• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training-related part of the Keras engine."""
16
17import copy
18import itertools
19import json
20import os
21import warnings
22import weakref
23
24from tensorflow.python.autograph.lang import directives
25from tensorflow.python.data.ops import dataset_ops
26from tensorflow.python.data.ops import options as options_lib
27from tensorflow.python.distribute import collective_all_reduce_strategy
28from tensorflow.python.distribute import distribution_strategy_context as ds_context
29from tensorflow.python.distribute import values as ds_values
30from tensorflow.python.distribute.coordinator import cluster_coordinator
31from tensorflow.python.eager import backprop
32from tensorflow.python.eager import context
33from tensorflow.python.eager import def_function
34from tensorflow.python.framework import composite_tensor
35from tensorflow.python.framework import errors
36from tensorflow.python.framework import errors_impl
37from tensorflow.python.framework import func_graph
38from tensorflow.python.framework import ops
39from tensorflow.python.framework import sparse_tensor
40from tensorflow.python.framework import tensor_shape
41from tensorflow.python.keras import backend
42from tensorflow.python.keras import callbacks as callbacks_module
43from tensorflow.python.keras import optimizer_v1
44from tensorflow.python.keras import optimizers
45from tensorflow.python.keras.engine import base_layer
46from tensorflow.python.keras.engine import base_layer_utils
47from tensorflow.python.keras.engine import compile_utils
48from tensorflow.python.keras.engine import data_adapter
49from tensorflow.python.keras.engine import training_utils
50from tensorflow.python.keras.mixed_precision import loss_scale_optimizer as lso
51from tensorflow.python.keras.mixed_precision import policy
52from tensorflow.python.keras.saving import hdf5_format
53from tensorflow.python.keras.saving import save
54from tensorflow.python.keras.saving import saving_utils
55from tensorflow.python.keras.saving.saved_model import json_utils
56from tensorflow.python.keras.saving.saved_model import model_serialization
57from tensorflow.python.keras.utils import generic_utils
58from tensorflow.python.keras.utils import layer_utils
59from tensorflow.python.keras.utils import object_identity
60from tensorflow.python.keras.utils import tf_utils
61from tensorflow.python.keras.utils import version_utils
62from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
63from tensorflow.python.keras.utils.io_utils import path_to_string
64from tensorflow.python.keras.utils.mode_keys import ModeKeys
65from tensorflow.python.ops import array_ops
66from tensorflow.python.ops import math_ops
67from tensorflow.python.ops import sparse_ops
68from tensorflow.python.ops import summary_ops_v2
69from tensorflow.python.ops import variables
70from tensorflow.python.platform import tf_logging as logging
71from tensorflow.python.profiler import trace
72from tensorflow.python.saved_model import constants as sm_constants
73from tensorflow.python.saved_model import loader_impl as sm_loader
74from tensorflow.python.training import checkpoint_management
75from tensorflow.python.training import py_checkpoint_reader
76from tensorflow.python.training.tracking import base as trackable
77from tensorflow.python.training.tracking import graph_view as graph_view_lib
78from tensorflow.python.training.tracking import util as trackable_utils
79from tensorflow.python.util import nest
80from tensorflow.python.util import tf_decorator
81from tensorflow.python.util.tf_export import keras_export
82from tensorflow.tools.docs import doc_controls
83
84
85# pylint: disable=g-import-not-at-top
86try:
87  import h5py
88except ImportError:
89  h5py = None
90# pylint: enable=g-import-not-at-top
91
92
93def disable_multi_worker(method):
94  """Decorator that disallows multi-worker use of `method`."""
95
96  def _method_wrapper(self, *args, **kwargs):
97    if self._in_multi_worker_mode():  # pylint: disable=protected-access
98      raise ValueError('{} is not supported in multi-worker mode.'.format(
99          method.__name__))
100    return method(self, *args, **kwargs)
101
102  return tf_decorator.make_decorator(
103      target=method, decorator_func=_method_wrapper)
104
105
106def inject_functional_model_class(cls):
107  """Inject `Functional` into the hierarchy of this class if needed."""
108  from tensorflow.python.keras.engine import functional  # pylint: disable=g-import-not-at-top
109  from tensorflow.python.keras.engine import training_v1  # pylint: disable=g-import-not-at-top
110  if cls == Model or cls == training_v1.Model:
111    return functional.Functional
112  # In case there is any multiple inheritance, we stop injecting the
113  # class if keras model is not in its class hierarchy.
114  if cls == object:
115    return object
116
117  cls.__bases__ = tuple(inject_functional_model_class(base)
118                        for base in cls.__bases__)
119  # Trigger any `__new__` class swapping that needed to happen on `Functional`
120  # but did not because functional was not in the class hierarchy.
121  cls.__new__(cls)
122
123  return cls
124
125
126def is_functional_model_init_params(args, kwargs):
127  return (len(args) == 2 or
128          len(args) == 1 and 'outputs' in kwargs or
129          'inputs' in kwargs and 'outputs' in kwargs)
130
131
132@keras_export('keras.Model', 'keras.models.Model')
133class Model(base_layer.Layer, version_utils.ModelVersionSelector):
134  """`Model` groups layers into an object with training and inference features.
135
136  Args:
137      inputs: The input(s) of the model: a `keras.Input` object or list of
138          `keras.Input` objects.
139      outputs: The output(s) of the model. See Functional API example below.
140      name: String, the name of the model.
141
142  There are two ways to instantiate a `Model`:
143
144  1 - With the "Functional API", where you start from `Input`,
145  you chain layer calls to specify the model's forward pass,
146  and finally you create your model from inputs and outputs:
147
148  ```python
149  import tensorflow as tf
150
151  inputs = tf.keras.Input(shape=(3,))
152  x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
153  outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
154  model = tf.keras.Model(inputs=inputs, outputs=outputs)
155  ```
156
157  Note: Only dicts, lists, and tuples of input tensors are supported. Nested
158  inputs are not supported (e.g. lists of list or dicts of dict).
159
160  2 - By subclassing the `Model` class: in that case, you should define your
161  layers in `__init__` and you should implement the model's forward pass
162  in `call`.
163
164  ```python
165  import tensorflow as tf
166
167  class MyModel(tf.keras.Model):
168
169    def __init__(self):
170      super(MyModel, self).__init__()
171      self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
172      self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
173
174    def call(self, inputs):
175      x = self.dense1(inputs)
176      return self.dense2(x)
177
178  model = MyModel()
179  ```
180
181  If you subclass `Model`, you can optionally have
182  a `training` argument (boolean) in `call`, which you can use to specify
183  a different behavior in training and inference:
184
185  ```python
186  import tensorflow as tf
187
188  class MyModel(tf.keras.Model):
189
190    def __init__(self):
191      super(MyModel, self).__init__()
192      self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
193      self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
194      self.dropout = tf.keras.layers.Dropout(0.5)
195
196    def call(self, inputs, training=False):
197      x = self.dense1(inputs)
198      if training:
199        x = self.dropout(x, training=training)
200      return self.dense2(x)
201
202  model = MyModel()
203  ```
204
205  Once the model is created, you can config the model with losses and metrics
206  with `model.compile()`, train the model with `model.fit()`, or use the model
207  to do prediction with `model.predict()`.
208  """
209  _TF_MODULE_IGNORED_PROPERTIES = frozenset(
210      itertools.chain(('_train_counter', '_test_counter', '_predict_counter',
211                       '_steps_per_execution'),
212                      base_layer.Layer._TF_MODULE_IGNORED_PROPERTIES))  # pylint: disable=protected-access
213
214  def __new__(cls, *args, **kwargs):
215    # Signature detection
216    if is_functional_model_init_params(args, kwargs) and cls == Model:
217      # Functional model
218      from tensorflow.python.keras.engine import functional  # pylint: disable=g-import-not-at-top
219      return functional.Functional(skip_init=True, *args, **kwargs)
220    else:
221      return super(Model, cls).__new__(cls, *args, **kwargs)
222
223  @trackable.no_automatic_dependency_tracking
224  def __init__(self, *args, **kwargs):
225    self._is_model_for_instrumentation = True
226
227    # Special case for Subclassed Functional Model, which we couldn't detect
228    # when __new__ is called. We only realize it is a functional model when it
229    # calls super.__init__ with input and output tensor.
230    from tensorflow.python.keras.engine import functional  # pylint: disable=g-import-not-at-top
231    if (is_functional_model_init_params(args, kwargs) and
232        not isinstance(self, functional.Functional)):
233      # Filter the kwargs for multiple inheritance.
234      supported_kwargs = ['inputs', 'outputs', 'name', 'trainable', 'skip_init']
235      model_kwargs = {k: kwargs[k] for k in kwargs if k in supported_kwargs}
236      other_kwargs = {k: kwargs[k] for k in kwargs if k not in supported_kwargs}
237      inject_functional_model_class(self.__class__)
238      functional.Functional.__init__(self, *args, **model_kwargs)
239
240      # In case there is any multiple inheritance here, we need to call the
241      # __init__ for any class that appears after the Functional class.
242      clz_to_init = []
243      found_functional_class = False
244      for clz in self.__class__.__bases__:
245        if issubclass(clz, functional.Functional):
246          found_functional_class = True
247          continue
248        if found_functional_class:
249          clz_to_init.append(clz)
250
251      if clz_to_init:
252        for clz in clz_to_init:
253          clz.__init__(self, *args, **other_kwargs)
254      elif other_kwargs:
255        # In case there are unused kwargs, we should raise an error to user, in
256        # case they have a typo in the param name.
257        raise TypeError(
258            'The following keyword arguments aren\'t supported: {}'.format(
259                other_kwargs))
260      return
261
262    # The following are implemented as property functions:
263    # self.trainable_weights
264    # self.non_trainable_weights
265    # `inputs` / `outputs` will only appear in kwargs if either are misspelled.
266    generic_utils.validate_kwargs(kwargs, {
267        'trainable', 'dtype', 'dynamic', 'name', 'autocast', 'inputs', 'outputs'
268    })
269    super(Model, self).__init__(**kwargs)
270    # By default, Model is a subclass model, which is not in graph network.
271    self._is_graph_network = False
272
273    self.inputs = None
274    self.outputs = None
275    self.input_names = None
276    self.output_names = None
277    # stop_training is used by callback to stop training when error happens
278    self.stop_training = False
279    self.history = None
280    # These objects are used in the default `Model.compile`. They are not
281    # guaranteed to be set after `Model.compile` is called, as users can
282    # override compile with custom logic.
283    self.compiled_loss = None
284    self.compiled_metrics = None
285
286    # This is True for Sequential networks and Functional networks.
287    self._compute_output_and_mask_jointly = False
288
289    # Don't reset compilation if already done. This may occur if calling
290    # `__init__` (or `_init_graph_network`) on an already-compiled model
291    # such as a Sequential model. Sequential models may need to rebuild
292    # themselves after compilation.
293    self._maybe_create_attribute('_is_compiled', False)
294    self._maybe_create_attribute('optimizer', None)
295
296    # Model must be created under scope of DistStrat it will be trained with.
297    if ds_context.has_strategy():
298      self._distribution_strategy = ds_context.get_strategy()
299    else:
300      self._distribution_strategy = None
301
302    self._cluster_coordinator = None
303
304    # Defaults to value of `tf.config.experimental_functions_run_eagerly`.
305    self._run_eagerly = None
306    # Initialize cache attrs.
307    self._reset_compile_cache()
308
309    # Fault-tolerance handler. Set in `ModelCheckpoint`.
310    self._training_state = None
311    self._saved_model_inputs_spec = None
312    self._trackable_saver = saver_with_op_caching(self)
313
314    self._steps_per_execution = None
315
316    self._init_batch_counters()
317    self._base_model_initialized = True
318
319  @trackable.no_automatic_dependency_tracking
320  def _init_batch_counters(self):
321    # Untracked Variables, used to keep track of mini-batches seen in `fit`,
322    # `evaluate`, and `predict`.
323    agg = variables.VariableAggregationV2.ONLY_FIRST_REPLICA
324    self._train_counter = variables.Variable(0, dtype='int64', aggregation=agg)
325    self._test_counter = variables.Variable(0, dtype='int64', aggregation=agg)
326    self._predict_counter = variables.Variable(
327        0, dtype='int64', aggregation=agg)
328
329  def __setattr__(self, name, value):
330    if not getattr(self, '_self_setattr_tracking', True):
331      super(Model, self).__setattr__(name, value)
332      return
333
334    if all(
335        isinstance(v, (base_layer.Layer, variables.Variable)) or
336        base_layer_utils.has_weights(v) for v in nest.flatten(value)):
337      try:
338        self._base_model_initialized
339      except AttributeError:
340        raise RuntimeError(
341            'It looks like you are subclassing `Model` and you '
342            'forgot to call `super().__init__()`.'
343            ' Always start with this line.')
344
345    super(Model, self).__setattr__(name, value)
346
347  @generic_utils.default
348  def build(self, input_shape):
349    """Builds the model based on input shapes received.
350
351    This is to be used for subclassed models, which do not know at instantiation
352    time what their inputs look like.
353
354    This method only exists for users who want to call `model.build()` in a
355    standalone way (as a substitute for calling the model on real data to
356    build it). It will never be called by the framework (and thus it will
357    never throw unexpected errors in an unrelated workflow).
358
359    Args:
360     input_shape: Single tuple, TensorShape, or list/dict of shapes, where
361         shapes are tuples, integers, or TensorShapes.
362
363    Raises:
364      ValueError:
365        1. In case of invalid user-provided data (not of type tuple,
366           list, TensorShape, or dict).
367        2. If the model requires call arguments that are agnostic
368           to the input shapes (positional or kwarg in call signature).
369        3. If not all layers were properly built.
370        4. If float type inputs are not supported within the layers.
371
372      In each of these cases, the user should build their model by calling it
373      on real tensor data.
374    """
375    if self._is_graph_network:
376      super(Model, self).build(input_shape)
377      return
378
379    if input_shape is None:
380      raise ValueError('Input shape must be defined when calling build on a '
381                       'model subclass network.')
382    valid_types = (tuple, list, tensor_shape.TensorShape, dict)
383    if not isinstance(input_shape, valid_types):
384      raise ValueError('Specified input shape is not one of the valid types. '
385                       'Please specify a batch input shape of type tuple or '
386                       'list of input shapes. User provided '
387                       'input type: {}'.format(type(input_shape)))
388
389    if input_shape and not self.inputs:
390      # We create placeholders for the `None`s in the shape and build the model
391      # in a Graph. Since tf.Variable is compatible with both eager execution
392      # and graph building, the variables created after building the model in
393      # a Graph are still valid when executing eagerly.
394      if context.executing_eagerly():
395        graph = func_graph.FuncGraph('build_graph')
396      else:
397        graph = backend.get_graph()
398      with graph.as_default():
399        if (isinstance(input_shape, list) and
400            all(d is None or isinstance(d, int) for d in input_shape)):
401          input_shape = tuple(input_shape)
402        if isinstance(input_shape, list):
403          x = [base_layer_utils.generate_placeholders_from_shape(shape)
404               for shape in input_shape]
405        elif isinstance(input_shape, dict):
406          x = {
407              k: base_layer_utils.generate_placeholders_from_shape(shape)
408              for k, shape in input_shape.items()
409          }
410        else:
411          x = base_layer_utils.generate_placeholders_from_shape(input_shape)
412
413        kwargs = {}
414        call_signature = self._call_full_argspec
415        call_args = call_signature.args
416        # Exclude `self`, `inputs`, and any argument with a default value.
417        if len(call_args) > 2:
418          if call_signature.defaults:
419            call_args = call_args[2:-len(call_signature.defaults)]
420          else:
421            call_args = call_args[2:]
422          for arg in call_args:
423            if arg == 'training':
424              # Case where `training` is a positional arg with no default.
425              kwargs['training'] = False
426            else:
427              # Has invalid call signature with unknown positional arguments.
428              raise ValueError(
429                  'Currently, you cannot build your model if it has '
430                  'positional or keyword arguments that are not '
431                  'inputs to the model, but are required for its '
432                  '`call` method. Instead, in order to instantiate '
433                  'and build your model, `call` your model on real '
434                  'tensor data with all expected call arguments.')
435        elif len(call_args) < 2:
436          # Signature without `inputs`.
437          raise ValueError('You can only call `build` on a model if its `call` '
438                           'method accepts an `inputs` argument.')
439        try:
440          self.call(x, **kwargs)
441        except (errors.InvalidArgumentError, TypeError):
442          raise ValueError('You cannot build your model by calling `build` '
443                           'if your layers do not support float type inputs. '
444                           'Instead, in order to instantiate and build your '
445                           'model, `call` your model on real tensor data (of '
446                           'the correct dtype).')
447    super(Model, self).build(input_shape)
448
449  @doc_controls.doc_in_current_and_subclasses
450  def call(self, inputs, training=None, mask=None):
451    """Calls the model on new inputs.
452
453    In this case `call` just reapplies
454    all ops in the graph to the new inputs
455    (e.g. build a new computational graph from the provided inputs).
456
457    Note: This method should not be called directly. It is only meant to be
458    overridden when subclassing `tf.keras.Model`.
459    To call a model on an input, always use the `__call__` method,
460    i.e. `model(inputs)`, which relies on the underlying `call` method.
461
462    Args:
463        inputs: Input tensor, or dict/list/tuple of input tensors.
464        training: Boolean or boolean scalar tensor, indicating whether to run
465          the `Network` in training mode or inference mode.
466        mask: A mask or list of masks. A mask can be
467            either a tensor or None (no mask).
468
469    Returns:
470        A tensor if there is a single output, or
471        a list of tensors if there are more than one outputs.
472    """
473    raise NotImplementedError('When subclassing the `Model` class, you should '
474                              'implement a `call` method.')
475
476  def compile(self,
477              optimizer='rmsprop',
478              loss=None,
479              metrics=None,
480              loss_weights=None,
481              weighted_metrics=None,
482              run_eagerly=None,
483              steps_per_execution=None,
484              **kwargs):
485    """Configures the model for training.
486
487    Args:
488        optimizer: String (name of optimizer) or optimizer instance. See
489          `tf.keras.optimizers`.
490        loss: String (name of objective function), objective function or
491          `tf.keras.losses.Loss` instance. See `tf.keras.losses`. An objective
492          function is any callable with the signature `loss = fn(y_true,
493          y_pred)`, where y_true = ground truth values with shape =
494          `[batch_size, d0, .. dN]`, except sparse loss functions such as sparse
495          categorical crossentropy where shape = `[batch_size, d0, .. dN-1]`.
496          y_pred = predicted values with shape = `[batch_size, d0, .. dN]`. It
497          returns a weighted loss float tensor. If a custom `Loss` instance is
498          used and reduction is set to `None`, return value has the shape
499          `[batch_size, d0, .. dN-1]` i.e. per-sample or per-timestep loss
500          values; otherwise, it is a scalar. If the model has multiple outputs,
501          you can use a different loss on each output by passing a dictionary
502          or a list of losses. The loss value that will be minimized by the
503          model will then be the sum of all individual losses, unless
504          `loss_weights` is specified.
505        metrics: List of metrics to be evaluated by the model during training
506          and testing. Each of this can be a string (name of a built-in
507          function), function or a `tf.keras.metrics.Metric` instance. See
508          `tf.keras.metrics`. Typically you will use `metrics=['accuracy']`. A
509          function is any callable with the signature `result = fn(y_true,
510          y_pred)`. To specify different metrics for different outputs of a
511          multi-output model, you could also pass a dictionary, such as
512          `metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}`.
513          You can also pass a list to specify a metric or a list of metrics
514          for each output, such as `metrics=[['accuracy'], ['accuracy', 'mse']]`
515          or `metrics=['accuracy', ['accuracy', 'mse']]`. When you pass the
516          strings 'accuracy' or 'acc', we convert this to one of
517          `tf.keras.metrics.BinaryAccuracy`,
518          `tf.keras.metrics.CategoricalAccuracy`,
519          `tf.keras.metrics.SparseCategoricalAccuracy` based on the loss
520          function used and the model output shape. We do a similar
521          conversion for the strings 'crossentropy' and 'ce' as well.
522        loss_weights: Optional list or dictionary specifying scalar coefficients
523          (Python floats) to weight the loss contributions of different model
524          outputs. The loss value that will be minimized by the model will then
525          be the *weighted sum* of all individual losses, weighted by the
526          `loss_weights` coefficients.
527            If a list, it is expected to have a 1:1 mapping to the model's
528              outputs. If a dict, it is expected to map output names (strings)
529              to scalar coefficients.
530        weighted_metrics: List of metrics to be evaluated and weighted by
531          `sample_weight` or `class_weight` during training and testing.
532        run_eagerly: Bool. Defaults to `False`. If `True`, this `Model`'s
533          logic will not be wrapped in a `tf.function`. Recommended to leave
534          this as `None` unless your `Model` cannot be run inside a
535          `tf.function`. `run_eagerly=True` is not supported when using
536          `tf.distribute.experimental.ParameterServerStrategy`.
537        steps_per_execution: Int. Defaults to 1. The number of batches to
538          run during each `tf.function` call. Running multiple batches
539          inside a single `tf.function` call can greatly improve performance
540          on TPUs or small models with a large Python overhead.
541          At most, one full epoch will be run each
542          execution. If a number larger than the size of the epoch is passed,
543          the execution will be truncated to the size of the epoch.
544          Note that if `steps_per_execution` is set to `N`,
545          `Callback.on_batch_begin` and `Callback.on_batch_end` methods
546          will only be called every `N` batches
547          (i.e. before/after each `tf.function` execution).
548        **kwargs: Arguments supported for backwards compatibility only.
549
550    Raises:
551        ValueError: In case of invalid arguments for
552            `optimizer`, `loss` or `metrics`.
553    """
554    with self.distribute_strategy.scope():
555      if 'experimental_steps_per_execution' in kwargs:
556        logging.warning('The argument `steps_per_execution` is no longer '
557                        'experimental. Pass `steps_per_execution` instead of '
558                        '`experimental_steps_per_execution`.')
559        if not steps_per_execution:
560          steps_per_execution = kwargs.pop('experimental_steps_per_execution')
561
562      # When compiling from an already-serialized model, we do not want to
563      # reapply some processing steps (e.g. metric renaming for multi-output
564      # models, which have prefixes added for each corresponding output name).
565      from_serialized = kwargs.pop('from_serialized', False)
566
567      self._validate_compile(optimizer, metrics, **kwargs)
568      self._run_eagerly = run_eagerly
569
570      self.optimizer = self._get_optimizer(optimizer)
571      self.compiled_loss = compile_utils.LossesContainer(
572          loss, loss_weights, output_names=self.output_names)
573      self.compiled_metrics = compile_utils.MetricsContainer(
574          metrics, weighted_metrics, output_names=self.output_names,
575          from_serialized=from_serialized)
576
577      self._configure_steps_per_execution(steps_per_execution or 1)
578
579      # Initializes attrs that are reset each time `compile` is called.
580      self._reset_compile_cache()
581      self._is_compiled = True
582
583      self.loss = loss or {}  # Backwards compat.
584
585  def _get_optimizer(self, optimizer):
586    """Wraps `optimizer` in `LossScaleOptimizer` if necessary."""
587    # The deprecated PolicyV1 has a loss_scale, which we use for backwards
588    # compatibility to match TF 2.3 behavior. The new Policy does not have a
589    # loss_scale, so we use dynamic loss scaling if the mixed_float16 policy is
590    # used.
591    if isinstance(self._dtype_policy, policy.PolicyV1):
592      loss_scale = self._dtype_policy.loss_scale
593    elif self._dtype_policy.name == 'mixed_float16':
594      loss_scale = 'dynamic'
595    else:
596      loss_scale = None
597
598    def _get_single_optimizer(opt):
599      opt = optimizers.get(opt)
600      if (loss_scale is not None and
601          not isinstance(opt, lso.LossScaleOptimizer)):
602        if loss_scale == 'dynamic':
603          opt = lso.LossScaleOptimizer(opt)
604        else:
605          opt = lso.LossScaleOptimizerV1(opt, loss_scale)
606      return opt
607
608    return nest.map_structure(_get_single_optimizer, optimizer)
609
610  @trackable.no_automatic_dependency_tracking
611  def _reset_compile_cache(self):
612    self.train_function = None
613    self.test_function = None
614    self.predict_function = None
615    # Used to cache the `tf.function`'ed `train_function` to be logged in
616    # TensorBoard, since the original `train_function` is not necessarily
617    # a `tf.function` (e.g., with ParameterServerStrategy, the `train_function`
618    # is a scheduling of the actual training function to a remote worker).
619    self.train_tf_function = None
620
621    # Used to cache `trainable` attr of `Layer`s for `fit`.
622    self._compiled_trainable_state = self._get_trainable_state()
623
624  @trackable.no_automatic_dependency_tracking
625  def _configure_steps_per_execution(self, steps_per_execution):
626    self._steps_per_execution = variables.Variable(
627        steps_per_execution,
628        dtype='int64',
629        aggregation=variables.VariableAggregationV2.ONLY_FIRST_REPLICA)
630
631  @property
632  def _should_compute_mask(self):
633    return False
634
635  @property
636  def metrics(self):
637    """Returns the model's metrics added using `compile`, `add_metric` APIs.
638
639    Note: Metrics passed to `compile()` are available only after a `keras.Model`
640    has been trained/evaluated on actual data.
641
642    Examples:
643
644    >>> inputs = tf.keras.layers.Input(shape=(3,))
645    >>> outputs = tf.keras.layers.Dense(2)(inputs)
646    >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
647    >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"])
648    >>> [m.name for m in model.metrics]
649    []
650
651    >>> x = np.random.random((2, 3))
652    >>> y = np.random.randint(0, 2, (2, 2))
653    >>> model.fit(x, y)
654    >>> [m.name for m in model.metrics]
655    ['loss', 'mae']
656
657    >>> inputs = tf.keras.layers.Input(shape=(3,))
658    >>> d = tf.keras.layers.Dense(2, name='out')
659    >>> output_1 = d(inputs)
660    >>> output_2 = d(inputs)
661    >>> model = tf.keras.models.Model(
662    ...    inputs=inputs, outputs=[output_1, output_2])
663    >>> model.add_metric(
664    ...    tf.reduce_sum(output_2), name='mean', aggregation='mean')
665    >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"])
666    >>> model.fit(x, (y, y))
667    >>> [m.name for m in model.metrics]
668    ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae',
669    'out_1_acc', 'mean']
670
671    """
672    metrics = []
673    if self._is_compiled:
674      # TODO(omalleyt): Track `LossesContainer` and `MetricsContainer` objects
675      # so that attr names are not load-bearing.
676      if self.compiled_loss is not None:
677        metrics += self.compiled_loss.metrics
678      if self.compiled_metrics is not None:
679        metrics += self.compiled_metrics.metrics
680
681    for l in self._flatten_layers():
682      metrics.extend(l._metrics)  # pylint: disable=protected-access
683    return metrics
684
685  @property
686  def metrics_names(self):
687    """Returns the model's display labels for all outputs.
688
689    Note: `metrics_names` are available only after a `keras.Model` has been
690    trained/evaluated on actual data.
691
692    Examples:
693
694    >>> inputs = tf.keras.layers.Input(shape=(3,))
695    >>> outputs = tf.keras.layers.Dense(2)(inputs)
696    >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
697    >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"])
698    >>> model.metrics_names
699    []
700
701    >>> x = np.random.random((2, 3))
702    >>> y = np.random.randint(0, 2, (2, 2))
703    >>> model.fit(x, y)
704    >>> model.metrics_names
705    ['loss', 'mae']
706
707    >>> inputs = tf.keras.layers.Input(shape=(3,))
708    >>> d = tf.keras.layers.Dense(2, name='out')
709    >>> output_1 = d(inputs)
710    >>> output_2 = d(inputs)
711    >>> model = tf.keras.models.Model(
712    ...    inputs=inputs, outputs=[output_1, output_2])
713    >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"])
714    >>> model.fit(x, (y, y))
715    >>> model.metrics_names
716    ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae',
717    'out_1_acc']
718
719    """
720
721    # This property includes all output names including `loss` and per-output
722    # losses for backward compatibility.
723    return [m.name for m in self.metrics]
724
725  @property
726  def distribute_strategy(self):
727    """The `tf.distribute.Strategy` this model was created under."""
728    return self._distribution_strategy or ds_context.get_strategy()
729
730  @property
731  def run_eagerly(self):
732    """Settable attribute indicating whether the model should run eagerly.
733
734    Running eagerly means that your model will be run step by step,
735    like Python code. Your model might run slower, but it should become easier
736    for you to debug it by stepping into individual layer calls.
737
738    By default, we will attempt to compile your model to a static graph to
739    deliver the best execution performance.
740
741    Returns:
742      Boolean, whether the model should run eagerly.
743    """
744    if self.dynamic and self._run_eagerly is False:  # pylint:disable=g-bool-id-comparison
745      # TODO(fchollet): consider using py_func to enable this.
746      raise ValueError('Your model contains layers that can only be '
747                       'successfully run in eager execution (layers '
748                       'constructed with `dynamic=True`). '
749                       'You cannot set `run_eagerly=False`.')
750
751    if self._cluster_coordinator and self._run_eagerly:
752      raise ValueError('When using `Model` with `ParameterServerStrategy`, '
753                       '`run_eagerly` is not supported.')
754
755    # Run eagerly logic, by priority:
756    # (1) Dynamic models must be run eagerly.
757    # (2) Explicitly setting run_eagerly causes a Model to be run eagerly.
758    # (3) Not explicitly setting run_eagerly defaults to TF's global setting.
759    return (self.dynamic or self._run_eagerly or
760            (def_function.functions_run_eagerly() and
761             self._run_eagerly is None))
762
763  @run_eagerly.setter
764  def run_eagerly(self, value):
765    self._run_eagerly = value
766
767  def train_step(self, data):
768    """The logic for one training step.
769
770    This method can be overridden to support custom training logic.
771    For concrete examples of how to override this method see
772    [Customizing what happends in fit](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit).
773    This method is called by `Model.make_train_function`.
774
775    This method should contain the mathematical logic for one step of training.
776    This typically includes the forward pass, loss calculation, backpropagation,
777    and metric updates.
778
779    Configuration details for *how* this logic is run (e.g. `tf.function` and
780    `tf.distribute.Strategy` settings), should be left to
781    `Model.make_train_function`, which can also be overridden.
782
783    Args:
784      data: A nested structure of `Tensor`s.
785
786    Returns:
787      A `dict` containing values that will be passed to
788      `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the
789      values of the `Model`'s metrics are returned. Example:
790      `{'loss': 0.2, 'accuracy': 0.7}`.
791
792    """
793    # These are the only transformations `Model.fit` applies to user-input
794    # data when a `tf.data.Dataset` is provided.
795    data = data_adapter.expand_1d(data)
796    x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
797    # Run forward pass.
798    with backprop.GradientTape() as tape:
799      y_pred = self(x, training=True)
800      loss = self.compiled_loss(
801          y, y_pred, sample_weight, regularization_losses=self.losses)
802    # Run backwards pass.
803    self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
804    self.compiled_metrics.update_state(y, y_pred, sample_weight)
805    # Collect metrics to return
806    return_metrics = {}
807    for metric in self.metrics:
808      result = metric.result()
809      if isinstance(result, dict):
810        return_metrics.update(result)
811      else:
812        return_metrics[metric.name] = result
813    return return_metrics
814
815  def make_train_function(self):
816    """Creates a function that executes one step of training.
817
818    This method can be overridden to support custom training logic.
819    This method is called by `Model.fit` and `Model.train_on_batch`.
820
821    Typically, this method directly controls `tf.function` and
822    `tf.distribute.Strategy` settings, and delegates the actual training
823    logic to `Model.train_step`.
824
825    This function is cached the first time `Model.fit` or
826    `Model.train_on_batch` is called. The cache is cleared whenever
827    `Model.compile` is called.
828
829    Returns:
830      Function. The function created by this method should accept a
831      `tf.data.Iterator`, and return a `dict` containing values that will
832      be passed to `tf.keras.Callbacks.on_train_batch_end`, such as
833      `{'loss': 0.2, 'accuracy': 0.7}`.
834    """
835    if self.train_function is not None:
836      return self.train_function
837
838    def step_function(model, iterator):
839      """Runs a single training step."""
840
841      def run_step(data):
842        outputs = model.train_step(data)
843        # Ensure counter is updated only if `train_step` succeeds.
844        with ops.control_dependencies(_minimum_control_deps(outputs)):
845          model._train_counter.assign_add(1)  # pylint: disable=protected-access
846        return outputs
847
848      data = next(iterator)
849      outputs = model.distribute_strategy.run(run_step, args=(data,))
850      outputs = reduce_per_replica(
851          outputs, self.distribute_strategy, reduction='first')
852      write_scalar_summaries(outputs, step=model._train_counter)  # pylint: disable=protected-access
853      return outputs
854
855    if self._steps_per_execution.numpy().item() == 1:
856
857      def train_function(iterator):
858        """Runs a training execution with one step."""
859        return step_function(self, iterator)
860
861    else:
862
863      def train_function(iterator):
864        """Runs a training execution with multiple steps."""
865        for _ in math_ops.range(self._steps_per_execution):
866          outputs = step_function(self, iterator)
867        return outputs
868
869    if not self.run_eagerly:
870      train_function = def_function.function(
871          train_function, experimental_relax_shapes=True)
872      self.train_tf_function = train_function
873
874    self.train_function = train_function
875
876    if self._cluster_coordinator:
877      self.train_function = lambda iterator: self._cluster_coordinator.schedule(  # pylint: disable=g-long-lambda
878          train_function, args=(iterator,))
879
880    return self.train_function
881
882  def fit(self,
883          x=None,
884          y=None,
885          batch_size=None,
886          epochs=1,
887          verbose='auto',
888          callbacks=None,
889          validation_split=0.,
890          validation_data=None,
891          shuffle=True,
892          class_weight=None,
893          sample_weight=None,
894          initial_epoch=0,
895          steps_per_epoch=None,
896          validation_steps=None,
897          validation_batch_size=None,
898          validation_freq=1,
899          max_queue_size=10,
900          workers=1,
901          use_multiprocessing=False):
902    """Trains the model for a fixed number of epochs (iterations on a dataset).
903
904    Args:
905        x: Input data. It could be:
906          - A Numpy array (or array-like), or a list of arrays
907            (in case the model has multiple inputs).
908          - A TensorFlow tensor, or a list of tensors
909            (in case the model has multiple inputs).
910          - A dict mapping input names to the corresponding array/tensors,
911            if the model has named inputs.
912          - A `tf.data` dataset. Should return a tuple
913            of either `(inputs, targets)` or
914            `(inputs, targets, sample_weights)`.
915          - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
916            or `(inputs, targets, sample_weights)`.
917          - A `tf.keras.utils.experimental.DatasetCreator`, which wraps a
918            callable that takes a single argument of type
919            `tf.distribute.InputContext`, and returns a `tf.data.Dataset`.
920            `DatasetCreator` should be used when users prefer to specify the
921            per-replica batching and sharding logic for the `Dataset`.
922            See `tf.keras.utils.experimental.DatasetCreator` doc for more
923            information.
924          A more detailed description of unpacking behavior for iterator types
925          (Dataset, generator, Sequence) is given below. If using
926          `tf.distribute.experimental.ParameterServerStrategy`, only
927          `DatasetCreator` type is supported for `x`.
928        y: Target data. Like the input data `x`,
929          it could be either Numpy array(s) or TensorFlow tensor(s).
930          It should be consistent with `x` (you cannot have Numpy inputs and
931          tensor targets, or inversely). If `x` is a dataset, generator,
932          or `keras.utils.Sequence` instance, `y` should
933          not be specified (since targets will be obtained from `x`).
934        batch_size: Integer or `None`.
935            Number of samples per gradient update.
936            If unspecified, `batch_size` will default to 32.
937            Do not specify the `batch_size` if your data is in the
938            form of datasets, generators, or `keras.utils.Sequence` instances
939            (since they generate batches).
940        epochs: Integer. Number of epochs to train the model.
941            An epoch is an iteration over the entire `x` and `y`
942            data provided.
943            Note that in conjunction with `initial_epoch`,
944            `epochs` is to be understood as "final epoch".
945            The model is not trained for a number of iterations
946            given by `epochs`, but merely until the epoch
947            of index `epochs` is reached.
948        verbose: 'auto', 0, 1, or 2. Verbosity mode.
949            0 = silent, 1 = progress bar, 2 = one line per epoch.
950            'auto' defaults to 1 for most cases, but 2 when used with
951            `ParameterServerStrategy`. Note that the progress bar is not
952            particularly useful when logged to a file, so verbose=2 is
953            recommended when not running interactively (eg, in a production
954            environment).
955        callbacks: List of `keras.callbacks.Callback` instances.
956            List of callbacks to apply during training.
957            See `tf.keras.callbacks`. Note `tf.keras.callbacks.ProgbarLogger`
958            and `tf.keras.callbacks.History` callbacks are created automatically
959            and need not be passed into `model.fit`.
960            `tf.keras.callbacks.ProgbarLogger` is created or not based on
961            `verbose` argument to `model.fit`.
962            Callbacks with batch-level calls are currently unsupported with
963            `tf.distribute.experimental.ParameterServerStrategy`, and users are
964            advised to implement epoch-level calls instead with an appropriate
965            `steps_per_epoch` value.
966        validation_split: Float between 0 and 1.
967            Fraction of the training data to be used as validation data.
968            The model will set apart this fraction of the training data,
969            will not train on it, and will evaluate
970            the loss and any model metrics
971            on this data at the end of each epoch.
972            The validation data is selected from the last samples
973            in the `x` and `y` data provided, before shuffling. This argument is
974            not supported when `x` is a dataset, generator or
975           `keras.utils.Sequence` instance.
976            `validation_split` is not yet supported with
977            `tf.distribute.experimental.ParameterServerStrategy`.
978        validation_data: Data on which to evaluate
979            the loss and any model metrics at the end of each epoch.
980            The model will not be trained on this data. Thus, note the fact
981            that the validation loss of data provided using `validation_split`
982            or `validation_data` is not affected by regularization layers like
983            noise and dropout.
984            `validation_data` will override `validation_split`.
985            `validation_data` could be:
986              - A tuple `(x_val, y_val)` of Numpy arrays or tensors.
987              - A tuple `(x_val, y_val, val_sample_weights)` of NumPy arrays.
988              - A `tf.data.Dataset`.
989              - A Python generator or `keras.utils.Sequence` returning
990              `(inputs, targets)` or `(inputs, targets, sample_weights)`.
991            `validation_data` is not yet supported with
992            `tf.distribute.experimental.ParameterServerStrategy`.
993        shuffle: Boolean (whether to shuffle the training data
994            before each epoch) or str (for 'batch'). This argument is ignored
995            when `x` is a generator or an object of tf.data.Dataset.
996            'batch' is a special option for dealing
997            with the limitations of HDF5 data; it shuffles in batch-sized
998            chunks. Has no effect when `steps_per_epoch` is not `None`.
999        class_weight: Optional dictionary mapping class indices (integers)
1000            to a weight (float) value, used for weighting the loss function
1001            (during training only).
1002            This can be useful to tell the model to
1003            "pay more attention" to samples from
1004            an under-represented class.
1005        sample_weight: Optional Numpy array of weights for
1006            the training samples, used for weighting the loss function
1007            (during training only). You can either pass a flat (1D)
1008            Numpy array with the same length as the input samples
1009            (1:1 mapping between weights and samples),
1010            or in the case of temporal data,
1011            you can pass a 2D array with shape
1012            `(samples, sequence_length)`,
1013            to apply a different weight to every timestep of every sample. This
1014            argument is not supported when `x` is a dataset, generator, or
1015           `keras.utils.Sequence` instance, instead provide the sample_weights
1016            as the third element of `x`.
1017        initial_epoch: Integer.
1018            Epoch at which to start training
1019            (useful for resuming a previous training run).
1020        steps_per_epoch: Integer or `None`.
1021            Total number of steps (batches of samples)
1022            before declaring one epoch finished and starting the
1023            next epoch. When training with input tensors such as
1024            TensorFlow data tensors, the default `None` is equal to
1025            the number of samples in your dataset divided by
1026            the batch size, or 1 if that cannot be determined. If x is a
1027            `tf.data` dataset, and 'steps_per_epoch'
1028            is None, the epoch will run until the input dataset is exhausted.
1029            When passing an infinitely repeating dataset, you must specify the
1030            `steps_per_epoch` argument. If `steps_per_epoch=-1` the training
1031            will run indefinitely with an infinitely repeating dataset.
1032            This argument is not supported with array inputs.
1033            When using `tf.distribute.experimental.ParameterServerStrategy`:
1034              * `steps_per_epoch=None` is not supported.
1035        validation_steps: Only relevant if `validation_data` is provided and
1036            is a `tf.data` dataset. Total number of steps (batches of
1037            samples) to draw before stopping when performing validation
1038            at the end of every epoch. If 'validation_steps' is None, validation
1039            will run until the `validation_data` dataset is exhausted. In the
1040            case of an infinitely repeated dataset, it will run into an
1041            infinite loop. If 'validation_steps' is specified and only part of
1042            the dataset will be consumed, the evaluation will start from the
1043            beginning of the dataset at each epoch. This ensures that the same
1044            validation samples are used every time.
1045        validation_batch_size: Integer or `None`.
1046            Number of samples per validation batch.
1047            If unspecified, will default to `batch_size`.
1048            Do not specify the `validation_batch_size` if your data is in the
1049            form of datasets, generators, or `keras.utils.Sequence` instances
1050            (since they generate batches).
1051        validation_freq: Only relevant if validation data is provided. Integer
1052            or `collections.abc.Container` instance (e.g. list, tuple, etc.).
1053            If an integer, specifies how many training epochs to run before a
1054            new validation run is performed, e.g. `validation_freq=2` runs
1055            validation every 2 epochs. If a Container, specifies the epochs on
1056            which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
1057            validation at the end of the 1st, 2nd, and 10th epochs.
1058        max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
1059            input only. Maximum size for the generator queue.
1060            If unspecified, `max_queue_size` will default to 10.
1061        workers: Integer. Used for generator or `keras.utils.Sequence` input
1062            only. Maximum number of processes to spin up
1063            when using process-based threading. If unspecified, `workers`
1064            will default to 1.
1065        use_multiprocessing: Boolean. Used for generator or
1066            `keras.utils.Sequence` input only. If `True`, use process-based
1067            threading. If unspecified, `use_multiprocessing` will default to
1068            `False`. Note that because this implementation relies on
1069            multiprocessing, you should not pass non-picklable arguments to
1070            the generator as they can't be passed easily to children processes.
1071
1072    Unpacking behavior for iterator-like inputs:
1073        A common pattern is to pass a tf.data.Dataset, generator, or
1074      tf.keras.utils.Sequence to the `x` argument of fit, which will in fact
1075      yield not only features (x) but optionally targets (y) and sample weights.
1076      Keras requires that the output of such iterator-likes be unambiguous. The
1077      iterator should return a tuple of length 1, 2, or 3, where the optional
1078      second and third elements will be used for y and sample_weight
1079      respectively. Any other type provided will be wrapped in a length one
1080      tuple, effectively treating everything as 'x'. When yielding dicts, they
1081      should still adhere to the top-level tuple structure.
1082      e.g. `({"x0": x0, "x1": x1}, y)`. Keras will not attempt to separate
1083      features, targets, and weights from the keys of a single dict.
1084        A notable unsupported data type is the namedtuple. The reason is that
1085      it behaves like both an ordered datatype (tuple) and a mapping
1086      datatype (dict). So given a namedtuple of the form:
1087          `namedtuple("example_tuple", ["y", "x"])`
1088      it is ambiguous whether to reverse the order of the elements when
1089      interpreting the value. Even worse is a tuple of the form:
1090          `namedtuple("other_tuple", ["x", "y", "z"])`
1091      where it is unclear if the tuple was intended to be unpacked into x, y,
1092      and sample_weight or passed through as a single element to `x`. As a
1093      result the data processing code will simply raise a ValueError if it
1094      encounters a namedtuple. (Along with instructions to remedy the issue.)
1095
1096    Returns:
1097        A `History` object. Its `History.history` attribute is
1098        a record of training loss values and metrics values
1099        at successive epochs, as well as validation loss values
1100        and validation metrics values (if applicable).
1101
1102    Raises:
1103        RuntimeError: 1. If the model was never compiled or,
1104        2. If `model.fit` is  wrapped in `tf.function`.
1105
1106        ValueError: In case of mismatch between the provided input data
1107            and what the model expects or when the input data is empty.
1108    """
1109    # Legacy graph support is contained in `training_v1.Model`.
1110    version_utils.disallow_legacy_graph('Model', 'fit')
1111    self._assert_compile_was_called()
1112    self._check_call_args('fit')
1113    _disallow_inside_tf_function('fit')
1114
1115    if verbose == 'auto':
1116      if self.distribute_strategy._should_use_with_coordinator:  # pylint: disable=protected-access
1117        verbose = 2  # Default to epoch-level logging for PSStrategy.
1118      else:
1119        verbose = 1  # Default to batch-level logging otherwise.
1120
1121    if validation_split:
1122      # Create the validation data using the training data. Only supported for
1123      # `Tensor` and `NumPy` input.
1124      (x, y, sample_weight), validation_data = (
1125          data_adapter.train_validation_split(
1126              (x, y, sample_weight), validation_split=validation_split))
1127
1128    if validation_data:
1129      val_x, val_y, val_sample_weight = (
1130          data_adapter.unpack_x_y_sample_weight(validation_data))
1131
1132    if self.distribute_strategy._should_use_with_coordinator:  # pylint: disable=protected-access
1133      self._cluster_coordinator = cluster_coordinator.ClusterCoordinator(
1134          self.distribute_strategy)
1135
1136    with self.distribute_strategy.scope(), \
1137         training_utils.RespectCompiledTrainableState(self):
1138      # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
1139      data_handler = data_adapter.get_data_handler(
1140          x=x,
1141          y=y,
1142          sample_weight=sample_weight,
1143          batch_size=batch_size,
1144          steps_per_epoch=steps_per_epoch,
1145          initial_epoch=initial_epoch,
1146          epochs=epochs,
1147          shuffle=shuffle,
1148          class_weight=class_weight,
1149          max_queue_size=max_queue_size,
1150          workers=workers,
1151          use_multiprocessing=use_multiprocessing,
1152          model=self,
1153          steps_per_execution=self._steps_per_execution)
1154
1155      # Container that configures and calls `tf.keras.Callback`s.
1156      if not isinstance(callbacks, callbacks_module.CallbackList):
1157        callbacks = callbacks_module.CallbackList(
1158            callbacks,
1159            add_history=True,
1160            add_progbar=verbose != 0,
1161            model=self,
1162            verbose=verbose,
1163            epochs=epochs,
1164            steps=data_handler.inferred_steps)
1165
1166      self.stop_training = False
1167      self.train_function = self.make_train_function()
1168      self._train_counter.assign(0)
1169      callbacks.on_train_begin()
1170      training_logs = None
1171      # Handle fault-tolerance for multi-worker.
1172      # TODO(omalleyt): Fix the ordering issues that mean this has to
1173      # happen after `callbacks.on_train_begin`.
1174      data_handler._initial_epoch = (  # pylint: disable=protected-access
1175          self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
1176      logs = None
1177      for epoch, iterator in data_handler.enumerate_epochs():
1178        self.reset_metrics()
1179        callbacks.on_epoch_begin(epoch)
1180        with data_handler.catch_stop_iteration():
1181          for step in data_handler.steps():
1182            with trace.Trace(
1183                'train',
1184                epoch_num=epoch,
1185                step_num=step,
1186                batch_size=batch_size,
1187                _r=1):
1188              callbacks.on_train_batch_begin(step)
1189              tmp_logs = self.train_function(iterator)
1190              if data_handler.should_sync:
1191                context.async_wait()
1192              logs = tmp_logs  # No error, now safe to assign to logs.
1193              end_step = step + data_handler.step_increment
1194              callbacks.on_train_batch_end(end_step, logs)
1195              if self.stop_training:
1196                break
1197
1198        logs = tf_utils.sync_to_numpy_or_python_type(logs)
1199        if logs is None:
1200          raise ValueError('Expect x to be a non-empty array or dataset.')
1201        epoch_logs = copy.copy(logs)
1202
1203        # Run validation.
1204        if validation_data and self._should_eval(epoch, validation_freq):
1205          # Create data_handler for evaluation and cache it.
1206          if getattr(self, '_eval_data_handler', None) is None:
1207            self._eval_data_handler = data_adapter.get_data_handler(
1208                x=val_x,
1209                y=val_y,
1210                sample_weight=val_sample_weight,
1211                batch_size=validation_batch_size or batch_size,
1212                steps_per_epoch=validation_steps,
1213                initial_epoch=0,
1214                epochs=1,
1215                max_queue_size=max_queue_size,
1216                workers=workers,
1217                use_multiprocessing=use_multiprocessing,
1218                model=self,
1219                steps_per_execution=self._steps_per_execution)
1220          val_logs = self.evaluate(
1221              x=val_x,
1222              y=val_y,
1223              sample_weight=val_sample_weight,
1224              batch_size=validation_batch_size or batch_size,
1225              steps=validation_steps,
1226              callbacks=callbacks,
1227              max_queue_size=max_queue_size,
1228              workers=workers,
1229              use_multiprocessing=use_multiprocessing,
1230              return_dict=True,
1231              _use_cached_eval_dataset=True)
1232          val_logs = {'val_' + name: val for name, val in val_logs.items()}
1233          epoch_logs.update(val_logs)
1234
1235        callbacks.on_epoch_end(epoch, epoch_logs)
1236        training_logs = epoch_logs
1237        if self.stop_training:
1238          break
1239
1240      # If eval data_hanlder exists, delete it after all epochs are done.
1241      if getattr(self, '_eval_data_handler', None) is not None:
1242        del self._eval_data_handler
1243      callbacks.on_train_end(logs=training_logs)
1244      return self.history
1245
1246  def test_step(self, data):
1247    """The logic for one evaluation step.
1248
1249    This method can be overridden to support custom evaluation logic.
1250    This method is called by `Model.make_test_function`.
1251
1252    This function should contain the mathematical logic for one step of
1253    evaluation.
1254    This typically includes the forward pass, loss calculation, and metrics
1255    updates.
1256
1257    Configuration details for *how* this logic is run (e.g. `tf.function` and
1258    `tf.distribute.Strategy` settings), should be left to
1259    `Model.make_test_function`, which can also be overridden.
1260
1261    Args:
1262      data: A nested structure of `Tensor`s.
1263
1264    Returns:
1265      A `dict` containing values that will be passed to
1266      `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the
1267      values of the `Model`'s metrics are returned.
1268    """
1269    data = data_adapter.expand_1d(data)
1270    x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
1271
1272    y_pred = self(x, training=False)
1273    # Updates stateful loss metrics.
1274    self.compiled_loss(
1275        y, y_pred, sample_weight, regularization_losses=self.losses)
1276    self.compiled_metrics.update_state(y, y_pred, sample_weight)
1277    # Collect metrics to return
1278    return_metrics = {}
1279    for metric in self.metrics:
1280      result = metric.result()
1281      if isinstance(result, dict):
1282        return_metrics.update(result)
1283      else:
1284        return_metrics[metric.name] = result
1285    return return_metrics
1286
1287  def make_test_function(self):
1288    """Creates a function that executes one step of evaluation.
1289
1290    This method can be overridden to support custom evaluation logic.
1291    This method is called by `Model.evaluate` and `Model.test_on_batch`.
1292
1293    Typically, this method directly controls `tf.function` and
1294    `tf.distribute.Strategy` settings, and delegates the actual evaluation
1295    logic to `Model.test_step`.
1296
1297    This function is cached the first time `Model.evaluate` or
1298    `Model.test_on_batch` is called. The cache is cleared whenever
1299    `Model.compile` is called.
1300
1301    Returns:
1302      Function. The function created by this method should accept a
1303      `tf.data.Iterator`, and return a `dict` containing values that will
1304      be passed to `tf.keras.Callbacks.on_test_batch_end`.
1305    """
1306    if self.test_function is not None:
1307      return self.test_function
1308
1309    def step_function(model, iterator):
1310      """Runs a single evaluation step."""
1311
1312      def run_step(data):
1313        outputs = model.test_step(data)
1314        # Ensure counter is updated only if `test_step` succeeds.
1315        with ops.control_dependencies(_minimum_control_deps(outputs)):
1316          model._test_counter.assign_add(1)  # pylint: disable=protected-access
1317        return outputs
1318
1319      data = next(iterator)
1320      outputs = model.distribute_strategy.run(run_step, args=(data,))
1321      outputs = reduce_per_replica(
1322          outputs, self.distribute_strategy, reduction='first')
1323      return outputs
1324
1325    if self._steps_per_execution.numpy().item() == 1:
1326
1327      def test_function(iterator):
1328        """Runs an evaluation execution with one step."""
1329        return step_function(self, iterator)
1330
1331    else:
1332
1333      def test_function(iterator):
1334        """Runs an evaluation execution with multiple steps."""
1335        for _ in math_ops.range(self._steps_per_execution):
1336          outputs = step_function(self, iterator)
1337        return outputs
1338
1339    if not self.run_eagerly:
1340      test_function = def_function.function(
1341          test_function, experimental_relax_shapes=True)
1342
1343    self.test_function = test_function
1344
1345    if self._cluster_coordinator:
1346      self.test_function = lambda iterator: self._cluster_coordinator.schedule(  # pylint: disable=g-long-lambda
1347          test_function, args=(iterator,))
1348
1349    return self.test_function
1350
1351  def evaluate(self,
1352               x=None,
1353               y=None,
1354               batch_size=None,
1355               verbose=1,
1356               sample_weight=None,
1357               steps=None,
1358               callbacks=None,
1359               max_queue_size=10,
1360               workers=1,
1361               use_multiprocessing=False,
1362               return_dict=False,
1363               **kwargs):
1364    """Returns the loss value & metrics values for the model in test mode.
1365
1366    Computation is done in batches (see the `batch_size` arg.)
1367
1368    Args:
1369        x: Input data. It could be:
1370          - A Numpy array (or array-like), or a list of arrays
1371            (in case the model has multiple inputs).
1372          - A TensorFlow tensor, or a list of tensors
1373            (in case the model has multiple inputs).
1374          - A dict mapping input names to the corresponding array/tensors,
1375            if the model has named inputs.
1376          - A `tf.data` dataset. Should return a tuple
1377            of either `(inputs, targets)` or
1378            `(inputs, targets, sample_weights)`.
1379          - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
1380            or `(inputs, targets, sample_weights)`.
1381          A more detailed description of unpacking behavior for iterator types
1382          (Dataset, generator, Sequence) is given in the `Unpacking behavior
1383          for iterator-like inputs` section of `Model.fit`.
1384        y: Target data. Like the input data `x`, it could be either Numpy
1385          array(s) or TensorFlow tensor(s). It should be consistent with `x`
1386          (you cannot have Numpy inputs and tensor targets, or inversely). If
1387          `x` is a dataset, generator or `keras.utils.Sequence` instance, `y`
1388          should not be specified (since targets will be obtained from the
1389          iterator/dataset).
1390        batch_size: Integer or `None`. Number of samples per batch of
1391          computation. If unspecified, `batch_size` will default to 32. Do not
1392          specify the `batch_size` if your data is in the form of a dataset,
1393          generators, or `keras.utils.Sequence` instances (since they generate
1394          batches).
1395        verbose: 0 or 1. Verbosity mode. 0 = silent, 1 = progress bar.
1396        sample_weight: Optional Numpy array of weights for the test samples,
1397          used for weighting the loss function. You can either pass a flat (1D)
1398          Numpy array with the same length as the input samples
1399            (1:1 mapping between weights and samples), or in the case of
1400              temporal data, you can pass a 2D array with shape `(samples,
1401              sequence_length)`, to apply a different weight to every timestep
1402              of every sample. This argument is not supported when `x` is a
1403              dataset, instead pass sample weights as the third element of `x`.
1404        steps: Integer or `None`. Total number of steps (batches of samples)
1405          before declaring the evaluation round finished. Ignored with the
1406          default value of `None`. If x is a `tf.data` dataset and `steps` is
1407          None, 'evaluate' will run until the dataset is exhausted. This
1408          argument is not supported with array inputs.
1409        callbacks: List of `keras.callbacks.Callback` instances. List of
1410          callbacks to apply during evaluation. See
1411          [callbacks](/api_docs/python/tf/keras/callbacks).
1412        max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
1413          input only. Maximum size for the generator queue. If unspecified,
1414          `max_queue_size` will default to 10.
1415        workers: Integer. Used for generator or `keras.utils.Sequence` input
1416          only. Maximum number of processes to spin up when using process-based
1417          threading. If unspecified, `workers` will default to 1.
1418        use_multiprocessing: Boolean. Used for generator or
1419          `keras.utils.Sequence` input only. If `True`, use process-based
1420          threading. If unspecified, `use_multiprocessing` will default to
1421          `False`. Note that because this implementation relies on
1422          multiprocessing, you should not pass non-picklable arguments to the
1423          generator as they can't be passed easily to children processes.
1424        return_dict: If `True`, loss and metric results are returned as a dict,
1425          with each key being the name of the metric. If `False`, they are
1426          returned as a list.
1427        **kwargs: Unused at this time.
1428
1429    See the discussion of `Unpacking behavior for iterator-like inputs` for
1430    `Model.fit`.
1431
1432    `Model.evaluate` is not yet supported with
1433    `tf.distribute.experimental.ParameterServerStrategy`.
1434
1435    Returns:
1436        Scalar test loss (if the model has a single output and no metrics)
1437        or list of scalars (if the model has multiple outputs
1438        and/or metrics). The attribute `model.metrics_names` will give you
1439        the display labels for the scalar outputs.
1440
1441    Raises:
1442        RuntimeError: If `model.evaluate` is wrapped in `tf.function`.
1443        ValueError: in case of invalid arguments.
1444    """
1445    version_utils.disallow_legacy_graph('Model', 'evaluate')
1446    self._assert_compile_was_called()
1447    self._check_call_args('evaluate')
1448    _disallow_inside_tf_function('evaluate')
1449    use_cached_eval_dataset = kwargs.pop('_use_cached_eval_dataset', False)
1450    if kwargs:
1451      raise TypeError('Invalid keyword arguments: %s' % (kwargs,))
1452
1453    if self.distribute_strategy._should_use_with_coordinator:  # pylint: disable=protected-access
1454      self._cluster_coordinator = cluster_coordinator.ClusterCoordinator(
1455          self.distribute_strategy)
1456
1457    with self.distribute_strategy.scope():
1458      # Use cached evaluation data only when it's called in `Model.fit`
1459      if (use_cached_eval_dataset
1460          and getattr(self, '_eval_data_handler', None) is not None):
1461        data_handler = self._eval_data_handler
1462      else:
1463        # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
1464        data_handler = data_adapter.get_data_handler(
1465            x=x,
1466            y=y,
1467            sample_weight=sample_weight,
1468            batch_size=batch_size,
1469            steps_per_epoch=steps,
1470            initial_epoch=0,
1471            epochs=1,
1472            max_queue_size=max_queue_size,
1473            workers=workers,
1474            use_multiprocessing=use_multiprocessing,
1475            model=self,
1476            steps_per_execution=self._steps_per_execution)
1477
1478      # Container that configures and calls `tf.keras.Callback`s.
1479      if not isinstance(callbacks, callbacks_module.CallbackList):
1480        callbacks = callbacks_module.CallbackList(
1481            callbacks,
1482            add_history=True,
1483            add_progbar=verbose != 0,
1484            model=self,
1485            verbose=verbose,
1486            epochs=1,
1487            steps=data_handler.inferred_steps)
1488
1489      logs = {}
1490      self.test_function = self.make_test_function()
1491      self._test_counter.assign(0)
1492      callbacks.on_test_begin()
1493      for _, iterator in data_handler.enumerate_epochs():  # Single epoch.
1494        self.reset_metrics()
1495        with data_handler.catch_stop_iteration():
1496          for step in data_handler.steps():
1497            with trace.Trace('test', step_num=step, _r=1):
1498              callbacks.on_test_batch_begin(step)
1499              tmp_logs = self.test_function(iterator)
1500              if data_handler.should_sync:
1501                context.async_wait()
1502              logs = tmp_logs  # No error, now safe to assign to logs.
1503              end_step = step + data_handler.step_increment
1504              callbacks.on_test_batch_end(end_step, logs)
1505      logs = tf_utils.sync_to_numpy_or_python_type(logs)
1506      callbacks.on_test_end(logs=logs)
1507
1508      if return_dict:
1509        return logs
1510      else:
1511        return flatten_metrics_in_order(logs, self.metrics_names)
1512
1513  def predict_step(self, data):
1514    """The logic for one inference step.
1515
1516    This method can be overridden to support custom inference logic.
1517    This method is called by `Model.make_predict_function`.
1518
1519    This method should contain the mathematical logic for one step of inference.
1520    This typically includes the forward pass.
1521
1522    Configuration details for *how* this logic is run (e.g. `tf.function` and
1523    `tf.distribute.Strategy` settings), should be left to
1524    `Model.make_predict_function`, which can also be overridden.
1525
1526    Args:
1527      data: A nested structure of `Tensor`s.
1528
1529    Returns:
1530      The result of one inference step, typically the output of calling the
1531      `Model` on data.
1532    """
1533    data = data_adapter.expand_1d(data)
1534    x, _, _ = data_adapter.unpack_x_y_sample_weight(data)
1535    return self(x, training=False)
1536
1537  def make_predict_function(self):
1538    """Creates a function that executes one step of inference.
1539
1540    This method can be overridden to support custom inference logic.
1541    This method is called by `Model.predict` and `Model.predict_on_batch`.
1542
1543    Typically, this method directly controls `tf.function` and
1544    `tf.distribute.Strategy` settings, and delegates the actual evaluation
1545    logic to `Model.predict_step`.
1546
1547    This function is cached the first time `Model.predict` or
1548    `Model.predict_on_batch` is called. The cache is cleared whenever
1549    `Model.compile` is called.
1550
1551    Returns:
1552      Function. The function created by this method should accept a
1553      `tf.data.Iterator`, and return the outputs of the `Model`.
1554    """
1555    if self.predict_function is not None:
1556      return self.predict_function
1557
1558    def step_function(model, iterator):
1559      """Runs a single evaluation step."""
1560
1561      def run_step(data):
1562        outputs = model.predict_step(data)
1563        # Ensure counter is updated only if `test_step` succeeds.
1564        with ops.control_dependencies(_minimum_control_deps(outputs)):
1565          model._predict_counter.assign_add(1)  # pylint: disable=protected-access
1566        return outputs
1567
1568      data = next(iterator)
1569      outputs = model.distribute_strategy.run(run_step, args=(data,))
1570      outputs = reduce_per_replica(
1571          outputs, self.distribute_strategy, reduction='concat')
1572      return outputs
1573
1574    if (self._steps_per_execution is None or
1575        self._steps_per_execution.numpy().item() == 1):
1576
1577      def predict_function(iterator):
1578        """Runs an evaluation execution with one step."""
1579        return step_function(self, iterator)
1580
1581    else:
1582
1583      def predict_function(iterator):
1584        """Runs an evaluation execution with multiple steps."""
1585        outputs = step_function(self, iterator)
1586        for _ in math_ops.range(self._steps_per_execution - 1):
1587          directives.set_loop_options(
1588              shape_invariants=[(
1589                  t, tf_utils.get_tensor_spec(t, dynamic_batch=True).shape)
1590                                for t in nest.flatten(outputs)])
1591          step_outputs = step_function(self, iterator)
1592          outputs = nest.map_structure(lambda t1, t2: concat([t1, t2]), outputs,
1593                                       step_outputs)
1594        return outputs
1595
1596    if not self.run_eagerly:
1597      predict_function = def_function.function(
1598          predict_function, experimental_relax_shapes=True)
1599
1600    self.predict_function = predict_function
1601    return self.predict_function
1602
1603  def predict(self,
1604              x,
1605              batch_size=None,
1606              verbose=0,
1607              steps=None,
1608              callbacks=None,
1609              max_queue_size=10,
1610              workers=1,
1611              use_multiprocessing=False):
1612    """Generates output predictions for the input samples.
1613
1614    Computation is done in batches. This method is designed for performance in
1615    large scale inputs. For small amount of inputs that fit in one batch,
1616    directly using `__call__` is recommended for faster execution, e.g.,
1617    `model(x)`, or `model(x, training=False)` if you have layers such as
1618    `tf.keras.layers.BatchNormalization` that behaves differently during
1619    inference. Also, note the fact that test loss is not affected by
1620    regularization layers like noise and dropout.
1621
1622    Args:
1623        x: Input samples. It could be:
1624          - A Numpy array (or array-like), or a list of arrays
1625            (in case the model has multiple inputs).
1626          - A TensorFlow tensor, or a list of tensors
1627            (in case the model has multiple inputs).
1628          - A `tf.data` dataset.
1629          - A generator or `keras.utils.Sequence` instance.
1630          A more detailed description of unpacking behavior for iterator types
1631          (Dataset, generator, Sequence) is given in the `Unpacking behavior
1632          for iterator-like inputs` section of `Model.fit`.
1633        batch_size: Integer or `None`.
1634            Number of samples per batch.
1635            If unspecified, `batch_size` will default to 32.
1636            Do not specify the `batch_size` if your data is in the
1637            form of dataset, generators, or `keras.utils.Sequence` instances
1638            (since they generate batches).
1639        verbose: Verbosity mode, 0 or 1.
1640        steps: Total number of steps (batches of samples)
1641            before declaring the prediction round finished.
1642            Ignored with the default value of `None`. If x is a `tf.data`
1643            dataset and `steps` is None, `predict` will
1644            run until the input dataset is exhausted.
1645        callbacks: List of `keras.callbacks.Callback` instances.
1646            List of callbacks to apply during prediction.
1647            See [callbacks](/api_docs/python/tf/keras/callbacks).
1648        max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
1649            input only. Maximum size for the generator queue.
1650            If unspecified, `max_queue_size` will default to 10.
1651        workers: Integer. Used for generator or `keras.utils.Sequence` input
1652            only. Maximum number of processes to spin up when using
1653            process-based threading. If unspecified, `workers` will default
1654            to 1.
1655        use_multiprocessing: Boolean. Used for generator or
1656            `keras.utils.Sequence` input only. If `True`, use process-based
1657            threading. If unspecified, `use_multiprocessing` will default to
1658            `False`. Note that because this implementation relies on
1659            multiprocessing, you should not pass non-picklable arguments to
1660            the generator as they can't be passed easily to children processes.
1661
1662    See the discussion of `Unpacking behavior for iterator-like inputs` for
1663    `Model.fit`. Note that Model.predict uses the same interpretation rules as
1664    `Model.fit` and `Model.evaluate`, so inputs must be unambiguous for all
1665    three methods.
1666
1667    Returns:
1668        Numpy array(s) of predictions.
1669
1670    Raises:
1671        RuntimeError: If `model.predict` is wrapped in `tf.function`.
1672        ValueError: In case of mismatch between the provided
1673            input data and the model's expectations,
1674            or in case a stateful model receives a number of samples
1675            that is not a multiple of the batch size.
1676    """
1677    version_utils.disallow_legacy_graph('Model', 'predict')
1678    self._check_call_args('predict')
1679    _disallow_inside_tf_function('predict')
1680
1681    # TODO(yashkatariya): Cache model on the coordinator for faster prediction.
1682    # If running under PSS, then swap it with OneDeviceStrategy so that
1683    # execution will run on the coordinator.
1684    original_pss_strategy = None
1685    if self.distribute_strategy._should_use_with_coordinator:  # pylint: disable=protected-access
1686      original_pss_strategy = self.distribute_strategy
1687      self._distribution_strategy = None
1688
1689    # Cluster coordinator is set by `.fit()` and `.evaluate()` which is not
1690    # needed in `.predict()` because all the predictions happen on the
1691    # coordinator/locally.
1692    if self._cluster_coordinator:
1693      self._cluster_coordinator = None
1694
1695    outputs = None
1696    with self.distribute_strategy.scope():
1697      # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
1698      dataset_types = (dataset_ops.DatasetV1, dataset_ops.DatasetV2)
1699      if (self._in_multi_worker_mode() or _is_tpu_multi_host(
1700          self.distribute_strategy)) and isinstance(x, dataset_types):
1701        try:
1702          options = options_lib.Options()
1703          data_option = options_lib.AutoShardPolicy.DATA
1704          options.experimental_distribute.auto_shard_policy = data_option
1705          x = x.with_options(options)
1706        except ValueError:
1707          warnings.warn('Using Model.predict with '
1708                        'MultiWorkerDistributionStrategy or TPUStrategy and '
1709                        'AutoShardPolicy.FILE might lead to out-of-order result'
1710                        '. Consider setting it to AutoShardPolicy.DATA.')
1711
1712      data_handler = data_adapter.get_data_handler(
1713          x=x,
1714          batch_size=batch_size,
1715          steps_per_epoch=steps,
1716          initial_epoch=0,
1717          epochs=1,
1718          max_queue_size=max_queue_size,
1719          workers=workers,
1720          use_multiprocessing=use_multiprocessing,
1721          model=self,
1722          steps_per_execution=self._steps_per_execution)
1723
1724      # Container that configures and calls `tf.keras.Callback`s.
1725      if not isinstance(callbacks, callbacks_module.CallbackList):
1726        callbacks = callbacks_module.CallbackList(
1727            callbacks,
1728            add_history=True,
1729            add_progbar=verbose != 0,
1730            model=self,
1731            verbose=verbose,
1732            epochs=1,
1733            steps=data_handler.inferred_steps)
1734
1735      self.predict_function = self.make_predict_function()
1736      self._predict_counter.assign(0)
1737      callbacks.on_predict_begin()
1738      batch_outputs = None
1739      for _, iterator in data_handler.enumerate_epochs():  # Single epoch.
1740        with data_handler.catch_stop_iteration():
1741          for step in data_handler.steps():
1742            callbacks.on_predict_batch_begin(step)
1743            tmp_batch_outputs = self.predict_function(iterator)
1744            if data_handler.should_sync:
1745              context.async_wait()
1746            batch_outputs = tmp_batch_outputs  # No error, now safe to assign.
1747            if outputs is None:
1748              outputs = nest.map_structure(lambda batch_output: [batch_output],
1749                                           batch_outputs)
1750            else:
1751              nest.map_structure_up_to(
1752                  batch_outputs,
1753                  lambda output, batch_output: output.append(batch_output),
1754                  outputs, batch_outputs)
1755            end_step = step + data_handler.step_increment
1756            callbacks.on_predict_batch_end(end_step, {'outputs': batch_outputs})
1757      if batch_outputs is None:
1758        raise ValueError('Expect x to be a non-empty array or dataset.')
1759      callbacks.on_predict_end()
1760    all_outputs = nest.map_structure_up_to(batch_outputs, concat, outputs)
1761
1762    # If originally PSS strategy was used, then replace it back since predict
1763    # is running under `OneDeviceStrategy` after the swap and once its done
1764    # we need to replace it back to PSS again.
1765    if original_pss_strategy is not None:
1766      self._distribution_strategy = original_pss_strategy
1767
1768    return tf_utils.sync_to_numpy_or_python_type(all_outputs)
1769
1770  def reset_metrics(self):
1771    """Resets the state of all the metrics in the model.
1772
1773    Examples:
1774
1775    >>> inputs = tf.keras.layers.Input(shape=(3,))
1776    >>> outputs = tf.keras.layers.Dense(2)(inputs)
1777    >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
1778    >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"])
1779
1780    >>> x = np.random.random((2, 3))
1781    >>> y = np.random.randint(0, 2, (2, 2))
1782    >>> _ = model.fit(x, y, verbose=0)
1783    >>> assert all(float(m.result()) for m in model.metrics)
1784
1785    >>> model.reset_metrics()
1786    >>> assert all(float(m.result()) == 0 for m in model.metrics)
1787
1788    """
1789    for m in self.metrics:
1790      m.reset_state()
1791
1792  def train_on_batch(self,
1793                     x,
1794                     y=None,
1795                     sample_weight=None,
1796                     class_weight=None,
1797                     reset_metrics=True,
1798                     return_dict=False):
1799    """Runs a single gradient update on a single batch of data.
1800
1801    Args:
1802        x: Input data. It could be:
1803          - A Numpy array (or array-like), or a list of arrays
1804              (in case the model has multiple inputs).
1805          - A TensorFlow tensor, or a list of tensors
1806              (in case the model has multiple inputs).
1807          - A dict mapping input names to the corresponding array/tensors,
1808              if the model has named inputs.
1809        y: Target data. Like the input data `x`, it could be either Numpy
1810          array(s) or TensorFlow tensor(s). It should be consistent with `x`
1811          (you cannot have Numpy inputs and tensor targets, or inversely).
1812        sample_weight: Optional array of the same length as x, containing
1813          weights to apply to the model's loss for each sample. In the case of
1814          temporal data, you can pass a 2D array with shape (samples,
1815          sequence_length), to apply a different weight to every timestep of
1816          every sample.
1817        class_weight: Optional dictionary mapping class indices (integers) to a
1818          weight (float) to apply to the model's loss for the samples from this
1819          class during training. This can be useful to tell the model to "pay
1820          more attention" to samples from an under-represented class.
1821        reset_metrics: If `True`, the metrics returned will be only for this
1822          batch. If `False`, the metrics will be statefully accumulated across
1823          batches.
1824        return_dict: If `True`, loss and metric results are returned as a dict,
1825          with each key being the name of the metric. If `False`, they are
1826          returned as a list.
1827
1828    Returns:
1829        Scalar training loss
1830        (if the model has a single output and no metrics)
1831        or list of scalars (if the model has multiple outputs
1832        and/or metrics). The attribute `model.metrics_names` will give you
1833        the display labels for the scalar outputs.
1834
1835    Raises:
1836      RuntimeError: If `model.train_on_batch` is wrapped in `tf.function`.
1837      ValueError: In case of invalid user-provided arguments.
1838    """
1839    self._assert_compile_was_called()
1840    self._check_call_args('train_on_batch')
1841    _disallow_inside_tf_function('train_on_batch')
1842    with self.distribute_strategy.scope(), \
1843         training_utils.RespectCompiledTrainableState(self):
1844      iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x,
1845                                                    y, sample_weight,
1846                                                    class_weight)
1847      self.train_function = self.make_train_function()
1848      logs = self.train_function(iterator)
1849
1850    if reset_metrics:
1851      self.reset_metrics()
1852    logs = tf_utils.sync_to_numpy_or_python_type(logs)
1853    if return_dict:
1854      return logs
1855    else:
1856      return flatten_metrics_in_order(logs, self.metrics_names)
1857
1858  def test_on_batch(self,
1859                    x,
1860                    y=None,
1861                    sample_weight=None,
1862                    reset_metrics=True,
1863                    return_dict=False):
1864    """Test the model on a single batch of samples.
1865
1866    Args:
1867        x: Input data. It could be:
1868          - A Numpy array (or array-like), or a list of arrays (in case the
1869              model has multiple inputs).
1870          - A TensorFlow tensor, or a list of tensors (in case the model has
1871              multiple inputs).
1872          - A dict mapping input names to the corresponding array/tensors, if
1873              the model has named inputs.
1874        y: Target data. Like the input data `x`, it could be either Numpy
1875          array(s) or TensorFlow tensor(s). It should be consistent with `x`
1876          (you cannot have Numpy inputs and tensor targets, or inversely).
1877        sample_weight: Optional array of the same length as x, containing
1878          weights to apply to the model's loss for each sample. In the case of
1879          temporal data, you can pass a 2D array with shape (samples,
1880          sequence_length), to apply a different weight to every timestep of
1881          every sample.
1882        reset_metrics: If `True`, the metrics returned will be only for this
1883          batch. If `False`, the metrics will be statefully accumulated across
1884          batches.
1885        return_dict: If `True`, loss and metric results are returned as a dict,
1886          with each key being the name of the metric. If `False`, they are
1887          returned as a list.
1888
1889    Returns:
1890        Scalar test loss (if the model has a single output and no metrics)
1891        or list of scalars (if the model has multiple outputs
1892        and/or metrics). The attribute `model.metrics_names` will give you
1893        the display labels for the scalar outputs.
1894
1895    Raises:
1896        RuntimeError: If `model.test_on_batch` is wrapped in `tf.function`.
1897        ValueError: In case of invalid user-provided arguments.
1898    """
1899    self._assert_compile_was_called()
1900    self._check_call_args('test_on_batch')
1901    _disallow_inside_tf_function('test_on_batch')
1902    with self.distribute_strategy.scope():
1903      iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x,
1904                                                    y, sample_weight)
1905      self.test_function = self.make_test_function()
1906      logs = self.test_function(iterator)
1907
1908    if reset_metrics:
1909      self.reset_metrics()
1910    logs = tf_utils.sync_to_numpy_or_python_type(logs)
1911    if return_dict:
1912      return logs
1913    else:
1914      return flatten_metrics_in_order(logs, self.metrics_names)
1915
1916  def predict_on_batch(self, x):
1917    """Returns predictions for a single batch of samples.
1918
1919    Args:
1920        x: Input data. It could be:
1921          - A Numpy array (or array-like), or a list of arrays (in case the
1922              model has multiple inputs).
1923          - A TensorFlow tensor, or a list of tensors (in case the model has
1924              multiple inputs).
1925
1926    Returns:
1927        Numpy array(s) of predictions.
1928
1929    Raises:
1930        RuntimeError: If `model.predict_on_batch` is wrapped in `tf.function`.
1931        ValueError: In case of mismatch between given number of inputs and
1932          expectations of the model.
1933    """
1934    self._check_call_args('predict_on_batch')
1935    _disallow_inside_tf_function('predict_on_batch')
1936    with self.distribute_strategy.scope():
1937      iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x)
1938      self.predict_function = self.make_predict_function()
1939      outputs = self.predict_function(iterator)
1940    return tf_utils.sync_to_numpy_or_python_type(outputs)
1941
1942  def fit_generator(self,
1943                    generator,
1944                    steps_per_epoch=None,
1945                    epochs=1,
1946                    verbose=1,
1947                    callbacks=None,
1948                    validation_data=None,
1949                    validation_steps=None,
1950                    validation_freq=1,
1951                    class_weight=None,
1952                    max_queue_size=10,
1953                    workers=1,
1954                    use_multiprocessing=False,
1955                    shuffle=True,
1956                    initial_epoch=0):
1957    """Fits the model on data yielded batch-by-batch by a Python generator.
1958
1959    DEPRECATED:
1960      `Model.fit` now supports generators, so there is no longer any need to use
1961      this endpoint.
1962    """
1963    warnings.warn('`Model.fit_generator` is deprecated and '
1964                  'will be removed in a future version. '
1965                  'Please use `Model.fit`, which supports generators.')
1966    return self.fit(
1967        generator,
1968        steps_per_epoch=steps_per_epoch,
1969        epochs=epochs,
1970        verbose=verbose,
1971        callbacks=callbacks,
1972        validation_data=validation_data,
1973        validation_steps=validation_steps,
1974        validation_freq=validation_freq,
1975        class_weight=class_weight,
1976        max_queue_size=max_queue_size,
1977        workers=workers,
1978        use_multiprocessing=use_multiprocessing,
1979        shuffle=shuffle,
1980        initial_epoch=initial_epoch)
1981
1982  def evaluate_generator(self,
1983                         generator,
1984                         steps=None,
1985                         callbacks=None,
1986                         max_queue_size=10,
1987                         workers=1,
1988                         use_multiprocessing=False,
1989                         verbose=0):
1990    """Evaluates the model on a data generator.
1991
1992    DEPRECATED:
1993      `Model.evaluate` now supports generators, so there is no longer any need
1994      to use this endpoint.
1995    """
1996    warnings.warn('`Model.evaluate_generator` is deprecated and '
1997                  'will be removed in a future version. '
1998                  'Please use `Model.evaluate`, which supports generators.')
1999    self._check_call_args('evaluate_generator')
2000
2001    return self.evaluate(
2002        generator,
2003        steps=steps,
2004        max_queue_size=max_queue_size,
2005        workers=workers,
2006        use_multiprocessing=use_multiprocessing,
2007        verbose=verbose,
2008        callbacks=callbacks)
2009
2010  def predict_generator(self,
2011                        generator,
2012                        steps=None,
2013                        callbacks=None,
2014                        max_queue_size=10,
2015                        workers=1,
2016                        use_multiprocessing=False,
2017                        verbose=0):
2018    """Generates predictions for the input samples from a data generator.
2019
2020    DEPRECATED:
2021      `Model.predict` now supports generators, so there is no longer any need
2022      to use this endpoint.
2023    """
2024    warnings.warn('`Model.predict_generator` is deprecated and '
2025                  'will be removed in a future version. '
2026                  'Please use `Model.predict`, which supports generators.')
2027    return self.predict(
2028        generator,
2029        steps=steps,
2030        max_queue_size=max_queue_size,
2031        workers=workers,
2032        use_multiprocessing=use_multiprocessing,
2033        verbose=verbose,
2034        callbacks=callbacks)
2035
2036  ######################################################################
2037  # Functions below are not training related. They are for model weights
2038  # tracking, save/load, serialization, etc.
2039  ######################################################################
2040
2041  @property
2042  def trainable_weights(self):
2043    self._assert_weights_created()
2044    if not self._trainable:
2045      return []
2046    trainable_variables = []
2047    for trackable_obj in self._self_tracked_trackables:
2048      trainable_variables += trackable_obj.trainable_variables
2049    trainable_variables += self._trainable_weights
2050    return self._dedup_weights(trainable_variables)
2051
2052  @property
2053  def non_trainable_weights(self):
2054    self._assert_weights_created()
2055    non_trainable_variables = []
2056    for trackable_obj in self._self_tracked_trackables:
2057      non_trainable_variables += trackable_obj.non_trainable_variables
2058
2059    if not self._trainable:
2060      # Return order is all trainable vars, then all non-trainable vars.
2061      trainable_variables = []
2062      for trackable_obj in self._self_tracked_trackables:
2063        trainable_variables += trackable_obj.trainable_variables
2064
2065      non_trainable_variables = (
2066          trainable_variables + self._trainable_weights +
2067          non_trainable_variables + self._non_trainable_weights)
2068    else:
2069      non_trainable_variables = (
2070          non_trainable_variables + self._non_trainable_weights)
2071
2072    return self._dedup_weights(non_trainable_variables)
2073
2074  def get_weights(self):
2075    """Retrieves the weights of the model.
2076
2077    Returns:
2078        A flat list of Numpy arrays.
2079    """
2080    with self.distribute_strategy.scope():
2081      return super(Model, self).get_weights()
2082
2083  def save(self,
2084           filepath,
2085           overwrite=True,
2086           include_optimizer=True,
2087           save_format=None,
2088           signatures=None,
2089           options=None,
2090           save_traces=True):
2091    # pylint: disable=line-too-long
2092    """Saves the model to Tensorflow SavedModel or a single HDF5 file.
2093
2094    Please see `tf.keras.models.save_model` or the
2095    [Serialization and Saving guide](https://keras.io/guides/serialization_and_saving/)
2096    for details.
2097
2098    Args:
2099        filepath: String, PathLike, path to SavedModel or H5 file to save the
2100            model.
2101        overwrite: Whether to silently overwrite any existing file at the
2102            target location, or provide the user with a manual prompt.
2103        include_optimizer: If True, save optimizer's state together.
2104        save_format: Either `'tf'` or `'h5'`, indicating whether to save the
2105            model to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X,
2106            and 'h5' in TF 1.X.
2107        signatures: Signatures to save with the SavedModel. Applicable to the
2108            'tf' format only. Please see the `signatures` argument in
2109            `tf.saved_model.save` for details.
2110        options: (only applies to SavedModel format)
2111            `tf.saved_model.SaveOptions` object that specifies options for
2112            saving to SavedModel.
2113        save_traces: (only applies to SavedModel format) When enabled, the
2114            SavedModel will store the function traces for each layer. This
2115            can be disabled, so that only the configs of each layer are stored.
2116            Defaults to `True`. Disabling this will decrease serialization time
2117            and reduce file size, but it requires that all custom layers/models
2118            implement a `get_config()` method.
2119
2120    Example:
2121
2122    ```python
2123    from keras.models import load_model
2124
2125    model.save('my_model.h5')  # creates a HDF5 file 'my_model.h5'
2126    del model  # deletes the existing model
2127
2128    # returns a compiled model
2129    # identical to the previous one
2130    model = load_model('my_model.h5')
2131    ```
2132    """
2133    # pylint: enable=line-too-long
2134    save.save_model(self, filepath, overwrite, include_optimizer, save_format,
2135                    signatures, options, save_traces)
2136
2137  def save_weights(self,
2138                   filepath,
2139                   overwrite=True,
2140                   save_format=None,
2141                   options=None):
2142    """Saves all layer weights.
2143
2144    Either saves in HDF5 or in TensorFlow format based on the `save_format`
2145    argument.
2146
2147    When saving in HDF5 format, the weight file has:
2148      - `layer_names` (attribute), a list of strings
2149          (ordered names of model layers).
2150      - For every layer, a `group` named `layer.name`
2151          - For every such layer group, a group attribute `weight_names`,
2152              a list of strings
2153              (ordered names of weights tensor of the layer).
2154          - For every weight in the layer, a dataset
2155              storing the weight value, named after the weight tensor.
2156
2157    When saving in TensorFlow format, all objects referenced by the network are
2158    saved in the same format as `tf.train.Checkpoint`, including any `Layer`
2159    instances or `Optimizer` instances assigned to object attributes. For
2160    networks constructed from inputs and outputs using `tf.keras.Model(inputs,
2161    outputs)`, `Layer` instances used by the network are tracked/saved
2162    automatically. For user-defined classes which inherit from `tf.keras.Model`,
2163    `Layer` instances must be assigned to object attributes, typically in the
2164    constructor. See the documentation of `tf.train.Checkpoint` and
2165    `tf.keras.Model` for details.
2166
2167    While the formats are the same, do not mix `save_weights` and
2168    `tf.train.Checkpoint`. Checkpoints saved by `Model.save_weights` should be
2169    loaded using `Model.load_weights`. Checkpoints saved using
2170    `tf.train.Checkpoint.save` should be restored using the corresponding
2171    `tf.train.Checkpoint.restore`. Prefer `tf.train.Checkpoint` over
2172    `save_weights` for training checkpoints.
2173
2174    The TensorFlow format matches objects and variables by starting at a root
2175    object, `self` for `save_weights`, and greedily matching attribute
2176    names. For `Model.save` this is the `Model`, and for `Checkpoint.save` this
2177    is the `Checkpoint` even if the `Checkpoint` has a model attached. This
2178    means saving a `tf.keras.Model` using `save_weights` and loading into a
2179    `tf.train.Checkpoint` with a `Model` attached (or vice versa) will not match
2180    the `Model`'s variables. See the [guide to training
2181    checkpoints](https://www.tensorflow.org/guide/checkpoint) for details
2182    on the TensorFlow format.
2183
2184    Args:
2185        filepath: String or PathLike, path to the file to save the weights to.
2186            When saving in TensorFlow format, this is the prefix used for
2187            checkpoint files (multiple files are generated). Note that the '.h5'
2188            suffix causes weights to be saved in HDF5 format.
2189        overwrite: Whether to silently overwrite any existing file at the
2190            target location, or provide the user with a manual prompt.
2191        save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
2192            '.keras' will default to HDF5 if `save_format` is `None`. Otherwise
2193            `None` defaults to 'tf'.
2194        options: Optional `tf.train.CheckpointOptions` object that specifies
2195            options for saving weights.
2196
2197    Raises:
2198        ImportError: If h5py is not available when attempting to save in HDF5
2199            format.
2200        ValueError: For invalid/unknown format arguments.
2201    """
2202    self._assert_weights_created()
2203    filepath = path_to_string(filepath)
2204    filepath_is_h5 = saving_utils.is_hdf5_filepath(filepath)
2205    if save_format is None:
2206      if filepath_is_h5:
2207        save_format = 'h5'
2208      else:
2209        save_format = 'tf'
2210    else:
2211      user_format = save_format.lower().strip()
2212      if user_format in ('tensorflow', 'tf'):
2213        save_format = 'tf'
2214      elif user_format in ('hdf5', 'h5', 'keras'):
2215        save_format = 'h5'
2216      else:
2217        raise ValueError(
2218            'Unknown format "%s". Was expecting one of {"tf", "h5"}.' % (
2219                save_format,))
2220    if save_format == 'tf' and filepath_is_h5:
2221      raise ValueError(
2222          ('save_weights got save_format="tf"/"tensorflow", but the '
2223           'filepath ("%s") looks like an HDF5 file. Omit the ".h5"/".keras" '
2224           'when saving in TensorFlow format.')
2225          % filepath)
2226
2227    if save_format == 'h5' and h5py is None:
2228      raise ImportError(
2229          '`save_weights` requires h5py when saving in hdf5.')
2230    if save_format == 'tf':
2231      check_filepath = filepath + '.index'
2232    else:
2233      check_filepath = filepath
2234    # If file exists and should not be overwritten:
2235    if not overwrite and os.path.isfile(check_filepath):
2236      proceed = ask_to_proceed_with_overwrite(check_filepath)
2237      if not proceed:
2238        return
2239    if save_format == 'h5':
2240      with h5py.File(filepath, 'w') as f:
2241        hdf5_format.save_weights_to_hdf5_group(f, self.layers)
2242    else:
2243      if context.executing_eagerly():
2244        session = None
2245      else:
2246        session = backend.get_session()
2247      self._trackable_saver.save(filepath, session=session, options=options)
2248      # Record this checkpoint so it's visible from tf.train.latest_checkpoint.
2249      checkpoint_management.update_checkpoint_state_internal(
2250          save_dir=os.path.dirname(filepath),
2251          model_checkpoint_path=filepath,
2252          save_relative_paths=True,
2253          all_model_checkpoint_paths=[filepath])
2254
2255  def load_weights(self,
2256                   filepath,
2257                   by_name=False,
2258                   skip_mismatch=False,
2259                   options=None):
2260    """Loads all layer weights, either from a TensorFlow or an HDF5 weight file.
2261
2262    If `by_name` is False weights are loaded based on the network's
2263    topology. This means the architecture should be the same as when the weights
2264    were saved.  Note that layers that don't have weights are not taken into
2265    account in the topological ordering, so adding or removing layers is fine as
2266    long as they don't have weights.
2267
2268    If `by_name` is True, weights are loaded into layers only if they share the
2269    same name. This is useful for fine-tuning or transfer-learning models where
2270    some of the layers have changed.
2271
2272    Only topological loading (`by_name=False`) is supported when loading weights
2273    from the TensorFlow format. Note that topological loading differs slightly
2274    between TensorFlow and HDF5 formats for user-defined classes inheriting from
2275    `tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the
2276    TensorFlow format loads based on the object-local names of attributes to
2277    which layers are assigned in the `Model`'s constructor.
2278
2279    Args:
2280        filepath: String, path to the weights file to load. For weight files in
2281            TensorFlow format, this is the file prefix (the same as was passed
2282            to `save_weights`). This can also be a path to a SavedModel
2283            saved from `model.save`.
2284        by_name: Boolean, whether to load weights by name or by topological
2285            order. Only topological loading is supported for weight files in
2286            TensorFlow format.
2287        skip_mismatch: Boolean, whether to skip loading of layers where there is
2288            a mismatch in the number of weights, or a mismatch in the shape of
2289            the weight (only valid when `by_name=True`).
2290        options: Optional `tf.train.CheckpointOptions` object that specifies
2291            options for loading weights.
2292
2293    Returns:
2294        When loading a weight file in TensorFlow format, returns the same status
2295        object as `tf.train.Checkpoint.restore`. When graph building, restore
2296        ops are run automatically as soon as the network is built (on first call
2297        for user-defined classes inheriting from `Model`, immediately if it is
2298        already built).
2299
2300        When loading weights in HDF5 format, returns `None`.
2301
2302    Raises:
2303        ImportError: If h5py is not available and the weight file is in HDF5
2304            format.
2305        ValueError: If `skip_mismatch` is set to `True` when `by_name` is
2306          `False`.
2307    """
2308    if backend.is_tpu_strategy(self._distribution_strategy):
2309      if (self._distribution_strategy.extended.steps_per_run > 1 and
2310          (not saving_utils.is_hdf5_filepath(filepath))):
2311        raise ValueError('Load weights is not yet supported with TPUStrategy '
2312                         'with steps_per_run greater than 1.')
2313    if skip_mismatch and not by_name:
2314      raise ValueError(
2315          'When calling model.load_weights, skip_mismatch can only be set to '
2316          'True when by_name is True.')
2317
2318    filepath, save_format = _detect_save_format(filepath)
2319    if save_format == 'tf':
2320      status = self._trackable_saver.restore(filepath, options)
2321      if by_name:
2322        raise NotImplementedError(
2323            'Weights may only be loaded based on topology into Models when '
2324            'loading TensorFlow-formatted weights (got by_name=True to '
2325            'load_weights).')
2326      if not context.executing_eagerly():
2327        session = backend.get_session()
2328        # Restore existing variables (if any) immediately, and set up a
2329        # streaming restore for any variables created in the future.
2330        trackable_utils.streaming_restore(status=status, session=session)
2331      status.assert_nontrivial_match()
2332    else:
2333      status = None
2334      if h5py is None:
2335        raise ImportError(
2336            '`load_weights` requires h5py when loading weights from HDF5.')
2337      if not self._is_graph_network and not self.built:
2338        raise ValueError(
2339            'Unable to load weights saved in HDF5 format into a subclassed '
2340            'Model which has not created its variables yet. Call the Model '
2341            'first, then load the weights.')
2342      self._assert_weights_created()
2343      with h5py.File(filepath, 'r') as f:
2344        if 'layer_names' not in f.attrs and 'model_weights' in f:
2345          f = f['model_weights']
2346        if by_name:
2347          hdf5_format.load_weights_from_hdf5_group_by_name(
2348              f, self.layers, skip_mismatch=skip_mismatch)
2349        else:
2350          hdf5_format.load_weights_from_hdf5_group(f, self.layers)
2351
2352    # Perform any layer defined finalization of the layer state.
2353    for layer in self.layers:
2354      layer.finalize_state()
2355    return status
2356
2357  def _updated_config(self):
2358    """Util shared between different serialization methods.
2359
2360    Returns:
2361        Model config with Keras version information added.
2362    """
2363    from tensorflow.python.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
2364
2365    config = self.get_config()
2366    model_config = {
2367        'class_name': self.__class__.__name__,
2368        'config': config,
2369        'keras_version': keras_version,
2370        'backend': backend.backend()
2371    }
2372    return model_config
2373
2374  def get_config(self):
2375    raise NotImplementedError
2376
2377  @classmethod
2378  def from_config(cls, config, custom_objects=None):
2379    # `from_config` assumes `cls` is either `Functional` or a child class of
2380    # `Functional`. In the case that `cls` is meant to behave like a child class
2381    # of `Functional` but only inherits from the `Model` class, we have to call
2382    # `cls(...)` instead of `Functional.from_config`.
2383    from tensorflow.python.keras.engine import functional  # pylint: disable=g-import-not-at-top
2384    with generic_utils.SharedObjectLoadingScope():
2385      input_tensors, output_tensors, created_layers = (
2386          functional.reconstruct_from_config(config, custom_objects))
2387      # Initialize a model belonging to `cls`, which can be user-defined or
2388      # `Functional`.
2389      model = cls(inputs=input_tensors, outputs=output_tensors,
2390                  name=config.get('name'))
2391      functional.connect_ancillary_layers(model, created_layers)
2392      return model
2393
2394  def to_json(self, **kwargs):
2395    """Returns a JSON string containing the network configuration.
2396
2397    To load a network from a JSON save file, use
2398    `keras.models.model_from_json(json_string, custom_objects={})`.
2399
2400    Args:
2401        **kwargs: Additional keyword arguments
2402            to be passed to `json.dumps()`.
2403
2404    Returns:
2405        A JSON string.
2406    """
2407    model_config = self._updated_config()
2408    return json.dumps(
2409        model_config, default=json_utils.get_json_type, **kwargs)
2410
2411  def to_yaml(self, **kwargs):
2412    """Returns a yaml string containing the network configuration.
2413
2414    Note: Since TF 2.6, this method is no longer supported and will raise a
2415    RuntimeError.
2416
2417    To load a network from a yaml save file, use
2418    `keras.models.model_from_yaml(yaml_string, custom_objects={})`.
2419
2420    `custom_objects` should be a dictionary mapping
2421    the names of custom losses / layers / etc to the corresponding
2422    functions / classes.
2423
2424    Args:
2425        **kwargs: Additional keyword arguments
2426            to be passed to `yaml.dump()`.
2427
2428    Returns:
2429        A YAML string.
2430
2431    Raises:
2432        RuntimeError: announces that the method poses a security risk
2433    """
2434    raise RuntimeError(
2435        'Method `model.to_yaml()` has been removed due to security risk of '
2436        'arbitrary code execution. Please use `model.to_json()` instead.'
2437    )
2438
2439  def reset_states(self):
2440    for layer in self.layers:
2441      if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
2442        layer.reset_states()
2443
2444  @property
2445  @doc_controls.do_not_generate_docs
2446  def state_updates(self):
2447    """Deprecated, do NOT use!
2448
2449    Returns the `updates` from all layers that are stateful.
2450
2451    This is useful for separating training updates and
2452    state updates, e.g. when we need to update a layer's internal state
2453    during prediction.
2454
2455    Returns:
2456        A list of update ops.
2457    """
2458    warnings.warn('`Model.state_updates` will be removed in a future version. '
2459                  'This property should not be used in TensorFlow 2.0, '
2460                  'as `updates` are applied automatically.')
2461    state_updates = []
2462    for layer in self.layers:
2463      if getattr(layer, 'stateful', False):
2464        if hasattr(layer, 'updates'):
2465          state_updates += layer.updates
2466    return state_updates
2467
2468  @property
2469  def weights(self):
2470    """Returns the list of all layer variables/weights.
2471
2472    Note: This will not track the weights of nested `tf.Modules` that are not
2473    themselves Keras layers.
2474
2475    Returns:
2476      A list of variables.
2477    """
2478    return self._dedup_weights(self._undeduplicated_weights)
2479
2480  @property
2481  def _undeduplicated_weights(self):
2482    """Returns the undeduplicated list of all layer variables/weights."""
2483    self._assert_weights_created()
2484    weights = []
2485    for layer in self._self_tracked_trackables:
2486      weights += layer.variables
2487    weights += (self._trainable_weights + self._non_trainable_weights)
2488    return weights
2489
2490  def summary(self, line_length=None, positions=None, print_fn=None):
2491    """Prints a string summary of the network.
2492
2493    Args:
2494        line_length: Total length of printed lines
2495            (e.g. set this to adapt the display to different
2496            terminal window sizes).
2497        positions: Relative or absolute positions of log elements
2498            in each line. If not provided,
2499            defaults to `[.33, .55, .67, 1.]`.
2500        print_fn: Print function to use. Defaults to `print`.
2501            It will be called on each line of the summary.
2502            You can set it to a custom function
2503            in order to capture the string summary.
2504
2505    Raises:
2506        ValueError: if `summary()` is called before the model is built.
2507    """
2508    if not self.built:
2509      raise ValueError('This model has not yet been built. '
2510                       'Build the model first by calling `build()` or calling '
2511                       '`fit()` with some data, or specify '
2512                       'an `input_shape` argument in the first layer(s) for '
2513                       'automatic build.')
2514    layer_utils.print_summary(self,
2515                              line_length=line_length,
2516                              positions=positions,
2517                              print_fn=print_fn)
2518
2519  @property
2520  def layers(self):
2521    return list(self._flatten_layers(include_self=False, recursive=False))
2522
2523  def get_layer(self, name=None, index=None):
2524    """Retrieves a layer based on either its name (unique) or index.
2525
2526    If `name` and `index` are both provided, `index` will take precedence.
2527    Indices are based on order of horizontal graph traversal (bottom-up).
2528
2529    Args:
2530        name: String, name of layer.
2531        index: Integer, index of layer.
2532
2533    Returns:
2534        A layer instance.
2535
2536    Raises:
2537        ValueError: In case of invalid layer name or index.
2538    """
2539    # TODO(fchollet): We could build a dictionary based on layer names
2540    # since they are constant, but we have not done that yet.
2541    if index is not None and name is not None:
2542      raise ValueError('Provide only a layer name or a layer index.')
2543
2544    if index is not None:
2545      if len(self.layers) <= index:
2546        raise ValueError('Was asked to retrieve layer at index ' + str(index) +
2547                         ' but model only has ' + str(len(self.layers)) +
2548                         ' layers.')
2549      else:
2550        return self.layers[index]
2551
2552    if name is not None:
2553      for layer in self.layers:
2554        if layer.name == name:
2555          return layer
2556      raise ValueError('No such layer: ' + name + '.')
2557    raise ValueError('Provide either a layer name or layer index.')
2558
2559  @trackable.no_automatic_dependency_tracking
2560  def _set_save_spec(self, inputs):
2561    if self._saved_model_inputs_spec is not None:
2562      return  # Already set.
2563
2564    input_names = self.input_names
2565    if not input_names:
2566      input_names = compile_utils.create_pseudo_input_names(inputs)
2567
2568    flat_inputs = nest.flatten(inputs)
2569    specs = []
2570    for name, tensor in zip(input_names, flat_inputs):
2571      specs.append(
2572          tf_utils.get_tensor_spec(tensor, dynamic_batch=False, name=name))
2573    specs = nest.pack_sequence_as(inputs, specs)
2574
2575    self._saved_model_inputs_spec = specs
2576
2577    # Store the input shapes
2578    if (self.__class__.__name__ == 'Sequential' and
2579        self._build_input_shape is None):
2580      self._build_input_shape = nest.map_structure(
2581          lambda x: None if x is None else x.shape, specs)
2582
2583  def _assert_weights_created(self):
2584    """Asserts that all the weights for the model have been created.
2585
2586    For a non-dynamic model, the weights must already be created after the
2587    layer has been called. For a dynamic model, the exact list of weights can
2588    never be known for certain since it may change at any time during execution.
2589
2590    We run this check right before accessing weights or getting the Numpy value
2591    for the current weights. Otherwise, if the layer has never been called,
2592    the user would just get an empty list, which is misleading.
2593
2594    Raises:
2595      ValueError: if the weights of the network has not yet been created.
2596    """
2597    if self.dynamic:
2598      return
2599
2600    if ('build' in self.__class__.__dict__ and
2601        self.__class__ != Model and
2602        not self.built):
2603      # For any model that has customized build() method but hasn't
2604      # been invoked yet, this will cover both sequential and subclass model.
2605      # Also make sure to exclude Model class itself which has build() defined.
2606      raise ValueError('Weights for model %s have not yet been created. '
2607                       'Weights are created when the Model is first called on '
2608                       'inputs or `build()` is called with an `input_shape`.' %
2609                       self.name)
2610
2611  def _check_call_args(self, method_name):
2612    """Check that `call` has only one positional arg."""
2613    # Always allow first arg, regardless of arg name.
2614    fullargspec = self._call_full_argspec
2615    if fullargspec.defaults:
2616      positional_args = fullargspec.args[:-len(fullargspec.defaults)]
2617    else:
2618      positional_args = fullargspec.args
2619    if 'training' in positional_args:
2620      positional_args.remove('training')
2621
2622    # self and first arg can be positional.
2623    if len(positional_args) > 2:
2624      extra_args = positional_args[2:]
2625      raise ValueError(
2626          'Models passed to `' + method_name + '` can only have `training` '
2627          'and the first argument in `call` as positional arguments, '
2628          'found: ' + str(extra_args) + '.')
2629
2630  def _validate_compile(self, optimizer, metrics, **kwargs):
2631    """Performs validation checks for the default `compile`."""
2632    if any(
2633        isinstance(opt, optimizer_v1.Optimizer)
2634        for opt in nest.flatten(optimizer)):
2635      raise ValueError(
2636          '`tf.compat.v1.keras` Optimizer (', optimizer, ') is '
2637          'not supported when eager execution is enabled. Use a '
2638          '`tf.keras` Optimizer instead, or disable eager '
2639          'execution.')
2640
2641    kwargs.pop('cloning', None)  # Legacy DistStrat argument, never used.
2642    kwargs.pop('experimental_run_tf_function', None)  # Always `True`.
2643    if kwargs.pop('distribute', None) is not None:
2644      raise ValueError(
2645          'Distribute argument in compile is not available in TF 2.0 please '
2646          'create the model under the distribution strategy scope.')
2647    if kwargs.pop('target_tensors', None) is not None:
2648      raise ValueError(
2649          'target_tensors argument is not supported when executing eagerly.')
2650    invalid_kwargs = set(kwargs) - {'sample_weight_mode'}
2651    if invalid_kwargs:
2652      raise TypeError('Invalid keyword argument(s) in `compile`: %s' %
2653                      (invalid_kwargs,))
2654
2655    # Model must be created and compiled with the same DistStrat.
2656    if self.built and ds_context.has_strategy():
2657      strategy = ds_context.get_strategy()
2658      for v in self.variables:
2659        if not strategy.extended.variable_created_in_scope(v):
2660          raise ValueError(
2661              'Variable (%s) was not created in the distribution strategy '
2662              'scope of (%s). It is most likely due to not all layers or '
2663              'the model or optimizer being created outside the distribution '
2664              'strategy scope. Try to make sure your code looks similar '
2665              'to the following.\n'
2666              'with strategy.scope():\n'
2667              '  model=_create_model()\n'
2668              '  model.compile(...)' % (v, strategy))
2669
2670    # Model metrics must be created in the same distribution strategy scope
2671    # as the model.
2672    strategy = self.distribute_strategy
2673    for metric in nest.flatten(metrics):
2674      for v in getattr(metric, 'variables', []):
2675        if not strategy.extended.variable_created_in_scope(v):
2676          raise ValueError(
2677              'Metric (%s) passed to model.compile was created inside of a '
2678              'different distribution strategy scope than the model. All '
2679              'metrics must be created in the same distribution strategy '
2680              'scope as the model (in this case %s). If you pass in a string '
2681              'identifier for a metric to compile the metric will '
2682              'automatically be created in the correct distribution '
2683              'strategy scope.' % (metric, strategy)
2684          )
2685
2686    # Model metrics must be created in the same distribution strategy scope
2687    # as the model.
2688    for opt in nest.flatten(optimizer):
2689      for v in getattr(opt, '_weights', []):
2690        if not strategy.extended.variable_created_in_scope(v):
2691          raise ValueError(
2692              'Optimizer (%s) passed to model.compile was created inside of a '
2693              'different distribution strategy scope than the model. All '
2694              'optimizers must be created in the same distribution strategy '
2695              'scope as the model (in this case %s). If you pass in a string '
2696              'identifier for an optimizer to compile the optimizer will '
2697              'automatically be created in the correct distribution '
2698              'strategy scope.' % (opt, strategy))
2699
2700  def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch):
2701    """Maybe load initial epoch from ckpt considering possible worker recovery.
2702
2703    Refer to tensorflow/python/keras/distribute/worker_training_state.py
2704    for more information.
2705
2706    Args:
2707      initial_epoch: The original initial_epoch user passes in in `fit()`.
2708
2709    Returns:
2710      If the training is recovering from previous failure under multi-worker
2711      training setting, return the epoch the training is supposed to continue
2712      at. Otherwise, return the `initial_epoch` the user passes in.
2713    """
2714    if self._training_state is not None:
2715      return self._training_state.maybe_load_initial_epoch_from_ckpt(
2716          initial_epoch, mode=ModeKeys.TRAIN)
2717    return initial_epoch
2718
2719  def _assert_compile_was_called(self):
2720    # Checks whether `compile` has been called. If it has been called,
2721    # then the optimizer is set. This is different from whether the
2722    # model is compiled
2723    # (i.e. whether the model is built and its inputs/outputs are set).
2724    if not self._is_compiled:
2725      raise RuntimeError('You must compile your model before '
2726                         'training/testing. '
2727                         'Use `model.compile(optimizer, loss)`.')
2728
2729  def _set_inputs(self, inputs, outputs=None, training=None):
2730    """This method is for compat with Modelv1. Only inputs are needed here."""
2731    self._set_save_spec(inputs)
2732
2733  @property
2734  def _trackable_saved_model_saver(self):
2735    return model_serialization.ModelSavedModelSaver(self)
2736
2737  def _list_functions_for_serialization(self, serialization_cache):
2738    # SavedModel needs to ignore the execution functions.
2739    train_function = self.train_function
2740    test_function = self.test_function
2741    predict_function = self.predict_function
2742    train_tf_function = self.train_tf_function
2743    self.train_function = None
2744    self.test_function = None
2745    self.predict_function = None
2746    self.train_tf_function = None
2747    functions = super(
2748        Model, self)._list_functions_for_serialization(serialization_cache)
2749    self.train_function = train_function
2750    self.test_function = test_function
2751    self.predict_function = predict_function
2752    self.train_tf_function = train_tf_function
2753    return functions
2754
2755  def _should_eval(self, epoch, validation_freq):
2756    epoch = epoch + 1  # one-index the user-facing epoch.
2757    if isinstance(validation_freq, int):
2758      return epoch % validation_freq == 0
2759    elif isinstance(validation_freq, list):
2760      return epoch in validation_freq
2761    else:
2762      raise ValueError('Expected `validation_freq` to be a list or int.')
2763
2764  ######################################################################
2765  # Functions below exist only as v1 / v2 compatibility shims.
2766  ######################################################################
2767
2768  def _get_compile_args(self, user_metrics=True):
2769    """Used for saving or cloning a Model.
2770
2771    Args:
2772      user_metrics: Whether to return user-supplied metrics or `Metric` objects.
2773        Defaults to returning the user-supplied metrics.
2774
2775    Returns:
2776      Dictionary of arguments that were used when compiling the model.
2777    """
2778    self._assert_compile_was_called()
2779    # pylint: disable=protected-access
2780
2781    saved_metrics = self.compiled_metrics._user_metrics
2782    saved_weighted_metrics = self.compiled_metrics._user_weighted_metrics
2783
2784    if not user_metrics:
2785      if saved_metrics is not None:
2786        saved_metrics = self.compiled_metrics._metrics
2787      if saved_weighted_metrics is not None:
2788        saved_weighted_metrics = self.compiled_metrics._weighted_metrics
2789
2790    compile_args = {
2791        'optimizer': self.optimizer,
2792        'loss': self.compiled_loss._user_losses,
2793        'metrics': saved_metrics,
2794        'weighted_metrics': saved_weighted_metrics,
2795        'loss_weights': self.compiled_loss._user_loss_weights,
2796    }
2797    # pylint: enable=protected-access
2798    return compile_args
2799
2800  def _get_callback_model(self):
2801    return self
2802
2803  def _in_multi_worker_mode(self):
2804    return self.distribute_strategy.extended._in_multi_worker_mode()  # pylint: disable=protected-access
2805
2806  @property
2807  def _compile_was_called(self):
2808    return self._is_compiled
2809
2810
2811def reduce_per_replica(values, strategy, reduction='first'):
2812  """Reduce PerReplica objects.
2813
2814  Args:
2815    values: Structure of `PerReplica` objects or `Tensor`s. `Tensor`s are
2816      returned as-is.
2817    strategy: `tf.distribute.Strategy` object.
2818    reduction: One of 'first', 'concat'.
2819
2820  Returns:
2821    Structure of `Tensor`s.
2822  """
2823
2824  def _reduce(v):
2825    """Reduce a single `PerReplica` object."""
2826    if reduction == 'concat' and _collective_all_reduce_multi_worker(strategy):
2827      return _multi_worker_concat(v, strategy)
2828    if not _is_per_replica_instance(v):
2829      return v
2830    elif reduction == 'first':
2831      return strategy.unwrap(v)[0]
2832    elif reduction == 'concat':
2833      if _is_tpu_multi_host(strategy):
2834        return _tpu_multi_host_concat(v, strategy)
2835      else:
2836        return concat(strategy.unwrap(v))
2837    else:
2838      raise ValueError('`reduction` must be "first" or "concat".')
2839
2840  return nest.map_structure(_reduce, values)
2841
2842
2843def concat(tensors, axis=0):
2844  """Concats `tensor`s along `axis`."""
2845  if isinstance(tensors[0], sparse_tensor.SparseTensor):
2846    return sparse_ops.sparse_concat_v2(axis=axis, sp_inputs=tensors)
2847  return array_ops.concat(tensors, axis=axis)
2848
2849
2850def _is_tpu_multi_host(strategy):
2851  return (backend.is_tpu_strategy(strategy) and
2852          strategy.extended.num_hosts > 1)
2853
2854
2855def _tpu_multi_host_concat(v, strategy):
2856  """Correctly order TPU PerReplica objects."""
2857  replicas = strategy.unwrap(v)
2858  # When distributed datasets are created from Tensors / NumPy,
2859  # TPUStrategy.experimental_distribute_dataset shards data in
2860  # (Replica, Host) order, and TPUStrategy.unwrap returns it in
2861  # (Host, Replica) order.
2862  # TODO(b/150317897): Figure out long-term plan here.
2863  num_replicas_per_host = strategy.extended.num_replicas_per_host
2864  ordered_replicas = []
2865  for replica_id in range(num_replicas_per_host):
2866    ordered_replicas += replicas[replica_id::num_replicas_per_host]
2867  return concat(ordered_replicas)
2868
2869
2870def _collective_all_reduce_multi_worker(strategy):
2871  return (isinstance(strategy,
2872                     collective_all_reduce_strategy.CollectiveAllReduceStrategy)
2873         ) and strategy.extended._in_multi_worker_mode()  # pylint: disable=protected-access
2874
2875
2876# TODO(wxinyi): merge this with _tpu_multi_host_concat once we have all_gather
2877# for all strategies
2878def _multi_worker_concat(v, strategy):
2879  """Order PerReplica objects for CollectiveAllReduceStrategy and concat."""
2880  replicas = strategy.gather(v, axis=0)
2881  # v might not have the same shape on different replicas
2882  if _is_per_replica_instance(v):
2883    shapes = array_ops.concat([
2884        array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0)
2885        for single_value in v.values
2886    ],
2887                              axis=0)
2888    all_shapes = strategy.gather(shapes, axis=0)
2889  else:
2890    # v is a tensor. This may happen when, say, we have 2x1 multi-worker.
2891    all_shapes = strategy.gather(
2892        array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0), axis=0)
2893
2894  replicas = array_ops.split(
2895      replicas,
2896      num_or_size_splits=all_shapes,
2897      num=strategy.num_replicas_in_sync)
2898  ordered_replicas = []
2899  num_replicas_per_worker = len(strategy.extended.worker_devices)
2900  for replica_id in range(num_replicas_per_worker):
2901    ordered_replicas += replicas[replica_id::num_replicas_per_worker]
2902  return concat(ordered_replicas)
2903
2904
2905def _is_scalar(x):
2906  return isinstance(x, (ops.Tensor, variables.Variable)) and x.shape.rank == 0
2907
2908
2909def write_scalar_summaries(logs, step):
2910  for name, value in logs.items():
2911    if _is_scalar(value):
2912      summary_ops_v2.scalar('batch_' + name, value, step=step)
2913
2914
2915def _minimum_control_deps(outputs):
2916  """Returns the minimum control dependencies to ensure step succeeded."""
2917  if context.executing_eagerly():
2918    return []  # Control dependencies not needed.
2919  outputs = nest.flatten(outputs, expand_composites=True)
2920  for out in outputs:
2921    # Variables can't be control dependencies.
2922    if not isinstance(out, variables.Variable):
2923      return [out]  # Return first Tensor or Op from outputs.
2924  return []  # No viable Tensor or Op to use for control deps.
2925
2926
2927def _disallow_inside_tf_function(method_name):
2928  if ops.inside_function():
2929    error_msg = (
2930        'Detected a call to `Model.{method_name}` inside a `tf.function`. '
2931        '`Model.{method_name} is a high-level endpoint that manages its own '
2932        '`tf.function`. Please move the call to `Model.{method_name}` outside '
2933        'of all enclosing `tf.function`s. Note that you can call a `Model` '
2934        'directly on `Tensor`s inside a `tf.function` like: `model(x)`.'
2935    ).format(method_name=method_name)
2936    raise RuntimeError(error_msg)
2937
2938
2939def _detect_save_format(filepath):
2940  """Returns path to weights file and save format."""
2941
2942  filepath = path_to_string(filepath)
2943  if saving_utils.is_hdf5_filepath(filepath):
2944    return filepath, 'h5'
2945
2946  # Filepath could be a TensorFlow checkpoint file prefix or SavedModel
2947  # directory. It's possible for filepath to be both a prefix and directory.
2948  # Prioritize checkpoint over SavedModel.
2949  if _is_readable_tf_checkpoint(filepath):
2950    save_format = 'tf'
2951  elif sm_loader.contains_saved_model(filepath):
2952    ckpt_path = os.path.join(filepath, sm_constants.VARIABLES_DIRECTORY,
2953                             sm_constants.VARIABLES_FILENAME)
2954    if _is_readable_tf_checkpoint(ckpt_path):
2955      filepath = ckpt_path
2956      save_format = 'tf'
2957    else:
2958      raise ValueError('Unable to load weights. filepath {} appears to be a '
2959                       'SavedModel directory, but checkpoint either doesn\'t '
2960                       'exist, or is incorrectly formatted.'.format(filepath))
2961  else:
2962    # Not a TensorFlow checkpoint. This filepath is likely an H5 file that
2963    # doesn't have the hdf5/keras extensions.
2964    save_format = 'h5'
2965  return filepath, save_format
2966
2967
2968def _is_readable_tf_checkpoint(filepath):
2969  try:
2970    py_checkpoint_reader.NewCheckpointReader(filepath)
2971    return True
2972  except errors_impl.DataLossError:
2973    # The checkpoint is not readable in TensorFlow format.
2974    return False
2975
2976
2977def flatten_metrics_in_order(logs, metrics_names):
2978  """Turns the `logs` dict into a list as per key order of `metrics_names`."""
2979  results = []
2980  for name in metrics_names:
2981    if name in logs:
2982      results.append(logs[name])
2983  for key in sorted(logs.keys()):
2984    if key not in metrics_names:
2985      results.append(logs[key])
2986  if len(results) == 1:
2987    return results[0]
2988  return results
2989
2990
2991def _is_per_replica_instance(obj):
2992  return (isinstance(obj, ds_values.DistributedValues) and
2993          isinstance(obj, composite_tensor.CompositeTensor))
2994
2995
2996def saver_with_op_caching(obj):
2997  if context.executing_eagerly():
2998    saveables_cache = None
2999  else:
3000    saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
3001  return trackable_utils.TrackableSaver(
3002      graph_view_lib.ObjectGraphView(
3003          weakref.ref(obj), saveables_cache=saveables_cache))
3004