• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-import-not-at-top
16"""Callbacks: utilities called at certain points during model training.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import copy
24import csv
25import io
26import json
27import os
28import re
29import tempfile
30import time
31
32import numpy as np
33import six
34
35from tensorflow.python.data.ops import iterator_ops
36from tensorflow.python.distribute import distributed_file_utils
37from tensorflow.python.distribute import multi_worker_util
38from tensorflow.python.eager import context
39from tensorflow.python.framework import ops
40from tensorflow.python.keras import backend as K
41from tensorflow.python.keras.distribute import multi_worker_training_state as training_state
42from tensorflow.python.keras.utils.data_utils import Sequence
43from tensorflow.python.keras.utils.generic_utils import Progbar
44from tensorflow.python.keras.utils.mode_keys import ModeKeys
45from tensorflow.python.lib.io import file_io
46from tensorflow.python.ops import array_ops
47from tensorflow.python.ops import math_ops
48from tensorflow.python.ops import summary_ops_v2
49from tensorflow.python.ops import variables
50from tensorflow.python.platform import tf_logging as logging
51from tensorflow.python.training import checkpoint_management
52from tensorflow.python.util.compat import collections_abc
53from tensorflow.python.util.tf_export import keras_export
54from tensorflow.tools.docs import doc_controls
55
56try:
57  import requests
58except ImportError:
59  requests = None
60
61
62def configure_callbacks(callbacks,
63                        model,
64                        do_validation=False,
65                        batch_size=None,
66                        epochs=None,
67                        steps_per_epoch=None,
68                        samples=None,
69                        verbose=1,
70                        count_mode='steps',
71                        mode=ModeKeys.TRAIN):
72  """Configures callbacks for use in various training loops.
73
74  Arguments:
75      callbacks: List of Callbacks.
76      model: Model being trained.
77      do_validation: Whether or not validation loop will be run.
78      batch_size: Number of samples per batch.
79      epochs: Number of epoch to train.
80      steps_per_epoch: Number of batches to run per training epoch.
81      samples: Number of training samples.
82      verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger.
83      count_mode: One of 'steps' or 'samples'. Per-batch or per-sample count.
84      mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT.
85        Which loop mode to configure callbacks for.
86
87  Returns:
88      Instance of CallbackList used to control all Callbacks.
89  """
90  # Check if callbacks have already been configured.
91  if isinstance(callbacks, CallbackList):
92    return callbacks
93
94  if not callbacks:
95    callbacks = []
96
97  # Add additional callbacks during training.
98  if mode == ModeKeys.TRAIN:
99    model.history = History()
100    callbacks = [BaseLogger()] + (callbacks or []) + [model.history]
101    if verbose:
102      callbacks.append(ProgbarLogger(count_mode))
103  callback_list = CallbackList(callbacks)
104
105  # Set callback model
106  callback_model = model._get_callback_model()  # pylint: disable=protected-access
107  callback_list.set_model(callback_model)
108
109  set_callback_parameters(
110      callback_list,
111      model,
112      do_validation=do_validation,
113      batch_size=batch_size,
114      epochs=epochs,
115      steps_per_epoch=steps_per_epoch,
116      samples=samples,
117      verbose=verbose,
118      mode=mode)
119
120  callback_list.model.stop_training = False
121  return callback_list
122
123
124def set_callback_parameters(callback_list,
125                            model,
126                            do_validation=False,
127                            batch_size=None,
128                            epochs=None,
129                            steps_per_epoch=None,
130                            samples=None,
131                            verbose=1,
132                            mode=ModeKeys.TRAIN):
133  """Sets callback parameters.
134
135  Arguments:
136      callback_list: CallbackList instance.
137      model: Model being trained.
138      do_validation: Whether or not validation loop will be run.
139      batch_size: Number of samples per batch.
140      epochs: Number of epoch to train.
141      steps_per_epoch: Number of batches to run per training epoch.
142      samples: Number of training samples.
143      verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger.
144      mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT.
145        Which loop mode to configure callbacks for.
146  """
147  metric_names = model.metrics_names
148  for cbk in callback_list:
149    if isinstance(cbk, (BaseLogger, ProgbarLogger)):
150      cbk.stateful_metrics = metric_names[1:]  # Exclude `loss`
151
152  # Set callback parameters
153  callback_metrics = []
154  # When we have deferred build scenario with iterator input, we will compile
155  # when we standardize first batch of data.
156  if mode != ModeKeys.PREDICT:
157    callback_metrics = copy.copy(metric_names)
158    if do_validation:
159      callback_metrics += ['val_' + n for n in metric_names]
160  callback_params = {
161      'batch_size': batch_size,
162      'epochs': epochs,
163      'steps': steps_per_epoch,
164      'samples': samples,
165      'verbose': verbose,
166      'do_validation': do_validation,
167      'metrics': callback_metrics,
168  }
169  callback_list.set_params(callback_params)
170
171
172def _is_generator_like(data):
173  """Checks if data is a generator, Sequence, or Iterator."""
174  return (hasattr(data, 'next') or hasattr(data, '__next__') or isinstance(
175      data, (Sequence, iterator_ops.Iterator, iterator_ops.OwnedIterator)))
176
177
178def make_logs(model, logs, outputs, mode, prefix=''):
179  """Computes logs for sending to `on_batch_end` methods."""
180  metric_names = model.metrics_names
181  if mode in {ModeKeys.TRAIN, ModeKeys.TEST} and metric_names:
182    for label, output in zip(metric_names, outputs):
183      logs[prefix + label] = output
184  else:
185    logs['outputs'] = outputs
186  return logs
187
188
189class CallbackList(object):
190  """Container abstracting a list of callbacks.
191
192  Arguments:
193      callbacks: List of `Callback` instances.
194      queue_length: Queue length for keeping
195          running statistics over callback execution time.
196  """
197
198  def __init__(self, callbacks=None, queue_length=10):
199    callbacks = callbacks or []
200    self.callbacks = [c for c in callbacks]
201    self.queue_length = queue_length
202    self.params = {}
203    self.model = None
204    self._reset_batch_timing()
205
206  def _reset_batch_timing(self):
207    self._delta_t_batch = 0.
208    self._delta_ts = collections.defaultdict(
209        lambda: collections.deque([], maxlen=self.queue_length))
210
211  def append(self, callback):
212    self.callbacks.append(callback)
213
214  def set_params(self, params):
215    self.params = params
216    for callback in self.callbacks:
217      callback.set_params(params)
218
219  def set_model(self, model):
220    self.model = model
221    for callback in self.callbacks:
222      callback.set_model(model)
223
224  def _call_batch_hook(self, mode, hook, batch, logs=None):
225    """Helper function for all batch_{begin | end} methods."""
226    if not self.callbacks:
227      return
228    hook_name = 'on_{mode}_batch_{hook}'.format(mode=mode, hook=hook)
229    if hook == 'begin':
230      self._t_enter_batch = time.time()
231    if hook == 'end':
232      # Batch is ending, calculate batch time.
233      self._delta_t_batch = time.time() - self._t_enter_batch
234
235    logs = logs or {}
236    t_before_callbacks = time.time()
237    for callback in self.callbacks:
238      batch_hook = getattr(callback, hook_name)
239      batch_hook(batch, logs)
240    self._delta_ts[hook_name].append(time.time() - t_before_callbacks)
241
242    delta_t_median = np.median(self._delta_ts[hook_name])
243    if (self._delta_t_batch > 0. and
244        delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1):
245      logging.warning(
246          'Method (%s) is slow compared '
247          'to the batch update (%f). Check your callbacks.', hook_name,
248          delta_t_median)
249
250  def _call_begin_hook(self, mode):
251    """Helper function for on_{train|test|predict}_begin methods."""
252    if mode == ModeKeys.TRAIN:
253      self.on_train_begin()
254    elif mode == ModeKeys.TEST:
255      self.on_test_begin()
256    else:
257      self.on_predict_begin()
258
259  def _call_end_hook(self, mode):
260    """Helper function for on_{train|test|predict}_end methods."""
261    if mode == ModeKeys.TRAIN:
262      self.on_train_end()
263    elif mode == ModeKeys.TEST:
264      self.on_test_end()
265    else:
266      self.on_predict_end()
267
268  def on_batch_begin(self, batch, logs=None):
269    self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs)
270
271  def on_batch_end(self, batch, logs=None):
272    self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
273
274  def on_epoch_begin(self, epoch, logs=None):
275    """Calls the `on_epoch_begin` methods of its callbacks.
276
277    This function should only be called during TRAIN mode.
278
279    Arguments:
280        epoch: integer, index of epoch.
281        logs: dict. Currently no data is passed to this argument for this method
282          but that may change in the future.
283    """
284    logs = logs or {}
285    for callback in self.callbacks:
286      callback.on_epoch_begin(epoch, logs)
287    self._reset_batch_timing()
288
289  def on_epoch_end(self, epoch, logs=None):
290    """Calls the `on_epoch_end` methods of its callbacks.
291
292    This function should only be called during TRAIN mode.
293
294    Arguments:
295        epoch: integer, index of epoch.
296        logs: dict, metric results for this training epoch, and for the
297          validation epoch if validation is performed. Validation result keys
298          are prefixed with `val_`.
299    """
300    logs = logs or {}
301    for callback in self.callbacks:
302      callback.on_epoch_end(epoch, logs)
303
304  def on_train_batch_begin(self, batch, logs=None):
305    """Calls the `on_train_batch_begin` methods of its callbacks.
306
307    Arguments:
308        batch: integer, index of batch within the current epoch.
309        logs: dict. Has keys `batch` and `size` representing the current batch
310          number and the size of the batch.
311    """
312    self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs)
313
314  def on_train_batch_end(self, batch, logs=None):
315    """Calls the `on_train_batch_end` methods of its callbacks.
316
317    Arguments:
318        batch: integer, index of batch within the current epoch.
319        logs: dict. Metric results for this batch.
320    """
321    self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
322
323  def on_test_batch_begin(self, batch, logs=None):
324    """Calls the `on_test_batch_begin` methods of its callbacks.
325
326    Arguments:
327        batch: integer, index of batch within the current epoch.
328        logs: dict. Has keys `batch` and `size` representing the current batch
329          number and the size of the batch.
330    """
331    self._call_batch_hook(ModeKeys.TEST, 'begin', batch, logs=logs)
332
333  def on_test_batch_end(self, batch, logs=None):
334    """Calls the `on_test_batch_end` methods of its callbacks.
335
336    Arguments:
337        batch: integer, index of batch within the current epoch.
338        logs: dict. Metric results for this batch.
339    """
340    self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs)
341
342  def on_predict_batch_begin(self, batch, logs=None):
343    """Calls the `on_predict_batch_begin` methods of its callbacks.
344
345    Arguments:
346        batch: integer, index of batch within the current epoch.
347        logs: dict. Has keys `batch` and `size` representing the current batch
348          number and the size of the batch.
349    """
350    self._call_batch_hook(ModeKeys.PREDICT, 'begin', batch, logs=logs)
351
352  def on_predict_batch_end(self, batch, logs=None):
353    """Calls the `on_predict_batch_end` methods of its callbacks.
354
355    Arguments:
356        batch: integer, index of batch within the current epoch.
357        logs: dict. Metric results for this batch.
358    """
359    self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs)
360
361  def on_train_begin(self, logs=None):
362    """Calls the `on_train_begin` methods of its callbacks.
363
364    Arguments:
365        logs: dict. Currently no data is passed to this argument for this method
366          but that may change in the future.
367    """
368    for callback in self.callbacks:
369      callback.on_train_begin(logs)
370
371  def on_train_end(self, logs=None):
372    """Calls the `on_train_end` methods of its callbacks.
373
374    Arguments:
375        logs: dict. Currently no data is passed to this argument for this method
376          but that may change in the future.
377    """
378    for callback in self.callbacks:
379      callback.on_train_end(logs)
380
381  def on_test_begin(self, logs=None):
382    """Calls the `on_test_begin` methods of its callbacks.
383
384    Arguments:
385        logs: dict. Currently no data is passed to this argument for this method
386          but that may change in the future.
387    """
388    for callback in self.callbacks:
389      callback.on_test_begin(logs)
390
391  def on_test_end(self, logs=None):
392    """Calls the `on_test_end` methods of its callbacks.
393
394    Arguments:
395        logs: dict. Currently no data is passed to this argument for this method
396          but that may change in the future.
397    """
398    for callback in self.callbacks:
399      callback.on_test_end(logs)
400
401  def on_predict_begin(self, logs=None):
402    """Calls the 'on_predict_begin` methods of its callbacks.
403
404    Arguments:
405        logs: dict. Currently no data is passed to this argument for this method
406          but that may change in the future.
407    """
408    for callback in self.callbacks:
409      callback.on_predict_begin(logs)
410
411  def on_predict_end(self, logs=None):
412    """Calls the `on_predict_end` methods of its callbacks.
413
414    Arguments:
415        logs: dict. Currently no data is passed to this argument for this method
416          but that may change in the future.
417    """
418    for callback in self.callbacks:
419      callback.on_predict_end(logs)
420
421  def __iter__(self):
422    return iter(self.callbacks)
423
424
425@keras_export('keras.callbacks.Callback')
426class Callback(object):
427  """Abstract base class used to build new callbacks.
428
429  Attributes:
430      params: dict. Training parameters
431          (eg. verbosity, batch size, number of epochs...).
432      model: instance of `keras.models.Model`.
433          Reference of the model being trained.
434      validation_data: Deprecated. Do not use.
435
436  The `logs` dictionary that callback methods
437  take as argument will contain keys for quantities relevant to
438  the current batch or epoch.
439
440  Currently, the `.fit()` method of the `Model` class
441  will include the following quantities in the `logs` that
442  it passes to its callbacks:
443
444      on_epoch_end: logs include `acc` and `loss`, and
445          optionally include `val_loss`
446          (if validation is enabled in `fit`), and `val_acc`
447          (if validation and accuracy monitoring are enabled).
448      on_batch_begin: logs include `size`,
449          the number of samples in the current batch.
450      on_batch_end: logs include `loss`, and optionally `acc`
451          (if accuracy monitoring is enabled).
452  """
453
454  def __init__(self):
455    self.validation_data = None
456    self.model = None
457    # Whether this Callback should only run on the chief worker in a
458    # Multi-Worker setting.
459    # TODO(omalleyt): Make this attr public once solution is stable.
460    self._chief_worker_only = None
461
462  def set_params(self, params):
463    self.params = params
464
465  def set_model(self, model):
466    self.model = model
467
468  @doc_controls.for_subclass_implementers
469  def on_batch_begin(self, batch, logs=None):
470    """A backwards compatibility alias for `on_train_batch_begin`."""
471
472  @doc_controls.for_subclass_implementers
473  def on_batch_end(self, batch, logs=None):
474    """A backwards compatibility alias for `on_train_batch_end`."""
475
476  @doc_controls.for_subclass_implementers
477  def on_epoch_begin(self, epoch, logs=None):
478    """Called at the start of an epoch.
479
480    Subclasses should override for any actions to run. This function should only
481    be called during TRAIN mode.
482
483    Arguments:
484        epoch: integer, index of epoch.
485        logs: dict. Currently no data is passed to this argument for this method
486          but that may change in the future.
487    """
488
489  @doc_controls.for_subclass_implementers
490  def on_epoch_end(self, epoch, logs=None):
491    """Called at the end of an epoch.
492
493    Subclasses should override for any actions to run. This function should only
494    be called during TRAIN mode.
495
496    Arguments:
497        epoch: integer, index of epoch.
498        logs: dict, metric results for this training epoch, and for the
499          validation epoch if validation is performed. Validation result keys
500          are prefixed with `val_`.
501    """
502
503  @doc_controls.for_subclass_implementers
504  def on_train_batch_begin(self, batch, logs=None):
505    """Called at the beginning of a training batch in `fit` methods.
506
507    Subclasses should override for any actions to run.
508
509    Arguments:
510        batch: integer, index of batch within the current epoch.
511        logs: dict. Has keys `batch` and `size` representing the current batch
512          number and the size of the batch.
513    """
514    # For backwards compatibility.
515    self.on_batch_begin(batch, logs=logs)
516
517  @doc_controls.for_subclass_implementers
518  def on_train_batch_end(self, batch, logs=None):
519    """Called at the end of a training batch in `fit` methods.
520
521    Subclasses should override for any actions to run.
522
523    Arguments:
524        batch: integer, index of batch within the current epoch.
525        logs: dict. Metric results for this batch.
526    """
527    # For backwards compatibility.
528    self.on_batch_end(batch, logs=logs)
529
530  @doc_controls.for_subclass_implementers
531  def on_test_batch_begin(self, batch, logs=None):
532    """Called at the beginning of a batch in `evaluate` methods.
533
534    Also called at the beginning of a validation batch in the `fit`
535    methods, if validation data is provided.
536
537    Subclasses should override for any actions to run.
538
539    Arguments:
540        batch: integer, index of batch within the current epoch.
541        logs: dict. Has keys `batch` and `size` representing the current batch
542          number and the size of the batch.
543    """
544
545  @doc_controls.for_subclass_implementers
546  def on_test_batch_end(self, batch, logs=None):
547    """Called at the end of a batch in `evaluate` methods.
548
549    Also called at the end of a validation batch in the `fit`
550    methods, if validation data is provided.
551
552    Subclasses should override for any actions to run.
553
554    Arguments:
555        batch: integer, index of batch within the current epoch.
556        logs: dict. Metric results for this batch.
557    """
558
559  @doc_controls.for_subclass_implementers
560  def on_predict_batch_begin(self, batch, logs=None):
561    """Called at the beginning of a batch in `predict` methods.
562
563    Subclasses should override for any actions to run.
564
565    Arguments:
566        batch: integer, index of batch within the current epoch.
567        logs: dict. Has keys `batch` and `size` representing the current batch
568          number and the size of the batch.
569    """
570
571  @doc_controls.for_subclass_implementers
572  def on_predict_batch_end(self, batch, logs=None):
573    """Called at the end of a batch in `predict` methods.
574
575    Subclasses should override for any actions to run.
576
577    Arguments:
578        batch: integer, index of batch within the current epoch.
579        logs: dict. Metric results for this batch.
580    """
581
582  @doc_controls.for_subclass_implementers
583  def on_train_begin(self, logs=None):
584    """Called at the beginning of training.
585
586    Subclasses should override for any actions to run.
587
588    Arguments:
589        logs: dict. Currently no data is passed to this argument for this method
590          but that may change in the future.
591    """
592
593  @doc_controls.for_subclass_implementers
594  def on_train_end(self, logs=None):
595    """Called at the end of training.
596
597    Subclasses should override for any actions to run.
598
599    Arguments:
600        logs: dict. Currently no data is passed to this argument for this method
601          but that may change in the future.
602    """
603
604  @doc_controls.for_subclass_implementers
605  def on_test_begin(self, logs=None):
606    """Called at the beginning of evaluation or validation.
607
608    Subclasses should override for any actions to run.
609
610    Arguments:
611        logs: dict. Currently no data is passed to this argument for this method
612          but that may change in the future.
613    """
614
615  @doc_controls.for_subclass_implementers
616  def on_test_end(self, logs=None):
617    """Called at the end of evaluation or validation.
618
619    Subclasses should override for any actions to run.
620
621    Arguments:
622        logs: dict. Currently no data is passed to this argument for this method
623          but that may change in the future.
624    """
625
626  @doc_controls.for_subclass_implementers
627  def on_predict_begin(self, logs=None):
628    """Called at the beginning of prediction.
629
630    Subclasses should override for any actions to run.
631
632    Arguments:
633        logs: dict. Currently no data is passed to this argument for this method
634          but that may change in the future.
635    """
636
637  @doc_controls.for_subclass_implementers
638  def on_predict_end(self, logs=None):
639    """Called at the end of prediction.
640
641    Subclasses should override for any actions to run.
642
643    Arguments:
644        logs: dict. Currently no data is passed to this argument for this method
645          but that may change in the future.
646    """
647
648
649@keras_export('keras.callbacks.BaseLogger')
650class BaseLogger(Callback):
651  """Callback that accumulates epoch averages of metrics.
652
653  This callback is automatically applied to every Keras model.
654
655  Arguments:
656      stateful_metrics: Iterable of string names of metrics that
657          should *not* be averaged over an epoch.
658          Metrics in this list will be logged as-is in `on_epoch_end`.
659          All others will be averaged in `on_epoch_end`.
660  """
661
662  def __init__(self, stateful_metrics=None):
663    super(BaseLogger, self).__init__()
664    self.stateful_metrics = set(stateful_metrics or [])
665
666  def on_epoch_begin(self, epoch, logs=None):
667    self.seen = 0
668    self.totals = {}
669
670  def on_batch_end(self, batch, logs=None):
671    logs = logs or {}
672    batch_size = logs.get('size', 0)
673    # In case of distribution strategy we can potentially run multiple steps
674    # at the same time, we should account for that in the `seen` calculation.
675    num_steps = logs.get('num_steps', 1)
676    self.seen += batch_size * num_steps
677
678    for k, v in logs.items():
679      if k in self.stateful_metrics:
680        self.totals[k] = v
681      else:
682        if k in self.totals:
683          self.totals[k] += v * batch_size
684        else:
685          self.totals[k] = v * batch_size
686
687  def on_epoch_end(self, epoch, logs=None):
688    if logs is not None:
689      for k in self.params['metrics']:
690        if k in self.totals:
691          # Make value available to next callbacks.
692          if k in self.stateful_metrics:
693            logs[k] = self.totals[k]
694          else:
695            logs[k] = self.totals[k] / self.seen
696
697
698@keras_export('keras.callbacks.TerminateOnNaN')
699class TerminateOnNaN(Callback):
700  """Callback that terminates training when a NaN loss is encountered.
701  """
702
703  def on_batch_end(self, batch, logs=None):
704    logs = logs or {}
705    loss = logs.get('loss')
706    if loss is not None:
707      if np.isnan(loss) or np.isinf(loss):
708        print('Batch %d: Invalid loss, terminating training' % (batch))
709        self.model.stop_training = True
710
711
712@keras_export('keras.callbacks.ProgbarLogger')
713class ProgbarLogger(Callback):
714  """Callback that prints metrics to stdout.
715
716  Arguments:
717      count_mode: One of "steps" or "samples".
718          Whether the progress bar should
719          count samples seen or steps (batches) seen.
720      stateful_metrics: Iterable of string names of metrics that
721          should *not* be averaged over an epoch.
722          Metrics in this list will be logged as-is.
723          All others will be averaged over time (e.g. loss, etc).
724
725  Raises:
726      ValueError: In case of invalid `count_mode`.
727  """
728
729  def __init__(self, count_mode='samples', stateful_metrics=None):
730    super(ProgbarLogger, self).__init__()
731    if count_mode == 'samples':
732      self.use_steps = False
733    elif count_mode == 'steps':
734      self.use_steps = True
735    else:
736      raise ValueError('Unknown `count_mode`: ' + str(count_mode))
737    self.stateful_metrics = set(stateful_metrics or [])
738    self.log_values = None
739
740  def on_train_begin(self, logs=None):
741    self.verbose = self.params['verbose']
742    self.epochs = self.params['epochs']
743
744  def on_epoch_begin(self, epoch, logs=None):
745    self.seen = 0
746    if self.use_steps:
747      self.target = self.params['steps']
748    else:
749      self.target = self.params['samples']
750
751    if self.verbose:
752      if self.epochs > 1:
753        print('Epoch %d/%d' % (epoch + 1, self.epochs))
754    self.progbar = Progbar(
755        target=self.target,
756        verbose=self.verbose,
757        stateful_metrics=self.stateful_metrics,
758        unit_name='step' if self.use_steps else 'sample')
759
760  def on_batch_begin(self, batch, logs=None):
761    self.log_values = []
762
763  def on_batch_end(self, batch, logs=None):
764    logs = logs or {}
765    batch_size = logs.get('size', 0)
766    # In case of distribution strategy we can potentially run multiple steps
767    # at the same time, we should account for that in the `seen` calculation.
768    num_steps = logs.get('num_steps', 1)
769    if self.use_steps:
770      self.seen += num_steps
771    else:
772      self.seen += batch_size * num_steps
773
774    for k in self.params['metrics']:
775      if k in logs:
776        self.log_values.append((k, logs[k]))
777
778    # Skip progbar update for the last batch;
779    # will be handled by on_epoch_end.
780    if self.verbose and (self.target is None or self.seen < self.target):
781      self.progbar.update(self.seen, self.log_values)
782
783  def on_epoch_end(self, epoch, logs=None):
784    logs = logs or {}
785    for k in self.params['metrics']:
786      if k in logs:
787        self.log_values.append((k, logs[k]))
788    if self.verbose:
789      self.progbar.update(self.seen, self.log_values)
790
791
792@keras_export('keras.callbacks.History')
793class History(Callback):
794  """Callback that records events into a `History` object.
795
796  This callback is automatically applied to
797  every Keras model. The `History` object
798  gets returned by the `fit` method of models.
799  """
800
801  def on_train_begin(self, logs=None):
802    self.epoch = []
803    self.history = {}
804
805  def on_epoch_end(self, epoch, logs=None):
806    logs = logs or {}
807    self.epoch.append(epoch)
808    for k, v in logs.items():
809      self.history.setdefault(k, []).append(v)
810
811
812@keras_export('keras.callbacks.ModelCheckpoint')
813class ModelCheckpoint(Callback):
814  """Callback to save the Keras model or model weights at some frequency.
815
816  `ModelCheckpoint` callback is used in conjunction with training using
817  `model.fit()` to save a model or weights (in a checkpoint file) at some
818  interval, so the model or weights can be loaded later to continue the training
819  from the state saved.
820
821  A few options this callback provides include:
822
823  - Whether to only keep the model that has achieved the "best performance" so
824    far, or whether to save the model at the end of every epoch regardless of
825    performance.
826  - Definition of 'best'; which quantity to monitor and whether it should be
827    maximized or minimized.
828  - The frequency it should save at. Currently, the callback supports saving at
829    the end of every epoch, or after a fixed number of training samples.
830  - Whether only weights are saved, or the whole model is saved.
831
832  Example:
833
834  ```python
835  EPOCHS = 10
836  checkpoint_filepath = '/tmp/checkpoint'
837  model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
838      filepath=checkpoint_filepath,
839      save_weights_only=True,
840      monitor='val_acc',
841      mode='max',
842      save_best_only=True)
843
844  # Model weights are saved at the end of every epoch, if it's the best seen
845  # so far.
846  model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
847
848  # The model weights (that are considered the best) are loaded into the model.
849  model.load_weights(checkpoint_filepath)
850  ```
851
852  Arguments:
853      filepath: string, path to save the model file. `filepath` can contain
854        named formatting options, which will be filled the value of `epoch` and
855        keys in `logs` (passed in `on_epoch_end`). For example: if `filepath` is
856        `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the model checkpoints
857        will be saved with the epoch number and the validation loss in the
858        filename.
859      monitor: quantity to monitor.
860      verbose: verbosity mode, 0 or 1.
861      save_best_only: if `save_best_only=True`, the latest best model according
862        to the quantity monitored will not be overwritten.
863        If `filepath` doesn't contain formatting options like `{epoch}` then
864        `filepath` will be overwritten by each new better model.
865      mode: one of {auto, min, max}. If `save_best_only=True`, the decision to
866        overwrite the current save file is made based on either the maximization
867        or the minimization of the monitored quantity. For `val_acc`, this
868        should be `max`, for `val_loss` this should be `min`, etc. In `auto`
869        mode, the direction is automatically inferred from the name of the
870        monitored quantity.
871      save_weights_only: if True, then only the model's weights will be saved
872        (`model.save_weights(filepath)`), else the full model is saved
873        (`model.save(filepath)`).
874      save_freq: `'epoch'` or integer. When using `'epoch'`, the callback saves
875        the model after each epoch. When using integer, the callback saves the
876        model at end of a batch at which this many samples have been seen since
877        last saving. Note that if the saving isn't aligned to epochs, the
878        monitored metric may potentially be less reliable (it could reflect as
879        little as 1 batch, since the metrics get reset every epoch). Defaults to
880        `'epoch'`
881      **kwargs: Additional arguments for backwards compatibility. Possible key
882        is `period`.
883  """
884
885  def __init__(self,
886               filepath,
887               monitor='val_loss',
888               verbose=0,
889               save_best_only=False,
890               save_weights_only=False,
891               mode='auto',
892               save_freq='epoch',
893               **kwargs):
894    super(ModelCheckpoint, self).__init__()
895    self.monitor = monitor
896    self.verbose = verbose
897    self.filepath = filepath
898    self.save_best_only = save_best_only
899    self.save_weights_only = save_weights_only
900    self.save_freq = save_freq
901    self.epochs_since_last_save = 0
902    self._samples_seen_since_last_saving = 0
903
904    # Deprecated field `load_weights_on_restart` is for loading the checkpoint
905    # file from `filepath` at the start of `model.fit()`
906    # TODO(rchao): Remove the arg during next breaking release.
907    if 'load_weights_on_restart' in kwargs:
908      self.load_weights_on_restart = kwargs['load_weights_on_restart']
909      logging.warning('`load_weights_on_restart` argument is deprecated. '
910                      'Please use `model.load_weights()` for loading weights '
911                      'before the start of `model.fit()`.')
912    else:
913      self.load_weights_on_restart = False
914
915    # Deprecated field `period` is for the number of epochs between which
916    # the model is saved.
917    if 'period' in kwargs:
918      self.period = kwargs['period']
919      logging.warning('`period` argument is deprecated. Please use `save_freq` '
920                      'to specify the frequency in number of samples seen.')
921    else:
922      self.period = 1
923
924    if mode not in ['auto', 'min', 'max']:
925      logging.warning('ModelCheckpoint mode %s is unknown, '
926                      'fallback to auto mode.', mode)
927      mode = 'auto'
928
929    if mode == 'min':
930      self.monitor_op = np.less
931      self.best = np.Inf
932    elif mode == 'max':
933      self.monitor_op = np.greater
934      self.best = -np.Inf
935    else:
936      if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
937        self.monitor_op = np.greater
938        self.best = -np.Inf
939      else:
940        self.monitor_op = np.less
941        self.best = np.Inf
942
943    if self.save_freq != 'epoch' and not isinstance(self.save_freq, int):
944      raise ValueError('Unrecognized save_freq: {}'.format(self.save_freq))
945
946    # Only the chief worker writes model checkpoints, but all workers
947    # restore checkpoint at on_train_begin().
948    self._chief_worker_only = False
949
950  def set_model(self, model):
951    self.model = model
952    # Use name matching rather than `isinstance` to avoid circular dependencies.
953    if (not self.save_weights_only and
954        not model._is_graph_network and  # pylint: disable=protected-access
955        model.__class__.__name__ != 'Sequential'):
956      self.save_weights_only = True
957
958  def on_train_begin(self, logs=None):
959    # pylint: disable=protected-access
960    if self.model._in_multi_worker_mode():
961      # MultiWorkerTrainingState is used to manage the training state needed
962      # for preemption-recovery of a worker in multi-worker training.
963      self.model._training_state = (
964          training_state.MultiWorkerTrainingState(self.model, self.filepath))
965      self._training_state = self.model._training_state
966      if self._training_state.restore():
967        # If the training state needs to be and is successfully restored,
968        # it is recovering from a previous failure (or preemption). In such
969        # case, do not load the weights from user specified file path.
970        return
971
972    # If this is not multi worker training, restoring is not needed, or
973    # restoring failed, check if it should load weights on restart.
974    if self.load_weights_on_restart:
975      if (not self.model._in_multi_worker_mode() or
976          multi_worker_util.should_load_checkpoint()):
977        filepath_to_load = (
978            self._get_most_recently_modified_file_matching_pattern(
979                self.filepath))
980        if (filepath_to_load is not None and
981            training_state.checkpoint_exists(filepath_to_load)):
982          try:
983            # `filepath` may contain placeholders such as `{epoch:02d}`, and
984            # thus it attempts to load the most recently modified file with file
985            # name matching the pattern.
986            self.model.load_weights(filepath_to_load)
987          except (IOError, ValueError) as e:
988            raise ValueError('Error loading file from {}. Reason: {}'.format(
989                filepath_to_load, e))
990
991  def on_train_end(self, logs=None):
992    # pylint: disable=protected-access
993    if self.model._in_multi_worker_mode():
994      if self.model.stop_training or getattr(
995          self.model, '_successful_loop_finish', False):
996        # In multi-worker training, on successful exit of training, delete the
997        # training state backup file that was saved for the purpose of worker
998        # recovery.
999        self._training_state.delete_backup()
1000        # Restore the training state so the model is ready for next (possible)
1001        # multi worker training.
1002        del self._training_state
1003        del self.model._training_state
1004
1005  def on_batch_end(self, batch, logs=None):
1006    logs = logs or {}
1007    if isinstance(self.save_freq, int):
1008      self._samples_seen_since_last_saving += logs.get('size', 1)
1009      if self._samples_seen_since_last_saving >= self.save_freq:
1010        self._save_model(epoch=self._current_epoch, logs=logs)
1011        self._samples_seen_since_last_saving = 0
1012
1013  def on_epoch_begin(self, epoch, logs=None):
1014    self._current_epoch = epoch
1015
1016  def on_epoch_end(self, epoch, logs=None):
1017    self.epochs_since_last_save += 1
1018    # pylint: disable=protected-access
1019    if self.save_freq == 'epoch':
1020      if self.model._in_multi_worker_mode():
1021        # Exclude training state variables in user-requested checkpoint file.
1022        with self._training_state.untrack_vars():
1023          self._save_model(epoch=epoch, logs=logs)
1024      else:
1025        self._save_model(epoch=epoch, logs=logs)
1026    if self.model._in_multi_worker_mode():
1027      # For multi-worker training, back up the weights and current training
1028      # state for possible future recovery.
1029      # TODO(rchao): Call `back_up` at finer period such as N steps.
1030      self._training_state.back_up(epoch)
1031
1032  def _save_model(self, epoch, logs):
1033    """Saves the model.
1034
1035    Arguments:
1036        epoch: the epoch this iteration is in.
1037        logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`.
1038    """
1039    logs = logs or {}
1040
1041    if isinstance(self.save_freq,
1042                  int) or self.epochs_since_last_save >= self.period:
1043      self.epochs_since_last_save = 0
1044      filepath = self._get_file_path(epoch, logs)
1045
1046      try:
1047        if self.save_best_only:
1048          current = logs.get(self.monitor)
1049          if current is None:
1050            logging.warning('Can save best model only with %s available, '
1051                            'skipping.', self.monitor)
1052          else:
1053            if self.monitor_op(current, self.best):
1054              if self.verbose > 0:
1055                print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
1056                      ' saving model to %s' % (epoch + 1, self.monitor,
1057                                               self.best, current, filepath))
1058              self.best = current
1059              if self.save_weights_only:
1060                self.model.save_weights(filepath, overwrite=True)
1061              else:
1062                self.model.save(filepath, overwrite=True)
1063            else:
1064              if self.verbose > 0:
1065                print('\nEpoch %05d: %s did not improve from %0.5f' %
1066                      (epoch + 1, self.monitor, self.best))
1067        else:
1068          if self.verbose > 0:
1069            print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
1070          if self.save_weights_only:
1071            self.model.save_weights(filepath, overwrite=True)
1072          else:
1073            self.model.save(filepath, overwrite=True)
1074
1075        self._maybe_remove_file()
1076      except IOError as e:
1077        # `e.errno` appears to be `None` so checking the content of `e.args[0]`.
1078        if 'is a directory' in six.ensure_str(e.args[0]):
1079          raise IOError('Please specify a non-directory filepath for '
1080                        'ModelCheckpoint. Filepath used is an existing '
1081                        'directory: {}'.format(filepath))
1082
1083  def _get_file_path(self, epoch, logs):
1084    """Returns the file path for checkpoint."""
1085    # pylint: disable=protected-access
1086    if not self.model._in_multi_worker_mode(
1087    ) or multi_worker_util.should_save_checkpoint():
1088      try:
1089        # `filepath` may contain placeholders such as `{epoch:02d}` and
1090        # `{mape:.2f}`. A mismatch between logged metrics and the path's
1091        # placeholders can cause formatting to fail.
1092        return self.filepath.format(epoch=epoch + 1, **logs)
1093      except KeyError as e:
1094        raise KeyError('Failed to format this callback filepath: "{}". '
1095                       'Reason: {}'.format(self.filepath, e))
1096    else:
1097      # If this is multi-worker training, and this worker should not
1098      # save checkpoint, we use a temp filepath to store a dummy checkpoint, so
1099      # it writes to a file that will be removed at the end of `_save_model()`
1100      # call. This is because the SyncOnReadVariable needs to be synced across
1101      # all the workers in order to be read, and all workers need to initiate
1102      # that.
1103      self._temp_file_dir = tempfile.mkdtemp()
1104      extension = os.path.splitext(self.filepath)[1]
1105      return os.path.join(self._temp_file_dir, 'temp' + extension)
1106
1107  def _maybe_remove_file(self):
1108    # Remove the checkpoint directory in multi-worker training where this worker
1109    # should not checkpoint. It is a dummy directory previously saved for sync
1110    # distributed training.
1111
1112    if (self.model._in_multi_worker_mode() and  # pylint: disable=protected-access
1113        not multi_worker_util.should_save_checkpoint()):
1114      file_io.delete_recursively(self._temp_file_dir)
1115      del self._temp_file_dir
1116
1117  def _get_most_recently_modified_file_matching_pattern(self, pattern):
1118    """Returns the most recently modified filepath matching pattern.
1119
1120    Pattern may contain python formatting placeholder. If
1121    `tf.train.latest_checkpoint()` does not return None, use that; otherwise,
1122    check for most recently modified one that matches the pattern.
1123
1124    In the rare case where there are more than one pattern-matching file having
1125    the same modified time that is most recent among all, return the filepath
1126    that is largest (by `>` operator, lexicographically using the numeric
1127    equivalents). This provides a tie-breaker when multiple files are most
1128    recent. Note that a larger `filepath` can sometimes indicate a later time of
1129    modification (for instance, when epoch/batch is used as formatting option),
1130    but not necessarily (when accuracy or loss is used). The tie-breaker is
1131    put in the logic as best effort to return the most recent, and to avoid
1132    undeterministic result.
1133
1134    Modified time of a file is obtained with `os.path.getmtime()`.
1135
1136    This utility function is best demonstrated via an example:
1137
1138    ```python
1139    file_pattern = 'f.batch{batch:02d}epoch{epoch:02d}.h5'
1140    test_dir = self.get_temp_dir()
1141    path_pattern = os.path.join(test_dir, file_pattern)
1142    file_paths = [
1143        os.path.join(test_dir, file_name) for file_name in
1144        ['f.batch03epoch02.h5', 'f.batch02epoch02.h5', 'f.batch01epoch01.h5']
1145    ]
1146    for file_path in file_paths:
1147      # Write something to each of the files
1148    self.assertEqual(
1149        _get_most_recently_modified_file_matching_pattern(path_pattern),
1150        file_paths[-1])
1151    ```
1152
1153    Arguments:
1154        pattern: The file pattern that may optionally contain python placeholder
1155            such as `{epoch:02d}`.
1156
1157    Returns:
1158        The most recently modified file's full filepath matching `pattern`. If
1159        `pattern` does not contain any placeholder, this returns the filepath
1160        that
1161        exactly matches `pattern`. Returns `None` if no match is found.
1162    """
1163    dir_name = os.path.dirname(pattern)
1164    base_name = os.path.basename(pattern)
1165    base_name_regex = '^' + re.sub(r'{.*}', r'.*', base_name) + '$'
1166
1167    # If tf.train.latest_checkpoint tells us there exists a latest checkpoint,
1168    # use that as it is more robust than `os.path.getmtime()`.
1169    latest_tf_checkpoint = checkpoint_management.latest_checkpoint(dir_name)
1170    if latest_tf_checkpoint is not None and re.match(
1171        base_name_regex, os.path.basename(latest_tf_checkpoint)):
1172      return latest_tf_checkpoint
1173
1174    latest_mod_time = 0
1175    file_path_with_latest_mod_time = None
1176    n_file_with_latest_mod_time = 0
1177    file_path_with_largest_file_name = None
1178
1179    if file_io.file_exists(dir_name):
1180      for file_name in os.listdir(dir_name):
1181        # Only consider if `file_name` matches the pattern.
1182        if re.match(base_name_regex, file_name):
1183          file_path = os.path.join(dir_name, file_name)
1184          mod_time = os.path.getmtime(file_path)
1185          if (file_path_with_largest_file_name is None or
1186              file_path > file_path_with_largest_file_name):
1187            file_path_with_largest_file_name = file_path
1188          if mod_time > latest_mod_time:
1189            latest_mod_time = mod_time
1190            file_path_with_latest_mod_time = file_path
1191            # In the case a file with later modified time is found, reset
1192            # the counter for the number of files with latest modified time.
1193            n_file_with_latest_mod_time = 1
1194          elif mod_time == latest_mod_time:
1195            # In the case a file has modified time tied with the most recent,
1196            # increment the counter for the number of files with latest modified
1197            # time by 1.
1198            n_file_with_latest_mod_time += 1
1199
1200    if n_file_with_latest_mod_time == 1:
1201      # Return the sole file that has most recent modified time.
1202      return file_path_with_latest_mod_time
1203    else:
1204      # If there are more than one file having latest modified time, return
1205      # the file path with the largest file name.
1206      return file_path_with_largest_file_name
1207
1208
1209@keras_export('keras.callbacks.EarlyStopping')
1210class EarlyStopping(Callback):
1211  """Stop training when a monitored metric has stopped improving.
1212
1213  Assuming the goal of a training is to minimize the loss. With this, the
1214  metric to be monitored would be 'loss', and mode would be 'min'. A
1215  `model.fit()` training loop will check at end of every epoch whether
1216  the loss is no longer decreasing, considering the `min_delta` and
1217  `patience` if applicable. Once it's found no longer decreasing,
1218  `model.stop_training` is marked True and the training terminates.
1219
1220  The quantity to be monitored needs to be available in `logs` dict.
1221  To make it so, pass the loss or metrics at `model.compile()`.
1222
1223  Example:
1224
1225  >>> callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
1226  >>> # This callback will stop the training when there is no improvement in
1227  >>> # the validation loss for three consecutive epochs.
1228  >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1229  >>> model.compile(tf.keras.optimizers.SGD(), loss='mse')
1230  >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
1231  ...                     epochs=10, callbacks=[callback])
1232      Train on 5 samples
1233      Epoch 1/10
1234  5/5 [==============================] - ... loss: 6533.1904
1235      Epoch 2/10
1236  5/5 [==============================] - ... loss: 110183360.0000
1237      Epoch 3/10
1238  5/5 [==============================] - ... loss: 1862575718400.0000
1239      Epoch 4/10
1240  5/5 [==============================] - ... loss: 31485597793124352.0000
1241  """
1242
1243  def __init__(self,
1244               monitor='val_loss',
1245               min_delta=0,
1246               patience=0,
1247               verbose=0,
1248               mode='auto',
1249               baseline=None,
1250               restore_best_weights=False):
1251    """Initialize an EarlyStopping callback.
1252
1253    Arguments:
1254        monitor: Quantity to be monitored.
1255        min_delta: Minimum change in the monitored quantity
1256            to qualify as an improvement, i.e. an absolute
1257            change of less than min_delta, will count as no
1258            improvement.
1259        patience: Number of epochs with no improvement
1260            after which training will be stopped.
1261        verbose: verbosity mode.
1262        mode: One of `{"auto", "min", "max"}`. In `min` mode,
1263            training will stop when the quantity
1264            monitored has stopped decreasing; in `max`
1265            mode it will stop when the quantity
1266            monitored has stopped increasing; in `auto`
1267            mode, the direction is automatically inferred
1268            from the name of the monitored quantity.
1269        baseline: Baseline value for the monitored quantity.
1270            Training will stop if the model doesn't show improvement over the
1271            baseline.
1272        restore_best_weights: Whether to restore model weights from
1273            the epoch with the best value of the monitored quantity.
1274            If False, the model weights obtained at the last step of
1275            training are used.
1276    """
1277    super(EarlyStopping, self).__init__()
1278
1279    self.monitor = monitor
1280    self.patience = patience
1281    self.verbose = verbose
1282    self.baseline = baseline
1283    self.min_delta = abs(min_delta)
1284    self.wait = 0
1285    self.stopped_epoch = 0
1286    self.restore_best_weights = restore_best_weights
1287    self.best_weights = None
1288
1289    if mode not in ['auto', 'min', 'max']:
1290      logging.warning('EarlyStopping mode %s is unknown, '
1291                      'fallback to auto mode.', mode)
1292      mode = 'auto'
1293
1294    if mode == 'min':
1295      self.monitor_op = np.less
1296    elif mode == 'max':
1297      self.monitor_op = np.greater
1298    else:
1299      if 'acc' in self.monitor:
1300        self.monitor_op = np.greater
1301      else:
1302        self.monitor_op = np.less
1303
1304    if self.monitor_op == np.greater:
1305      self.min_delta *= 1
1306    else:
1307      self.min_delta *= -1
1308
1309  def on_train_begin(self, logs=None):
1310    # Allow instances to be re-used
1311    self.wait = 0
1312    self.stopped_epoch = 0
1313    if self.baseline is not None:
1314      self.best = self.baseline
1315    else:
1316      self.best = np.Inf if self.monitor_op == np.less else -np.Inf
1317
1318  def on_epoch_end(self, epoch, logs=None):
1319    current = self.get_monitor_value(logs)
1320    if current is None:
1321      return
1322    if self.monitor_op(current - self.min_delta, self.best):
1323      self.best = current
1324      self.wait = 0
1325      if self.restore_best_weights:
1326        self.best_weights = self.model.get_weights()
1327    else:
1328      self.wait += 1
1329      if self.wait >= self.patience:
1330        self.stopped_epoch = epoch
1331        self.model.stop_training = True
1332        if self.restore_best_weights:
1333          if self.verbose > 0:
1334            print('Restoring model weights from the end of the best epoch.')
1335          self.model.set_weights(self.best_weights)
1336
1337  def on_train_end(self, logs=None):
1338    if self.stopped_epoch > 0 and self.verbose > 0:
1339      print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
1340
1341  def get_monitor_value(self, logs):
1342    logs = logs or {}
1343    monitor_value = logs.get(self.monitor)
1344    if monitor_value is None:
1345      logging.warning('Early stopping conditioned on metric `%s` '
1346                      'which is not available. Available metrics are: %s',
1347                      self.monitor, ','.join(list(logs.keys())))
1348    return monitor_value
1349
1350
1351@keras_export('keras.callbacks.RemoteMonitor')
1352class RemoteMonitor(Callback):
1353  """Callback used to stream events to a server.
1354
1355  Requires the `requests` library.
1356  Events are sent to `root + '/publish/epoch/end/'` by default. Calls are
1357  HTTP POST, with a `data` argument which is a
1358  JSON-encoded dictionary of event data.
1359  If send_as_json is set to True, the content type of the request will be
1360  application/json. Otherwise the serialized JSON will be sent within a form.
1361
1362  Arguments:
1363      root: String; root url of the target server.
1364      path: String; path relative to `root` to which the events will be sent.
1365      field: String; JSON field under which the data will be stored.
1366          The field is used only if the payload is sent within a form
1367          (i.e. send_as_json is set to False).
1368      headers: Dictionary; optional custom HTTP headers.
1369      send_as_json: Boolean; whether the request should be
1370          sent as application/json.
1371  """
1372
1373  def __init__(self,
1374               root='http://localhost:9000',
1375               path='/publish/epoch/end/',
1376               field='data',
1377               headers=None,
1378               send_as_json=False):
1379    super(RemoteMonitor, self).__init__()
1380
1381    self.root = root
1382    self.path = path
1383    self.field = field
1384    self.headers = headers
1385    self.send_as_json = send_as_json
1386
1387  def on_epoch_end(self, epoch, logs=None):
1388    if requests is None:
1389      raise ImportError('RemoteMonitor requires the `requests` library.')
1390    logs = logs or {}
1391    send = {}
1392    send['epoch'] = epoch
1393    for k, v in logs.items():
1394      # np.ndarray and np.generic are not scalar types
1395      # therefore we must unwrap their scalar values and
1396      # pass to the json-serializable dict 'send'
1397      if isinstance(v, (np.ndarray, np.generic)):
1398        send[k] = v.item()
1399      else:
1400        send[k] = v
1401    try:
1402      if self.send_as_json:
1403        requests.post(self.root + self.path, json=send, headers=self.headers)
1404      else:
1405        requests.post(
1406            self.root + self.path, {self.field: json.dumps(send)},
1407            headers=self.headers)
1408    except requests.exceptions.RequestException:
1409      logging.warning('Warning: could not reach RemoteMonitor '
1410                      'root server at ' + str(self.root))
1411
1412
1413@keras_export('keras.callbacks.LearningRateScheduler')
1414class LearningRateScheduler(Callback):
1415  """Learning rate scheduler.
1416
1417  Arguments:
1418      schedule: a function that takes an epoch index as input
1419          (integer, indexed from 0) and returns a new
1420          learning rate as output (float).
1421      verbose: int. 0: quiet, 1: update messages.
1422
1423  ```python
1424  # This function keeps the learning rate at 0.001 for the first ten epochs
1425  # and decreases it exponentially after that.
1426  def scheduler(epoch):
1427    if epoch < 10:
1428      return 0.001
1429    else:
1430      return 0.001 * tf.math.exp(0.1 * (10 - epoch))
1431
1432  callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
1433  model.fit(data, labels, epochs=100, callbacks=[callback],
1434            validation_data=(val_data, val_labels))
1435  ```
1436  """
1437
1438  def __init__(self, schedule, verbose=0):
1439    super(LearningRateScheduler, self).__init__()
1440    self.schedule = schedule
1441    self.verbose = verbose
1442
1443  def on_epoch_begin(self, epoch, logs=None):
1444    if not hasattr(self.model.optimizer, 'lr'):
1445      raise ValueError('Optimizer must have a "lr" attribute.')
1446    try:  # new API
1447      lr = float(K.get_value(self.model.optimizer.lr))
1448      lr = self.schedule(epoch, lr)
1449    except TypeError:  # Support for old API for backward compatibility
1450      lr = self.schedule(epoch)
1451    if not isinstance(lr, (ops.Tensor, float, np.float32, np.float64)):
1452      raise ValueError('The output of the "schedule" function '
1453                       'should be float.')
1454    if isinstance(lr, ops.Tensor) and not lr.dtype.is_floating:
1455      raise ValueError('The dtype of Tensor should be float')
1456    K.set_value(self.model.optimizer.lr, K.get_value(lr))
1457    if self.verbose > 0:
1458      print('\nEpoch %05d: LearningRateScheduler reducing learning '
1459            'rate to %s.' % (epoch + 1, lr))
1460
1461  def on_epoch_end(self, epoch, logs=None):
1462    logs = logs or {}
1463    logs['lr'] = K.get_value(self.model.optimizer.lr)
1464
1465
1466@keras_export('keras.callbacks.TensorBoard', v1=[])
1467class TensorBoard(Callback):
1468  # pylint: disable=line-too-long
1469  """Enable visualizations for TensorBoard.
1470
1471  TensorBoard is a visualization tool provided with TensorFlow.
1472
1473  This callback logs events for TensorBoard, including:
1474
1475  * Metrics summary plots
1476  * Training graph visualization
1477  * Activation histograms
1478  * Sampled profiling
1479
1480  If you have installed TensorFlow with pip, you should be able
1481  to launch TensorBoard from the command line:
1482
1483  ```sh
1484  tensorboard --logdir=path_to_your_logs
1485  ```
1486
1487  You can find more information about TensorBoard
1488  [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
1489
1490  Example:
1491  ```python
1492  tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
1493  model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
1494  #run the tensorboard command to view the visualizations
1495  ```
1496
1497  Arguments:
1498      log_dir: the path of the directory where to save the log files to be
1499        parsed by TensorBoard.
1500      histogram_freq: frequency (in epochs) at which to compute activation and
1501        weight histograms for the layers of the model. If set to 0, histograms
1502        won't be computed. Validation data (or split) must be specified for
1503        histogram visualizations.
1504      write_graph: whether to visualize the graph in TensorBoard. The log file
1505        can become quite large when write_graph is set to True.
1506      write_images: whether to write model weights to visualize as image in
1507        TensorBoard.
1508      update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`,
1509        writes the losses and metrics to TensorBoard after each batch. The same
1510        applies for `'epoch'`. If using an integer, let's say `1000`, the
1511        callback will write the metrics and losses to TensorBoard every 1000
1512        batches. Note that writing too frequently to TensorBoard can slow down
1513        your training.
1514      profile_batch: Profile the batch to sample compute characteristics. By
1515        default, it will profile the second batch. Set profile_batch=0 to
1516        disable profiling. Must run in TensorFlow eager mode.
1517      embeddings_freq: frequency (in epochs) at which embedding layers will
1518        be visualized. If set to 0, embeddings won't be visualized.
1519      embeddings_metadata: a dictionary which maps layer name to a file name in
1520        which metadata for this embedding layer is saved. See the
1521        [details](
1522          https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional)
1523        about metadata files format. In case if the same metadata file is
1524        used for all embedding layers, string can be passed.
1525
1526  Raises:
1527      ValueError: If histogram_freq is set and no validation data is provided.
1528  """
1529
1530  # pylint: enable=line-too-long
1531
1532  def __init__(self,
1533               log_dir='logs',
1534               histogram_freq=0,
1535               write_graph=True,
1536               write_images=False,
1537               update_freq='epoch',
1538               profile_batch=2,
1539               embeddings_freq=0,
1540               embeddings_metadata=None,
1541               **kwargs):
1542    super(TensorBoard, self).__init__()
1543    self._validate_kwargs(kwargs)
1544
1545    self.log_dir = log_dir
1546    self.histogram_freq = histogram_freq
1547    self.write_graph = write_graph
1548    self.write_images = write_images
1549    if update_freq == 'batch':
1550      self.update_freq = 1
1551    else:
1552      self.update_freq = update_freq
1553    self.embeddings_freq = embeddings_freq
1554    self.embeddings_metadata = embeddings_metadata
1555
1556    self._samples_seen = 0
1557    self._samples_seen_at_last_write = 0
1558    self._current_batch = 0
1559
1560    # A collection of file writers currently in use, to be closed when
1561    # training ends for this callback. Writers are keyed by the
1562    # directory name under the root logdir: e.g., "train" or
1563    # "validation".
1564    self._train_run_name = 'train'
1565    self._validation_run_name = 'validation'
1566    self._writers = {}
1567
1568    self._profile_batch = profile_batch
1569    # True when a trace is running.
1570    self._is_tracing = False
1571
1572  def _validate_kwargs(self, kwargs):
1573    """Handle arguments were supported in V1."""
1574    if kwargs.get('write_grads', False):
1575      logging.warning('`write_grads` will be ignored in TensorFlow 2.0 '
1576                      'for the `TensorBoard` Callback.')
1577    if kwargs.get('batch_size', False):
1578      logging.warning('`batch_size` is no longer needed in the '
1579                      '`TensorBoard` Callback and will be ignored '
1580                      'in TensorFlow 2.0.')
1581    if kwargs.get('embeddings_layer_names', False):
1582      logging.warning('`embeddings_layer_names` is not supported in '
1583                      'TensorFlow 2.0. Instead, all `Embedding` layers '
1584                      'will be visualized.')
1585    if kwargs.get('embeddings_data', False):
1586      logging.warning('`embeddings_data` is not supported in TensorFlow '
1587                      '2.0. Instead, all `Embedding` variables will be '
1588                      'visualized.')
1589
1590    unrecognized_kwargs = set(kwargs.keys()) - {
1591        'write_grads', 'embeddings_layer_names', 'embeddings_data', 'batch_size'
1592    }
1593
1594    # Only allow kwargs that were supported in V1.
1595    if unrecognized_kwargs:
1596      raise ValueError('Unrecognized arguments in `TensorBoard` '
1597                       'Callback: ' + str(unrecognized_kwargs))
1598
1599  def set_model(self, model):
1600    """Sets Keras model and writes graph if specified."""
1601    self.model = model
1602
1603    # TensorBoard callback involves writing a summary file in a
1604    # possibly distributed settings.
1605    self._log_write_dir = distributed_file_utils.write_dirpath(
1606        self.log_dir, self.model._get_distribution_strategy())  # pylint: disable=protected-access
1607
1608    with context.eager_mode():
1609      self._close_writers()
1610      if self.write_graph:
1611        with self._get_writer(self._train_run_name).as_default():
1612          with summary_ops_v2.always_record_summaries():
1613            if not model.run_eagerly:
1614              summary_ops_v2.graph(K.get_graph(), step=0)
1615
1616            summary_writable = (
1617                self.model._is_graph_network or  # pylint: disable=protected-access
1618                self.model.__class__.__name__ == 'Sequential')  # pylint: disable=protected-access
1619            if summary_writable:
1620              summary_ops_v2.keras_model('keras', self.model, step=0)
1621
1622    if self.embeddings_freq:
1623      self._configure_embeddings()
1624
1625    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
1626    self._prev_summary_recording = summary_state.is_recording
1627    self._prev_summary_writer = summary_state.writer
1628    self._prev_summary_step = summary_state.step
1629
1630  def _configure_embeddings(self):
1631    """Configure the Projector for embeddings."""
1632    # TODO(omalleyt): Add integration tests.
1633    from tensorflow.python.keras.layers import embeddings
1634    try:
1635      from tensorboard.plugins import projector
1636    except ImportError:
1637      raise ImportError('Failed to import TensorBoard. Please make sure that '
1638                        'TensorBoard integration is complete."')
1639    config = projector.ProjectorConfig()
1640    for layer in self.model.layers:
1641      if isinstance(layer, embeddings.Embedding):
1642        embedding = config.embeddings.add()
1643        embedding.tensor_name = layer.embeddings.name
1644
1645        if self.embeddings_metadata is not None:
1646          if isinstance(self.embeddings_metadata, str):
1647            embedding.metadata_path = self.embeddings_metadata
1648          else:
1649            if layer.name in embedding.metadata_path:
1650              embedding.metadata_path = self.embeddings_metadata.pop(layer.name)
1651
1652    if self.embeddings_metadata:
1653      raise ValueError('Unrecognized `Embedding` layer names passed to '
1654                       '`keras.callbacks.TensorBoard` `embeddings_metadata` '
1655                       'argument: ' + str(self.embeddings_metadata.keys()))
1656
1657    class DummyWriter(object):
1658      """Dummy writer to conform to `Projector` API."""
1659
1660      def __init__(self, logdir):
1661        self.logdir = logdir
1662
1663      def get_logdir(self):
1664        return self.logdir
1665
1666    writer = DummyWriter(self._log_write_dir)
1667    projector.visualize_embeddings(writer, config)
1668
1669  def _close_writers(self):
1670    """Close all remaining open file writers owned by this callback.
1671
1672    If there are no such file writers, this is a no-op.
1673    """
1674    with context.eager_mode():
1675      for writer in six.itervalues(self._writers):
1676        writer.close()
1677      self._writers.clear()
1678
1679  def _get_writer(self, writer_name):
1680    """Get a summary writer for the given subdirectory under the logdir.
1681
1682    A writer will be created if it does not yet exist.
1683
1684    Arguments:
1685      writer_name: The name of the directory for which to create or
1686        retrieve a writer. Should be either `self._train_run_name` or
1687        `self._validation_run_name`.
1688
1689    Returns:
1690      A `SummaryWriter` object.
1691    """
1692    if writer_name not in self._writers:
1693      path = os.path.join(self._log_write_dir, writer_name)
1694      writer = summary_ops_v2.create_file_writer_v2(path)
1695      self._writers[writer_name] = writer
1696    return self._writers[writer_name]
1697
1698  def _set_default_writer(self, writer_name):
1699    """Sets the default writer for custom batch-level summaries."""
1700    if self.update_freq == 'epoch':
1701      # Writer is only used for custom summaries, which are written
1702      # batch-by-batch.
1703      return
1704
1705    step = self._total_batches_seen[writer_name]
1706
1707    def _should_record():
1708      return math_ops.equal(step % self.update_freq, 0)
1709
1710    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
1711    summary_state.is_recording = _should_record
1712    summary_state.writer = self._get_writer(writer_name)
1713    summary_ops_v2.set_step(step)
1714
1715  def _init_batch_steps(self):
1716    """Create the total batch counters."""
1717    if ops.executing_eagerly_outside_functions():
1718      # Variables are needed for the `step` value of custom tf.summaries
1719      # to be updated inside a tf.function.
1720      self._total_batches_seen = {
1721          self._train_run_name: variables.Variable(0, dtype='int64'),
1722          self._validation_run_name: variables.Variable(0, dtype='int64')
1723      }
1724    else:
1725      # Custom tf.summaries are not supported in legacy graph mode.
1726      self._total_batches_seen = {
1727          self._train_run_name: 0,
1728          self._validation_run_name: 0
1729      }
1730
1731  def _increment_step(self, writer_name):
1732    step = self._total_batches_seen[writer_name]
1733    if isinstance(step, variables.Variable):
1734      step.assign_add(1)
1735    else:
1736      self._total_batches_seen[writer_name] += 1
1737
1738  def on_train_begin(self, logs=None):
1739    self._init_batch_steps()
1740    if self._profile_batch == 1:
1741      summary_ops_v2.trace_on(graph=True, profiler=True)
1742      self._is_tracing = True
1743
1744  def on_test_begin(self, logs=None):
1745    self._set_default_writer(self._validation_run_name)
1746
1747  def on_train_batch_end(self, batch, logs=None):
1748    """Writes scalar summaries for metrics on every training batch.
1749
1750    Performs profiling if current batch is in profiler_batches.
1751
1752    Arguments:
1753      batch: Integer, index of batch within the current epoch.
1754      logs: Dict. Metric results for this batch.
1755    """
1756    if self.update_freq == 'epoch' and self._profile_batch is None:
1757      return
1758
1759    # Don't output batch_size and batch number as TensorBoard summaries
1760    logs = logs or {}
1761    train_batches = self._total_batches_seen[self._train_run_name]
1762    if self.update_freq != 'epoch' and batch % self.update_freq == 0:
1763      self._log_metrics(logs, prefix='batch_', step=train_batches)
1764
1765    self._increment_step(self._train_run_name)
1766
1767    if context.executing_eagerly():
1768      if self._is_tracing:
1769        self._log_trace()
1770      elif (not self._is_tracing and
1771            math_ops.equal(train_batches, self._profile_batch - 1)):
1772        self._enable_trace()
1773
1774  def on_test_batch_end(self, batch, logs=None):
1775    if self.update_freq == 'epoch':
1776      return
1777    self._increment_step(self._validation_run_name)
1778
1779  def on_epoch_begin(self, epoch, logs=None):
1780    self._set_default_writer(self._train_run_name)
1781
1782  def on_epoch_end(self, epoch, logs=None):
1783    """Runs metrics and histogram summaries at epoch end."""
1784    self._log_metrics(logs, prefix='epoch_', step=epoch)
1785
1786    if self.histogram_freq and epoch % self.histogram_freq == 0:
1787      self._log_weights(epoch)
1788
1789    if self.embeddings_freq and epoch % self.embeddings_freq == 0:
1790      self._log_embeddings(epoch)
1791
1792  def on_train_end(self, logs=None):
1793    if self._is_tracing:
1794      self._log_trace()
1795    self._close_writers()
1796
1797    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
1798    summary_state.is_recording = self._prev_summary_recording
1799    summary_state.writer = self._prev_summary_writer
1800    summary_state.step = self._prev_summary_step
1801
1802    # Safely remove the unneeded temp files.
1803    distributed_file_utils.remove_temp_dirpath(
1804        self.log_dir, self.model._get_distribution_strategy())  # pylint: disable=protected-access
1805
1806  def _enable_trace(self):
1807    if context.executing_eagerly():
1808      summary_ops_v2.trace_on(graph=True, profiler=True)
1809      self._is_tracing = True
1810
1811  def _log_trace(self):
1812    """Logs the trace graph to TensorBoard."""
1813    if context.executing_eagerly():
1814      with self._get_writer(self._train_run_name).as_default(), \
1815          summary_ops_v2.always_record_summaries():
1816        # TODO(b/126388999): Remove step info in the summary name.
1817        step = K.get_value(self._total_batches_seen[self._train_run_name])
1818        summary_ops_v2.trace_export(
1819            name='batch_%d' % step,
1820            step=step,
1821            profiler_outdir=os.path.join(self._log_write_dir, 'train'))
1822      self._is_tracing = False
1823
1824  def _log_metrics(self, logs, prefix, step):
1825    """Writes metrics out as custom scalar summaries.
1826
1827    Arguments:
1828        logs: Dict. Keys are scalar summary names, values are NumPy scalars.
1829        prefix: String. The prefix to apply to the scalar summary names.
1830        step: Int. The global step to use for TensorBoard.
1831    """
1832    if logs is None:
1833      logs = {}
1834
1835    # Group metrics by the name of their associated file writer. Values
1836    # are lists of metrics, as (name, scalar_value) pairs.
1837    logs_by_writer = {
1838        self._train_run_name: [],
1839        self._validation_run_name: [],
1840    }
1841    validation_prefix = 'val_'
1842    for (name, value) in logs.items():
1843      if name in ('batch', 'size', 'num_steps'):
1844        # Scrub non-metric items.
1845        continue
1846      if name.startswith(validation_prefix):
1847        name = name[len(validation_prefix):]
1848        writer_name = self._validation_run_name
1849      else:
1850        writer_name = self._train_run_name
1851      name = prefix + name  # assign batch or epoch prefix
1852      logs_by_writer[writer_name].append((name, value))
1853
1854    with context.eager_mode():
1855      with summary_ops_v2.always_record_summaries():
1856        for writer_name in logs_by_writer:
1857          these_logs = logs_by_writer[writer_name]
1858          if not these_logs:
1859            # Don't create a "validation" events file if we don't
1860            # actually have any validation data.
1861            continue
1862          writer = self._get_writer(writer_name)
1863          with writer.as_default():
1864            for (name, value) in these_logs:
1865              summary_ops_v2.scalar(name, value, step=step)
1866
1867  def _log_weights(self, epoch):
1868    """Logs the weights of the Model to TensorBoard."""
1869    writer = self._get_writer(self._train_run_name)
1870    with context.eager_mode(), \
1871          writer.as_default(), \
1872          summary_ops_v2.always_record_summaries():
1873      for layer in self.model.layers:
1874        for weight in layer.weights:
1875          weight_name = weight.name.replace(':', '_')
1876          with ops.init_scope():
1877            weight = K.get_value(weight)
1878          summary_ops_v2.histogram(weight_name, weight, step=epoch)
1879          if self.write_images:
1880            self._log_weight_as_image(weight, weight_name, epoch)
1881      writer.flush()
1882
1883  def _log_weight_as_image(self, weight, weight_name, epoch):
1884    """Logs a weight as a TensorBoard image."""
1885    w_img = array_ops.squeeze(weight)
1886    shape = K.int_shape(w_img)
1887    if len(shape) == 1:  # Bias case
1888      w_img = array_ops.reshape(w_img, [1, shape[0], 1, 1])
1889    elif len(shape) == 2:  # Dense layer kernel case
1890      if shape[0] > shape[1]:
1891        w_img = array_ops.transpose(w_img)
1892        shape = K.int_shape(w_img)
1893      w_img = array_ops.reshape(w_img, [1, shape[0], shape[1], 1])
1894    elif len(shape) == 3:  # ConvNet case
1895      if K.image_data_format() == 'channels_last':
1896        # Switch to channels_first to display every kernel as a separate
1897        # image.
1898        w_img = array_ops.transpose(w_img, perm=[2, 0, 1])
1899        shape = K.int_shape(w_img)
1900      w_img = array_ops.reshape(w_img, [shape[0], shape[1], shape[2], 1])
1901
1902    shape = K.int_shape(w_img)
1903    # Not possible to handle 3D convnets etc.
1904    if len(shape) == 4 and shape[-1] in [1, 3, 4]:
1905      summary_ops_v2.image(weight_name, w_img, step=epoch)
1906
1907  def _log_embeddings(self, epoch):
1908    embeddings_ckpt = os.path.join(self._log_write_dir, 'train',
1909                                   'keras_embedding.ckpt-{}'.format(epoch))
1910    self.model.save_weights(embeddings_ckpt)
1911
1912
1913@keras_export('keras.callbacks.ReduceLROnPlateau')
1914class ReduceLROnPlateau(Callback):
1915  """Reduce learning rate when a metric has stopped improving.
1916
1917  Models often benefit from reducing the learning rate by a factor
1918  of 2-10 once learning stagnates. This callback monitors a
1919  quantity and if no improvement is seen for a 'patience' number
1920  of epochs, the learning rate is reduced.
1921
1922  Example:
1923
1924  ```python
1925  reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
1926                                patience=5, min_lr=0.001)
1927  model.fit(X_train, Y_train, callbacks=[reduce_lr])
1928  ```
1929
1930  Arguments:
1931      monitor: quantity to be monitored.
1932      factor: factor by which the learning rate will be reduced. new_lr = lr *
1933        factor
1934      patience: number of epochs with no improvement after which learning rate
1935        will be reduced.
1936      verbose: int. 0: quiet, 1: update messages.
1937      mode: one of {auto, min, max}. In `min` mode, lr will be reduced when the
1938        quantity monitored has stopped decreasing; in `max` mode it will be
1939        reduced when the quantity monitored has stopped increasing; in `auto`
1940        mode, the direction is automatically inferred from the name of the
1941        monitored quantity.
1942      min_delta: threshold for measuring the new optimum, to only focus on
1943        significant changes.
1944      cooldown: number of epochs to wait before resuming normal operation after
1945        lr has been reduced.
1946      min_lr: lower bound on the learning rate.
1947  """
1948
1949  def __init__(self,
1950               monitor='val_loss',
1951               factor=0.1,
1952               patience=10,
1953               verbose=0,
1954               mode='auto',
1955               min_delta=1e-4,
1956               cooldown=0,
1957               min_lr=0,
1958               **kwargs):
1959    super(ReduceLROnPlateau, self).__init__()
1960
1961    self.monitor = monitor
1962    if factor >= 1.0:
1963      raise ValueError('ReduceLROnPlateau ' 'does not support a factor >= 1.0.')
1964    if 'epsilon' in kwargs:
1965      min_delta = kwargs.pop('epsilon')
1966      logging.warning('`epsilon` argument is deprecated and '
1967                      'will be removed, use `min_delta` instead.')
1968    self.factor = factor
1969    self.min_lr = min_lr
1970    self.min_delta = min_delta
1971    self.patience = patience
1972    self.verbose = verbose
1973    self.cooldown = cooldown
1974    self.cooldown_counter = 0  # Cooldown counter.
1975    self.wait = 0
1976    self.best = 0
1977    self.mode = mode
1978    self.monitor_op = None
1979    self._reset()
1980
1981  def _reset(self):
1982    """Resets wait counter and cooldown counter.
1983    """
1984    if self.mode not in ['auto', 'min', 'max']:
1985      logging.warning('Learning Rate Plateau Reducing mode %s is unknown, '
1986                      'fallback to auto mode.', self.mode)
1987      self.mode = 'auto'
1988    if (self.mode == 'min' or
1989        (self.mode == 'auto' and 'acc' not in self.monitor)):
1990      self.monitor_op = lambda a, b: np.less(a, b - self.min_delta)
1991      self.best = np.Inf
1992    else:
1993      self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta)
1994      self.best = -np.Inf
1995    self.cooldown_counter = 0
1996    self.wait = 0
1997
1998  def on_train_begin(self, logs=None):
1999    self._reset()
2000
2001  def on_epoch_end(self, epoch, logs=None):
2002    logs = logs or {}
2003    logs['lr'] = K.get_value(self.model.optimizer.lr)
2004    current = logs.get(self.monitor)
2005    if current is None:
2006      logging.warning('Reduce LR on plateau conditioned on metric `%s` '
2007                      'which is not available. Available metrics are: %s',
2008                      self.monitor, ','.join(list(logs.keys())))
2009
2010    else:
2011      if self.in_cooldown():
2012        self.cooldown_counter -= 1
2013        self.wait = 0
2014
2015      if self.monitor_op(current, self.best):
2016        self.best = current
2017        self.wait = 0
2018      elif not self.in_cooldown():
2019        self.wait += 1
2020        if self.wait >= self.patience:
2021          old_lr = float(K.get_value(self.model.optimizer.lr))
2022          if old_lr > self.min_lr:
2023            new_lr = old_lr * self.factor
2024            new_lr = max(new_lr, self.min_lr)
2025            K.set_value(self.model.optimizer.lr, new_lr)
2026            if self.verbose > 0:
2027              print('\nEpoch %05d: ReduceLROnPlateau reducing learning '
2028                    'rate to %s.' % (epoch + 1, new_lr))
2029            self.cooldown_counter = self.cooldown
2030            self.wait = 0
2031
2032  def in_cooldown(self):
2033    return self.cooldown_counter > 0
2034
2035
2036@keras_export('keras.callbacks.CSVLogger')
2037class CSVLogger(Callback):
2038  """Callback that streams epoch results to a csv file.
2039
2040  Supports all values that can be represented as a string,
2041  including 1D iterables such as np.ndarray.
2042
2043  Example:
2044
2045  ```python
2046  csv_logger = CSVLogger('training.log')
2047  model.fit(X_train, Y_train, callbacks=[csv_logger])
2048  ```
2049
2050  Arguments:
2051      filename: filename of the csv file, e.g. 'run/log.csv'.
2052      separator: string used to separate elements in the csv file.
2053      append: True: append if file exists (useful for continuing
2054          training). False: overwrite existing file,
2055  """
2056
2057  def __init__(self, filename, separator=',', append=False):
2058    self.sep = separator
2059    self.filename = filename
2060    self.append = append
2061    self.writer = None
2062    self.keys = None
2063    self.append_header = True
2064    if six.PY2:
2065      self.file_flags = 'b'
2066      self._open_args = {}
2067    else:
2068      self.file_flags = ''
2069      self._open_args = {'newline': '\n'}
2070    super(CSVLogger, self).__init__()
2071
2072  def on_train_begin(self, logs=None):
2073    if self.append:
2074      if file_io.file_exists(self.filename):
2075        with open(self.filename, 'r' + self.file_flags) as f:
2076          self.append_header = not bool(len(f.readline()))
2077      mode = 'a'
2078    else:
2079      mode = 'w'
2080    self.csv_file = io.open(self.filename,
2081                            mode + self.file_flags,
2082                            **self._open_args)
2083
2084  def on_epoch_end(self, epoch, logs=None):
2085    logs = logs or {}
2086
2087    def handle_value(k):
2088      is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
2089      if isinstance(k, six.string_types):
2090        return k
2091      elif isinstance(k, collections_abc.Iterable) and not is_zero_dim_ndarray:
2092        return '"[%s]"' % (', '.join(map(str, k)))
2093      else:
2094        return k
2095
2096    if self.keys is None:
2097      self.keys = sorted(logs.keys())
2098
2099    if self.model.stop_training:
2100      # We set NA so that csv parsers do not fail for this last epoch.
2101      logs = dict([(k, logs[k]) if k in logs else (k, 'NA') for k in self.keys])
2102
2103    if not self.writer:
2104
2105      class CustomDialect(csv.excel):
2106        delimiter = self.sep
2107
2108      fieldnames = ['epoch'] + self.keys
2109      if six.PY2:
2110        fieldnames = [unicode(x) for x in fieldnames]
2111
2112      self.writer = csv.DictWriter(
2113          self.csv_file,
2114          fieldnames=fieldnames,
2115          dialect=CustomDialect)
2116      if self.append_header:
2117        self.writer.writeheader()
2118
2119    row_dict = collections.OrderedDict({'epoch': epoch})
2120    row_dict.update((key, handle_value(logs[key])) for key in self.keys)
2121    self.writer.writerow(row_dict)
2122    self.csv_file.flush()
2123
2124  def on_train_end(self, logs=None):
2125    self.csv_file.close()
2126    self.writer = None
2127
2128
2129@keras_export('keras.callbacks.LambdaCallback')
2130class LambdaCallback(Callback):
2131  r"""Callback for creating simple, custom callbacks on-the-fly.
2132
2133  This callback is constructed with anonymous functions that will be called
2134  at the appropriate time. Note that the callbacks expects positional
2135  arguments, as:
2136
2137   - `on_epoch_begin` and `on_epoch_end` expect two positional arguments:
2138      `epoch`, `logs`
2139   - `on_batch_begin` and `on_batch_end` expect two positional arguments:
2140      `batch`, `logs`
2141   - `on_train_begin` and `on_train_end` expect one positional argument:
2142      `logs`
2143
2144  Arguments:
2145      on_epoch_begin: called at the beginning of every epoch.
2146      on_epoch_end: called at the end of every epoch.
2147      on_batch_begin: called at the beginning of every batch.
2148      on_batch_end: called at the end of every batch.
2149      on_train_begin: called at the beginning of model training.
2150      on_train_end: called at the end of model training.
2151
2152  Example:
2153
2154  ```python
2155  # Print the batch number at the beginning of every batch.
2156  batch_print_callback = LambdaCallback(
2157      on_batch_begin=lambda batch,logs: print(batch))
2158
2159  # Stream the epoch loss to a file in JSON format. The file content
2160  # is not well-formed JSON but rather has a JSON object per line.
2161  import json
2162  json_log = open('loss_log.json', mode='wt', buffering=1)
2163  json_logging_callback = LambdaCallback(
2164      on_epoch_end=lambda epoch, logs: json_log.write(
2165          json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'),
2166      on_train_end=lambda logs: json_log.close()
2167  )
2168
2169  # Terminate some processes after having finished model training.
2170  processes = ...
2171  cleanup_callback = LambdaCallback(
2172      on_train_end=lambda logs: [
2173          p.terminate() for p in processes if p.is_alive()])
2174
2175  model.fit(...,
2176            callbacks=[batch_print_callback,
2177                       json_logging_callback,
2178                       cleanup_callback])
2179  ```
2180  """
2181
2182  def __init__(self,
2183               on_epoch_begin=None,
2184               on_epoch_end=None,
2185               on_batch_begin=None,
2186               on_batch_end=None,
2187               on_train_begin=None,
2188               on_train_end=None,
2189               **kwargs):
2190    super(LambdaCallback, self).__init__()
2191    self.__dict__.update(kwargs)
2192    if on_epoch_begin is not None:
2193      self.on_epoch_begin = on_epoch_begin
2194    else:
2195      self.on_epoch_begin = lambda epoch, logs: None
2196    if on_epoch_end is not None:
2197      self.on_epoch_end = on_epoch_end
2198    else:
2199      self.on_epoch_end = lambda epoch, logs: None
2200    if on_batch_begin is not None:
2201      self.on_batch_begin = on_batch_begin
2202    else:
2203      self.on_batch_begin = lambda batch, logs: None
2204    if on_batch_end is not None:
2205      self.on_batch_end = on_batch_end
2206    else:
2207      self.on_batch_end = lambda batch, logs: None
2208    if on_train_begin is not None:
2209      self.on_train_begin = on_train_begin
2210    else:
2211      self.on_train_begin = lambda logs: None
2212    if on_train_end is not None:
2213      self.on_train_end = on_train_end
2214    else:
2215      self.on_train_end = lambda logs: None
2216