1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Training-related part of the Keras engine.""" 16 17import copy 18import itertools 19import json 20import os 21import warnings 22import weakref 23 24from tensorflow.python.autograph.lang import directives 25from tensorflow.python.data.ops import dataset_ops 26from tensorflow.python.data.ops import options as options_lib 27from tensorflow.python.distribute import collective_all_reduce_strategy 28from tensorflow.python.distribute import distribution_strategy_context as ds_context 29from tensorflow.python.distribute import values as ds_values 30from tensorflow.python.distribute.coordinator import cluster_coordinator 31from tensorflow.python.eager import backprop 32from tensorflow.python.eager import context 33from tensorflow.python.eager import def_function 34from tensorflow.python.framework import composite_tensor 35from tensorflow.python.framework import errors 36from tensorflow.python.framework import errors_impl 37from tensorflow.python.framework import func_graph 38from tensorflow.python.framework import ops 39from tensorflow.python.framework import sparse_tensor 40from tensorflow.python.framework import tensor_shape 41from tensorflow.python.keras import backend 42from tensorflow.python.keras import callbacks as callbacks_module 43from tensorflow.python.keras import optimizer_v1 44from tensorflow.python.keras import optimizers 45from tensorflow.python.keras.engine import base_layer 46from tensorflow.python.keras.engine import base_layer_utils 47from tensorflow.python.keras.engine import compile_utils 48from tensorflow.python.keras.engine import data_adapter 49from tensorflow.python.keras.engine import training_utils 50from tensorflow.python.keras.mixed_precision import loss_scale_optimizer as lso 51from tensorflow.python.keras.mixed_precision import policy 52from tensorflow.python.keras.saving import hdf5_format 53from tensorflow.python.keras.saving import save 54from tensorflow.python.keras.saving import saving_utils 55from tensorflow.python.keras.saving.saved_model import json_utils 56from tensorflow.python.keras.saving.saved_model import model_serialization 57from tensorflow.python.keras.utils import generic_utils 58from tensorflow.python.keras.utils import layer_utils 59from tensorflow.python.keras.utils import object_identity 60from tensorflow.python.keras.utils import tf_utils 61from tensorflow.python.keras.utils import version_utils 62from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite 63from tensorflow.python.keras.utils.io_utils import path_to_string 64from tensorflow.python.keras.utils.mode_keys import ModeKeys 65from tensorflow.python.ops import array_ops 66from tensorflow.python.ops import math_ops 67from tensorflow.python.ops import sparse_ops 68from tensorflow.python.ops import summary_ops_v2 69from tensorflow.python.ops import variables 70from tensorflow.python.platform import tf_logging as logging 71from tensorflow.python.profiler import trace 72from tensorflow.python.saved_model import constants as sm_constants 73from tensorflow.python.saved_model import loader_impl as sm_loader 74from tensorflow.python.training import checkpoint_management 75from tensorflow.python.training import py_checkpoint_reader 76from tensorflow.python.training.tracking import base as trackable 77from tensorflow.python.training.tracking import graph_view as graph_view_lib 78from tensorflow.python.training.tracking import util as trackable_utils 79from tensorflow.python.util import nest 80from tensorflow.python.util import tf_decorator 81from tensorflow.python.util.tf_export import keras_export 82from tensorflow.tools.docs import doc_controls 83 84 85# pylint: disable=g-import-not-at-top 86try: 87 import h5py 88except ImportError: 89 h5py = None 90# pylint: enable=g-import-not-at-top 91 92 93def disable_multi_worker(method): 94 """Decorator that disallows multi-worker use of `method`.""" 95 96 def _method_wrapper(self, *args, **kwargs): 97 if self._in_multi_worker_mode(): # pylint: disable=protected-access 98 raise ValueError('{} is not supported in multi-worker mode.'.format( 99 method.__name__)) 100 return method(self, *args, **kwargs) 101 102 return tf_decorator.make_decorator( 103 target=method, decorator_func=_method_wrapper) 104 105 106def inject_functional_model_class(cls): 107 """Inject `Functional` into the hierarchy of this class if needed.""" 108 from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top 109 from tensorflow.python.keras.engine import training_v1 # pylint: disable=g-import-not-at-top 110 if cls == Model or cls == training_v1.Model: 111 return functional.Functional 112 # In case there is any multiple inheritance, we stop injecting the 113 # class if keras model is not in its class hierarchy. 114 if cls == object: 115 return object 116 117 cls.__bases__ = tuple(inject_functional_model_class(base) 118 for base in cls.__bases__) 119 # Trigger any `__new__` class swapping that needed to happen on `Functional` 120 # but did not because functional was not in the class hierarchy. 121 cls.__new__(cls) 122 123 return cls 124 125 126def is_functional_model_init_params(args, kwargs): 127 return (len(args) == 2 or 128 len(args) == 1 and 'outputs' in kwargs or 129 'inputs' in kwargs and 'outputs' in kwargs) 130 131 132@keras_export('keras.Model', 'keras.models.Model') 133class Model(base_layer.Layer, version_utils.ModelVersionSelector): 134 """`Model` groups layers into an object with training and inference features. 135 136 Args: 137 inputs: The input(s) of the model: a `keras.Input` object or list of 138 `keras.Input` objects. 139 outputs: The output(s) of the model. See Functional API example below. 140 name: String, the name of the model. 141 142 There are two ways to instantiate a `Model`: 143 144 1 - With the "Functional API", where you start from `Input`, 145 you chain layer calls to specify the model's forward pass, 146 and finally you create your model from inputs and outputs: 147 148 ```python 149 import tensorflow as tf 150 151 inputs = tf.keras.Input(shape=(3,)) 152 x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs) 153 outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x) 154 model = tf.keras.Model(inputs=inputs, outputs=outputs) 155 ``` 156 157 Note: Only dicts, lists, and tuples of input tensors are supported. Nested 158 inputs are not supported (e.g. lists of list or dicts of dict). 159 160 2 - By subclassing the `Model` class: in that case, you should define your 161 layers in `__init__` and you should implement the model's forward pass 162 in `call`. 163 164 ```python 165 import tensorflow as tf 166 167 class MyModel(tf.keras.Model): 168 169 def __init__(self): 170 super(MyModel, self).__init__() 171 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 172 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 173 174 def call(self, inputs): 175 x = self.dense1(inputs) 176 return self.dense2(x) 177 178 model = MyModel() 179 ``` 180 181 If you subclass `Model`, you can optionally have 182 a `training` argument (boolean) in `call`, which you can use to specify 183 a different behavior in training and inference: 184 185 ```python 186 import tensorflow as tf 187 188 class MyModel(tf.keras.Model): 189 190 def __init__(self): 191 super(MyModel, self).__init__() 192 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 193 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 194 self.dropout = tf.keras.layers.Dropout(0.5) 195 196 def call(self, inputs, training=False): 197 x = self.dense1(inputs) 198 if training: 199 x = self.dropout(x, training=training) 200 return self.dense2(x) 201 202 model = MyModel() 203 ``` 204 205 Once the model is created, you can config the model with losses and metrics 206 with `model.compile()`, train the model with `model.fit()`, or use the model 207 to do prediction with `model.predict()`. 208 """ 209 _TF_MODULE_IGNORED_PROPERTIES = frozenset( 210 itertools.chain(('_train_counter', '_test_counter', '_predict_counter', 211 '_steps_per_execution'), 212 base_layer.Layer._TF_MODULE_IGNORED_PROPERTIES)) # pylint: disable=protected-access 213 214 def __new__(cls, *args, **kwargs): 215 # Signature detection 216 if is_functional_model_init_params(args, kwargs) and cls == Model: 217 # Functional model 218 from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top 219 return functional.Functional(skip_init=True, *args, **kwargs) 220 else: 221 return super(Model, cls).__new__(cls, *args, **kwargs) 222 223 @trackable.no_automatic_dependency_tracking 224 def __init__(self, *args, **kwargs): 225 self._is_model_for_instrumentation = True 226 227 # Special case for Subclassed Functional Model, which we couldn't detect 228 # when __new__ is called. We only realize it is a functional model when it 229 # calls super.__init__ with input and output tensor. 230 from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top 231 if (is_functional_model_init_params(args, kwargs) and 232 not isinstance(self, functional.Functional)): 233 # Filter the kwargs for multiple inheritance. 234 supported_kwargs = ['inputs', 'outputs', 'name', 'trainable', 'skip_init'] 235 model_kwargs = {k: kwargs[k] for k in kwargs if k in supported_kwargs} 236 other_kwargs = {k: kwargs[k] for k in kwargs if k not in supported_kwargs} 237 inject_functional_model_class(self.__class__) 238 functional.Functional.__init__(self, *args, **model_kwargs) 239 240 # In case there is any multiple inheritance here, we need to call the 241 # __init__ for any class that appears after the Functional class. 242 clz_to_init = [] 243 found_functional_class = False 244 for clz in self.__class__.__bases__: 245 if issubclass(clz, functional.Functional): 246 found_functional_class = True 247 continue 248 if found_functional_class: 249 clz_to_init.append(clz) 250 251 if clz_to_init: 252 for clz in clz_to_init: 253 clz.__init__(self, *args, **other_kwargs) 254 elif other_kwargs: 255 # In case there are unused kwargs, we should raise an error to user, in 256 # case they have a typo in the param name. 257 raise TypeError( 258 'The following keyword arguments aren\'t supported: {}'.format( 259 other_kwargs)) 260 return 261 262 # The following are implemented as property functions: 263 # self.trainable_weights 264 # self.non_trainable_weights 265 # `inputs` / `outputs` will only appear in kwargs if either are misspelled. 266 generic_utils.validate_kwargs(kwargs, { 267 'trainable', 'dtype', 'dynamic', 'name', 'autocast', 'inputs', 'outputs' 268 }) 269 super(Model, self).__init__(**kwargs) 270 # By default, Model is a subclass model, which is not in graph network. 271 self._is_graph_network = False 272 273 self.inputs = None 274 self.outputs = None 275 self.input_names = None 276 self.output_names = None 277 # stop_training is used by callback to stop training when error happens 278 self.stop_training = False 279 self.history = None 280 # These objects are used in the default `Model.compile`. They are not 281 # guaranteed to be set after `Model.compile` is called, as users can 282 # override compile with custom logic. 283 self.compiled_loss = None 284 self.compiled_metrics = None 285 286 # This is True for Sequential networks and Functional networks. 287 self._compute_output_and_mask_jointly = False 288 289 # Don't reset compilation if already done. This may occur if calling 290 # `__init__` (or `_init_graph_network`) on an already-compiled model 291 # such as a Sequential model. Sequential models may need to rebuild 292 # themselves after compilation. 293 self._maybe_create_attribute('_is_compiled', False) 294 self._maybe_create_attribute('optimizer', None) 295 296 # Model must be created under scope of DistStrat it will be trained with. 297 if ds_context.has_strategy(): 298 self._distribution_strategy = ds_context.get_strategy() 299 else: 300 self._distribution_strategy = None 301 302 self._cluster_coordinator = None 303 304 # Defaults to value of `tf.config.experimental_functions_run_eagerly`. 305 self._run_eagerly = None 306 # Initialize cache attrs. 307 self._reset_compile_cache() 308 309 # Fault-tolerance handler. Set in `ModelCheckpoint`. 310 self._training_state = None 311 self._saved_model_inputs_spec = None 312 self._trackable_saver = saver_with_op_caching(self) 313 314 self._steps_per_execution = None 315 316 self._init_batch_counters() 317 self._base_model_initialized = True 318 319 @trackable.no_automatic_dependency_tracking 320 def _init_batch_counters(self): 321 # Untracked Variables, used to keep track of mini-batches seen in `fit`, 322 # `evaluate`, and `predict`. 323 agg = variables.VariableAggregationV2.ONLY_FIRST_REPLICA 324 self._train_counter = variables.Variable(0, dtype='int64', aggregation=agg) 325 self._test_counter = variables.Variable(0, dtype='int64', aggregation=agg) 326 self._predict_counter = variables.Variable( 327 0, dtype='int64', aggregation=agg) 328 329 def __setattr__(self, name, value): 330 if not getattr(self, '_self_setattr_tracking', True): 331 super(Model, self).__setattr__(name, value) 332 return 333 334 if all( 335 isinstance(v, (base_layer.Layer, variables.Variable)) or 336 base_layer_utils.has_weights(v) for v in nest.flatten(value)): 337 try: 338 self._base_model_initialized 339 except AttributeError: 340 raise RuntimeError( 341 'It looks like you are subclassing `Model` and you ' 342 'forgot to call `super().__init__()`.' 343 ' Always start with this line.') 344 345 super(Model, self).__setattr__(name, value) 346 347 @generic_utils.default 348 def build(self, input_shape): 349 """Builds the model based on input shapes received. 350 351 This is to be used for subclassed models, which do not know at instantiation 352 time what their inputs look like. 353 354 This method only exists for users who want to call `model.build()` in a 355 standalone way (as a substitute for calling the model on real data to 356 build it). It will never be called by the framework (and thus it will 357 never throw unexpected errors in an unrelated workflow). 358 359 Args: 360 input_shape: Single tuple, TensorShape, or list/dict of shapes, where 361 shapes are tuples, integers, or TensorShapes. 362 363 Raises: 364 ValueError: 365 1. In case of invalid user-provided data (not of type tuple, 366 list, TensorShape, or dict). 367 2. If the model requires call arguments that are agnostic 368 to the input shapes (positional or kwarg in call signature). 369 3. If not all layers were properly built. 370 4. If float type inputs are not supported within the layers. 371 372 In each of these cases, the user should build their model by calling it 373 on real tensor data. 374 """ 375 if self._is_graph_network: 376 super(Model, self).build(input_shape) 377 return 378 379 if input_shape is None: 380 raise ValueError('Input shape must be defined when calling build on a ' 381 'model subclass network.') 382 valid_types = (tuple, list, tensor_shape.TensorShape, dict) 383 if not isinstance(input_shape, valid_types): 384 raise ValueError('Specified input shape is not one of the valid types. ' 385 'Please specify a batch input shape of type tuple or ' 386 'list of input shapes. User provided ' 387 'input type: {}'.format(type(input_shape))) 388 389 if input_shape and not self.inputs: 390 # We create placeholders for the `None`s in the shape and build the model 391 # in a Graph. Since tf.Variable is compatible with both eager execution 392 # and graph building, the variables created after building the model in 393 # a Graph are still valid when executing eagerly. 394 if context.executing_eagerly(): 395 graph = func_graph.FuncGraph('build_graph') 396 else: 397 graph = backend.get_graph() 398 with graph.as_default(): 399 if (isinstance(input_shape, list) and 400 all(d is None or isinstance(d, int) for d in input_shape)): 401 input_shape = tuple(input_shape) 402 if isinstance(input_shape, list): 403 x = [base_layer_utils.generate_placeholders_from_shape(shape) 404 for shape in input_shape] 405 elif isinstance(input_shape, dict): 406 x = { 407 k: base_layer_utils.generate_placeholders_from_shape(shape) 408 for k, shape in input_shape.items() 409 } 410 else: 411 x = base_layer_utils.generate_placeholders_from_shape(input_shape) 412 413 kwargs = {} 414 call_signature = self._call_full_argspec 415 call_args = call_signature.args 416 # Exclude `self`, `inputs`, and any argument with a default value. 417 if len(call_args) > 2: 418 if call_signature.defaults: 419 call_args = call_args[2:-len(call_signature.defaults)] 420 else: 421 call_args = call_args[2:] 422 for arg in call_args: 423 if arg == 'training': 424 # Case where `training` is a positional arg with no default. 425 kwargs['training'] = False 426 else: 427 # Has invalid call signature with unknown positional arguments. 428 raise ValueError( 429 'Currently, you cannot build your model if it has ' 430 'positional or keyword arguments that are not ' 431 'inputs to the model, but are required for its ' 432 '`call` method. Instead, in order to instantiate ' 433 'and build your model, `call` your model on real ' 434 'tensor data with all expected call arguments.') 435 elif len(call_args) < 2: 436 # Signature without `inputs`. 437 raise ValueError('You can only call `build` on a model if its `call` ' 438 'method accepts an `inputs` argument.') 439 try: 440 self.call(x, **kwargs) 441 except (errors.InvalidArgumentError, TypeError): 442 raise ValueError('You cannot build your model by calling `build` ' 443 'if your layers do not support float type inputs. ' 444 'Instead, in order to instantiate and build your ' 445 'model, `call` your model on real tensor data (of ' 446 'the correct dtype).') 447 super(Model, self).build(input_shape) 448 449 @doc_controls.doc_in_current_and_subclasses 450 def call(self, inputs, training=None, mask=None): 451 """Calls the model on new inputs. 452 453 In this case `call` just reapplies 454 all ops in the graph to the new inputs 455 (e.g. build a new computational graph from the provided inputs). 456 457 Note: This method should not be called directly. It is only meant to be 458 overridden when subclassing `tf.keras.Model`. 459 To call a model on an input, always use the `__call__` method, 460 i.e. `model(inputs)`, which relies on the underlying `call` method. 461 462 Args: 463 inputs: Input tensor, or dict/list/tuple of input tensors. 464 training: Boolean or boolean scalar tensor, indicating whether to run 465 the `Network` in training mode or inference mode. 466 mask: A mask or list of masks. A mask can be 467 either a tensor or None (no mask). 468 469 Returns: 470 A tensor if there is a single output, or 471 a list of tensors if there are more than one outputs. 472 """ 473 raise NotImplementedError('When subclassing the `Model` class, you should ' 474 'implement a `call` method.') 475 476 def compile(self, 477 optimizer='rmsprop', 478 loss=None, 479 metrics=None, 480 loss_weights=None, 481 weighted_metrics=None, 482 run_eagerly=None, 483 steps_per_execution=None, 484 **kwargs): 485 """Configures the model for training. 486 487 Args: 488 optimizer: String (name of optimizer) or optimizer instance. See 489 `tf.keras.optimizers`. 490 loss: String (name of objective function), objective function or 491 `tf.keras.losses.Loss` instance. See `tf.keras.losses`. An objective 492 function is any callable with the signature `loss = fn(y_true, 493 y_pred)`, where y_true = ground truth values with shape = 494 `[batch_size, d0, .. dN]`, except sparse loss functions such as sparse 495 categorical crossentropy where shape = `[batch_size, d0, .. dN-1]`. 496 y_pred = predicted values with shape = `[batch_size, d0, .. dN]`. It 497 returns a weighted loss float tensor. If a custom `Loss` instance is 498 used and reduction is set to `None`, return value has the shape 499 `[batch_size, d0, .. dN-1]` i.e. per-sample or per-timestep loss 500 values; otherwise, it is a scalar. If the model has multiple outputs, 501 you can use a different loss on each output by passing a dictionary 502 or a list of losses. The loss value that will be minimized by the 503 model will then be the sum of all individual losses, unless 504 `loss_weights` is specified. 505 metrics: List of metrics to be evaluated by the model during training 506 and testing. Each of this can be a string (name of a built-in 507 function), function or a `tf.keras.metrics.Metric` instance. See 508 `tf.keras.metrics`. Typically you will use `metrics=['accuracy']`. A 509 function is any callable with the signature `result = fn(y_true, 510 y_pred)`. To specify different metrics for different outputs of a 511 multi-output model, you could also pass a dictionary, such as 512 `metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}`. 513 You can also pass a list to specify a metric or a list of metrics 514 for each output, such as `metrics=[['accuracy'], ['accuracy', 'mse']]` 515 or `metrics=['accuracy', ['accuracy', 'mse']]`. When you pass the 516 strings 'accuracy' or 'acc', we convert this to one of 517 `tf.keras.metrics.BinaryAccuracy`, 518 `tf.keras.metrics.CategoricalAccuracy`, 519 `tf.keras.metrics.SparseCategoricalAccuracy` based on the loss 520 function used and the model output shape. We do a similar 521 conversion for the strings 'crossentropy' and 'ce' as well. 522 loss_weights: Optional list or dictionary specifying scalar coefficients 523 (Python floats) to weight the loss contributions of different model 524 outputs. The loss value that will be minimized by the model will then 525 be the *weighted sum* of all individual losses, weighted by the 526 `loss_weights` coefficients. 527 If a list, it is expected to have a 1:1 mapping to the model's 528 outputs. If a dict, it is expected to map output names (strings) 529 to scalar coefficients. 530 weighted_metrics: List of metrics to be evaluated and weighted by 531 `sample_weight` or `class_weight` during training and testing. 532 run_eagerly: Bool. Defaults to `False`. If `True`, this `Model`'s 533 logic will not be wrapped in a `tf.function`. Recommended to leave 534 this as `None` unless your `Model` cannot be run inside a 535 `tf.function`. `run_eagerly=True` is not supported when using 536 `tf.distribute.experimental.ParameterServerStrategy`. 537 steps_per_execution: Int. Defaults to 1. The number of batches to 538 run during each `tf.function` call. Running multiple batches 539 inside a single `tf.function` call can greatly improve performance 540 on TPUs or small models with a large Python overhead. 541 At most, one full epoch will be run each 542 execution. If a number larger than the size of the epoch is passed, 543 the execution will be truncated to the size of the epoch. 544 Note that if `steps_per_execution` is set to `N`, 545 `Callback.on_batch_begin` and `Callback.on_batch_end` methods 546 will only be called every `N` batches 547 (i.e. before/after each `tf.function` execution). 548 **kwargs: Arguments supported for backwards compatibility only. 549 550 Raises: 551 ValueError: In case of invalid arguments for 552 `optimizer`, `loss` or `metrics`. 553 """ 554 with self.distribute_strategy.scope(): 555 if 'experimental_steps_per_execution' in kwargs: 556 logging.warning('The argument `steps_per_execution` is no longer ' 557 'experimental. Pass `steps_per_execution` instead of ' 558 '`experimental_steps_per_execution`.') 559 if not steps_per_execution: 560 steps_per_execution = kwargs.pop('experimental_steps_per_execution') 561 562 # When compiling from an already-serialized model, we do not want to 563 # reapply some processing steps (e.g. metric renaming for multi-output 564 # models, which have prefixes added for each corresponding output name). 565 from_serialized = kwargs.pop('from_serialized', False) 566 567 self._validate_compile(optimizer, metrics, **kwargs) 568 self._run_eagerly = run_eagerly 569 570 self.optimizer = self._get_optimizer(optimizer) 571 self.compiled_loss = compile_utils.LossesContainer( 572 loss, loss_weights, output_names=self.output_names) 573 self.compiled_metrics = compile_utils.MetricsContainer( 574 metrics, weighted_metrics, output_names=self.output_names, 575 from_serialized=from_serialized) 576 577 self._configure_steps_per_execution(steps_per_execution or 1) 578 579 # Initializes attrs that are reset each time `compile` is called. 580 self._reset_compile_cache() 581 self._is_compiled = True 582 583 self.loss = loss or {} # Backwards compat. 584 585 def _get_optimizer(self, optimizer): 586 """Wraps `optimizer` in `LossScaleOptimizer` if necessary.""" 587 # The deprecated PolicyV1 has a loss_scale, which we use for backwards 588 # compatibility to match TF 2.3 behavior. The new Policy does not have a 589 # loss_scale, so we use dynamic loss scaling if the mixed_float16 policy is 590 # used. 591 if isinstance(self._dtype_policy, policy.PolicyV1): 592 loss_scale = self._dtype_policy.loss_scale 593 elif self._dtype_policy.name == 'mixed_float16': 594 loss_scale = 'dynamic' 595 else: 596 loss_scale = None 597 598 def _get_single_optimizer(opt): 599 opt = optimizers.get(opt) 600 if (loss_scale is not None and 601 not isinstance(opt, lso.LossScaleOptimizer)): 602 if loss_scale == 'dynamic': 603 opt = lso.LossScaleOptimizer(opt) 604 else: 605 opt = lso.LossScaleOptimizerV1(opt, loss_scale) 606 return opt 607 608 return nest.map_structure(_get_single_optimizer, optimizer) 609 610 @trackable.no_automatic_dependency_tracking 611 def _reset_compile_cache(self): 612 self.train_function = None 613 self.test_function = None 614 self.predict_function = None 615 # Used to cache the `tf.function`'ed `train_function` to be logged in 616 # TensorBoard, since the original `train_function` is not necessarily 617 # a `tf.function` (e.g., with ParameterServerStrategy, the `train_function` 618 # is a scheduling of the actual training function to a remote worker). 619 self.train_tf_function = None 620 621 # Used to cache `trainable` attr of `Layer`s for `fit`. 622 self._compiled_trainable_state = self._get_trainable_state() 623 624 @trackable.no_automatic_dependency_tracking 625 def _configure_steps_per_execution(self, steps_per_execution): 626 self._steps_per_execution = variables.Variable( 627 steps_per_execution, 628 dtype='int64', 629 aggregation=variables.VariableAggregationV2.ONLY_FIRST_REPLICA) 630 631 @property 632 def _should_compute_mask(self): 633 return False 634 635 @property 636 def metrics(self): 637 """Returns the model's metrics added using `compile`, `add_metric` APIs. 638 639 Note: Metrics passed to `compile()` are available only after a `keras.Model` 640 has been trained/evaluated on actual data. 641 642 Examples: 643 644 >>> inputs = tf.keras.layers.Input(shape=(3,)) 645 >>> outputs = tf.keras.layers.Dense(2)(inputs) 646 >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs) 647 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"]) 648 >>> [m.name for m in model.metrics] 649 [] 650 651 >>> x = np.random.random((2, 3)) 652 >>> y = np.random.randint(0, 2, (2, 2)) 653 >>> model.fit(x, y) 654 >>> [m.name for m in model.metrics] 655 ['loss', 'mae'] 656 657 >>> inputs = tf.keras.layers.Input(shape=(3,)) 658 >>> d = tf.keras.layers.Dense(2, name='out') 659 >>> output_1 = d(inputs) 660 >>> output_2 = d(inputs) 661 >>> model = tf.keras.models.Model( 662 ... inputs=inputs, outputs=[output_1, output_2]) 663 >>> model.add_metric( 664 ... tf.reduce_sum(output_2), name='mean', aggregation='mean') 665 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"]) 666 >>> model.fit(x, (y, y)) 667 >>> [m.name for m in model.metrics] 668 ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae', 669 'out_1_acc', 'mean'] 670 671 """ 672 metrics = [] 673 if self._is_compiled: 674 # TODO(omalleyt): Track `LossesContainer` and `MetricsContainer` objects 675 # so that attr names are not load-bearing. 676 if self.compiled_loss is not None: 677 metrics += self.compiled_loss.metrics 678 if self.compiled_metrics is not None: 679 metrics += self.compiled_metrics.metrics 680 681 for l in self._flatten_layers(): 682 metrics.extend(l._metrics) # pylint: disable=protected-access 683 return metrics 684 685 @property 686 def metrics_names(self): 687 """Returns the model's display labels for all outputs. 688 689 Note: `metrics_names` are available only after a `keras.Model` has been 690 trained/evaluated on actual data. 691 692 Examples: 693 694 >>> inputs = tf.keras.layers.Input(shape=(3,)) 695 >>> outputs = tf.keras.layers.Dense(2)(inputs) 696 >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs) 697 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"]) 698 >>> model.metrics_names 699 [] 700 701 >>> x = np.random.random((2, 3)) 702 >>> y = np.random.randint(0, 2, (2, 2)) 703 >>> model.fit(x, y) 704 >>> model.metrics_names 705 ['loss', 'mae'] 706 707 >>> inputs = tf.keras.layers.Input(shape=(3,)) 708 >>> d = tf.keras.layers.Dense(2, name='out') 709 >>> output_1 = d(inputs) 710 >>> output_2 = d(inputs) 711 >>> model = tf.keras.models.Model( 712 ... inputs=inputs, outputs=[output_1, output_2]) 713 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"]) 714 >>> model.fit(x, (y, y)) 715 >>> model.metrics_names 716 ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae', 717 'out_1_acc'] 718 719 """ 720 721 # This property includes all output names including `loss` and per-output 722 # losses for backward compatibility. 723 return [m.name for m in self.metrics] 724 725 @property 726 def distribute_strategy(self): 727 """The `tf.distribute.Strategy` this model was created under.""" 728 return self._distribution_strategy or ds_context.get_strategy() 729 730 @property 731 def run_eagerly(self): 732 """Settable attribute indicating whether the model should run eagerly. 733 734 Running eagerly means that your model will be run step by step, 735 like Python code. Your model might run slower, but it should become easier 736 for you to debug it by stepping into individual layer calls. 737 738 By default, we will attempt to compile your model to a static graph to 739 deliver the best execution performance. 740 741 Returns: 742 Boolean, whether the model should run eagerly. 743 """ 744 if self.dynamic and self._run_eagerly is False: # pylint:disable=g-bool-id-comparison 745 # TODO(fchollet): consider using py_func to enable this. 746 raise ValueError('Your model contains layers that can only be ' 747 'successfully run in eager execution (layers ' 748 'constructed with `dynamic=True`). ' 749 'You cannot set `run_eagerly=False`.') 750 751 if self._cluster_coordinator and self._run_eagerly: 752 raise ValueError('When using `Model` with `ParameterServerStrategy`, ' 753 '`run_eagerly` is not supported.') 754 755 # Run eagerly logic, by priority: 756 # (1) Dynamic models must be run eagerly. 757 # (2) Explicitly setting run_eagerly causes a Model to be run eagerly. 758 # (3) Not explicitly setting run_eagerly defaults to TF's global setting. 759 return (self.dynamic or self._run_eagerly or 760 (def_function.functions_run_eagerly() and 761 self._run_eagerly is None)) 762 763 @run_eagerly.setter 764 def run_eagerly(self, value): 765 self._run_eagerly = value 766 767 def train_step(self, data): 768 """The logic for one training step. 769 770 This method can be overridden to support custom training logic. 771 For concrete examples of how to override this method see 772 [Customizing what happends in fit](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit). 773 This method is called by `Model.make_train_function`. 774 775 This method should contain the mathematical logic for one step of training. 776 This typically includes the forward pass, loss calculation, backpropagation, 777 and metric updates. 778 779 Configuration details for *how* this logic is run (e.g. `tf.function` and 780 `tf.distribute.Strategy` settings), should be left to 781 `Model.make_train_function`, which can also be overridden. 782 783 Args: 784 data: A nested structure of `Tensor`s. 785 786 Returns: 787 A `dict` containing values that will be passed to 788 `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the 789 values of the `Model`'s metrics are returned. Example: 790 `{'loss': 0.2, 'accuracy': 0.7}`. 791 792 """ 793 # These are the only transformations `Model.fit` applies to user-input 794 # data when a `tf.data.Dataset` is provided. 795 data = data_adapter.expand_1d(data) 796 x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) 797 # Run forward pass. 798 with backprop.GradientTape() as tape: 799 y_pred = self(x, training=True) 800 loss = self.compiled_loss( 801 y, y_pred, sample_weight, regularization_losses=self.losses) 802 # Run backwards pass. 803 self.optimizer.minimize(loss, self.trainable_variables, tape=tape) 804 self.compiled_metrics.update_state(y, y_pred, sample_weight) 805 # Collect metrics to return 806 return_metrics = {} 807 for metric in self.metrics: 808 result = metric.result() 809 if isinstance(result, dict): 810 return_metrics.update(result) 811 else: 812 return_metrics[metric.name] = result 813 return return_metrics 814 815 def make_train_function(self): 816 """Creates a function that executes one step of training. 817 818 This method can be overridden to support custom training logic. 819 This method is called by `Model.fit` and `Model.train_on_batch`. 820 821 Typically, this method directly controls `tf.function` and 822 `tf.distribute.Strategy` settings, and delegates the actual training 823 logic to `Model.train_step`. 824 825 This function is cached the first time `Model.fit` or 826 `Model.train_on_batch` is called. The cache is cleared whenever 827 `Model.compile` is called. 828 829 Returns: 830 Function. The function created by this method should accept a 831 `tf.data.Iterator`, and return a `dict` containing values that will 832 be passed to `tf.keras.Callbacks.on_train_batch_end`, such as 833 `{'loss': 0.2, 'accuracy': 0.7}`. 834 """ 835 if self.train_function is not None: 836 return self.train_function 837 838 def step_function(model, iterator): 839 """Runs a single training step.""" 840 841 def run_step(data): 842 outputs = model.train_step(data) 843 # Ensure counter is updated only if `train_step` succeeds. 844 with ops.control_dependencies(_minimum_control_deps(outputs)): 845 model._train_counter.assign_add(1) # pylint: disable=protected-access 846 return outputs 847 848 data = next(iterator) 849 outputs = model.distribute_strategy.run(run_step, args=(data,)) 850 outputs = reduce_per_replica( 851 outputs, self.distribute_strategy, reduction='first') 852 write_scalar_summaries(outputs, step=model._train_counter) # pylint: disable=protected-access 853 return outputs 854 855 if self._steps_per_execution.numpy().item() == 1: 856 857 def train_function(iterator): 858 """Runs a training execution with one step.""" 859 return step_function(self, iterator) 860 861 else: 862 863 def train_function(iterator): 864 """Runs a training execution with multiple steps.""" 865 for _ in math_ops.range(self._steps_per_execution): 866 outputs = step_function(self, iterator) 867 return outputs 868 869 if not self.run_eagerly: 870 train_function = def_function.function( 871 train_function, experimental_relax_shapes=True) 872 self.train_tf_function = train_function 873 874 self.train_function = train_function 875 876 if self._cluster_coordinator: 877 self.train_function = lambda iterator: self._cluster_coordinator.schedule( # pylint: disable=g-long-lambda 878 train_function, args=(iterator,)) 879 880 return self.train_function 881 882 def fit(self, 883 x=None, 884 y=None, 885 batch_size=None, 886 epochs=1, 887 verbose='auto', 888 callbacks=None, 889 validation_split=0., 890 validation_data=None, 891 shuffle=True, 892 class_weight=None, 893 sample_weight=None, 894 initial_epoch=0, 895 steps_per_epoch=None, 896 validation_steps=None, 897 validation_batch_size=None, 898 validation_freq=1, 899 max_queue_size=10, 900 workers=1, 901 use_multiprocessing=False): 902 """Trains the model for a fixed number of epochs (iterations on a dataset). 903 904 Args: 905 x: Input data. It could be: 906 - A Numpy array (or array-like), or a list of arrays 907 (in case the model has multiple inputs). 908 - A TensorFlow tensor, or a list of tensors 909 (in case the model has multiple inputs). 910 - A dict mapping input names to the corresponding array/tensors, 911 if the model has named inputs. 912 - A `tf.data` dataset. Should return a tuple 913 of either `(inputs, targets)` or 914 `(inputs, targets, sample_weights)`. 915 - A generator or `keras.utils.Sequence` returning `(inputs, targets)` 916 or `(inputs, targets, sample_weights)`. 917 - A `tf.keras.utils.experimental.DatasetCreator`, which wraps a 918 callable that takes a single argument of type 919 `tf.distribute.InputContext`, and returns a `tf.data.Dataset`. 920 `DatasetCreator` should be used when users prefer to specify the 921 per-replica batching and sharding logic for the `Dataset`. 922 See `tf.keras.utils.experimental.DatasetCreator` doc for more 923 information. 924 A more detailed description of unpacking behavior for iterator types 925 (Dataset, generator, Sequence) is given below. If using 926 `tf.distribute.experimental.ParameterServerStrategy`, only 927 `DatasetCreator` type is supported for `x`. 928 y: Target data. Like the input data `x`, 929 it could be either Numpy array(s) or TensorFlow tensor(s). 930 It should be consistent with `x` (you cannot have Numpy inputs and 931 tensor targets, or inversely). If `x` is a dataset, generator, 932 or `keras.utils.Sequence` instance, `y` should 933 not be specified (since targets will be obtained from `x`). 934 batch_size: Integer or `None`. 935 Number of samples per gradient update. 936 If unspecified, `batch_size` will default to 32. 937 Do not specify the `batch_size` if your data is in the 938 form of datasets, generators, or `keras.utils.Sequence` instances 939 (since they generate batches). 940 epochs: Integer. Number of epochs to train the model. 941 An epoch is an iteration over the entire `x` and `y` 942 data provided. 943 Note that in conjunction with `initial_epoch`, 944 `epochs` is to be understood as "final epoch". 945 The model is not trained for a number of iterations 946 given by `epochs`, but merely until the epoch 947 of index `epochs` is reached. 948 verbose: 'auto', 0, 1, or 2. Verbosity mode. 949 0 = silent, 1 = progress bar, 2 = one line per epoch. 950 'auto' defaults to 1 for most cases, but 2 when used with 951 `ParameterServerStrategy`. Note that the progress bar is not 952 particularly useful when logged to a file, so verbose=2 is 953 recommended when not running interactively (eg, in a production 954 environment). 955 callbacks: List of `keras.callbacks.Callback` instances. 956 List of callbacks to apply during training. 957 See `tf.keras.callbacks`. Note `tf.keras.callbacks.ProgbarLogger` 958 and `tf.keras.callbacks.History` callbacks are created automatically 959 and need not be passed into `model.fit`. 960 `tf.keras.callbacks.ProgbarLogger` is created or not based on 961 `verbose` argument to `model.fit`. 962 Callbacks with batch-level calls are currently unsupported with 963 `tf.distribute.experimental.ParameterServerStrategy`, and users are 964 advised to implement epoch-level calls instead with an appropriate 965 `steps_per_epoch` value. 966 validation_split: Float between 0 and 1. 967 Fraction of the training data to be used as validation data. 968 The model will set apart this fraction of the training data, 969 will not train on it, and will evaluate 970 the loss and any model metrics 971 on this data at the end of each epoch. 972 The validation data is selected from the last samples 973 in the `x` and `y` data provided, before shuffling. This argument is 974 not supported when `x` is a dataset, generator or 975 `keras.utils.Sequence` instance. 976 `validation_split` is not yet supported with 977 `tf.distribute.experimental.ParameterServerStrategy`. 978 validation_data: Data on which to evaluate 979 the loss and any model metrics at the end of each epoch. 980 The model will not be trained on this data. Thus, note the fact 981 that the validation loss of data provided using `validation_split` 982 or `validation_data` is not affected by regularization layers like 983 noise and dropout. 984 `validation_data` will override `validation_split`. 985 `validation_data` could be: 986 - A tuple `(x_val, y_val)` of Numpy arrays or tensors. 987 - A tuple `(x_val, y_val, val_sample_weights)` of NumPy arrays. 988 - A `tf.data.Dataset`. 989 - A Python generator or `keras.utils.Sequence` returning 990 `(inputs, targets)` or `(inputs, targets, sample_weights)`. 991 `validation_data` is not yet supported with 992 `tf.distribute.experimental.ParameterServerStrategy`. 993 shuffle: Boolean (whether to shuffle the training data 994 before each epoch) or str (for 'batch'). This argument is ignored 995 when `x` is a generator or an object of tf.data.Dataset. 996 'batch' is a special option for dealing 997 with the limitations of HDF5 data; it shuffles in batch-sized 998 chunks. Has no effect when `steps_per_epoch` is not `None`. 999 class_weight: Optional dictionary mapping class indices (integers) 1000 to a weight (float) value, used for weighting the loss function 1001 (during training only). 1002 This can be useful to tell the model to 1003 "pay more attention" to samples from 1004 an under-represented class. 1005 sample_weight: Optional Numpy array of weights for 1006 the training samples, used for weighting the loss function 1007 (during training only). You can either pass a flat (1D) 1008 Numpy array with the same length as the input samples 1009 (1:1 mapping between weights and samples), 1010 or in the case of temporal data, 1011 you can pass a 2D array with shape 1012 `(samples, sequence_length)`, 1013 to apply a different weight to every timestep of every sample. This 1014 argument is not supported when `x` is a dataset, generator, or 1015 `keras.utils.Sequence` instance, instead provide the sample_weights 1016 as the third element of `x`. 1017 initial_epoch: Integer. 1018 Epoch at which to start training 1019 (useful for resuming a previous training run). 1020 steps_per_epoch: Integer or `None`. 1021 Total number of steps (batches of samples) 1022 before declaring one epoch finished and starting the 1023 next epoch. When training with input tensors such as 1024 TensorFlow data tensors, the default `None` is equal to 1025 the number of samples in your dataset divided by 1026 the batch size, or 1 if that cannot be determined. If x is a 1027 `tf.data` dataset, and 'steps_per_epoch' 1028 is None, the epoch will run until the input dataset is exhausted. 1029 When passing an infinitely repeating dataset, you must specify the 1030 `steps_per_epoch` argument. If `steps_per_epoch=-1` the training 1031 will run indefinitely with an infinitely repeating dataset. 1032 This argument is not supported with array inputs. 1033 When using `tf.distribute.experimental.ParameterServerStrategy`: 1034 * `steps_per_epoch=None` is not supported. 1035 validation_steps: Only relevant if `validation_data` is provided and 1036 is a `tf.data` dataset. Total number of steps (batches of 1037 samples) to draw before stopping when performing validation 1038 at the end of every epoch. If 'validation_steps' is None, validation 1039 will run until the `validation_data` dataset is exhausted. In the 1040 case of an infinitely repeated dataset, it will run into an 1041 infinite loop. If 'validation_steps' is specified and only part of 1042 the dataset will be consumed, the evaluation will start from the 1043 beginning of the dataset at each epoch. This ensures that the same 1044 validation samples are used every time. 1045 validation_batch_size: Integer or `None`. 1046 Number of samples per validation batch. 1047 If unspecified, will default to `batch_size`. 1048 Do not specify the `validation_batch_size` if your data is in the 1049 form of datasets, generators, or `keras.utils.Sequence` instances 1050 (since they generate batches). 1051 validation_freq: Only relevant if validation data is provided. Integer 1052 or `collections.abc.Container` instance (e.g. list, tuple, etc.). 1053 If an integer, specifies how many training epochs to run before a 1054 new validation run is performed, e.g. `validation_freq=2` runs 1055 validation every 2 epochs. If a Container, specifies the epochs on 1056 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs 1057 validation at the end of the 1st, 2nd, and 10th epochs. 1058 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 1059 input only. Maximum size for the generator queue. 1060 If unspecified, `max_queue_size` will default to 10. 1061 workers: Integer. Used for generator or `keras.utils.Sequence` input 1062 only. Maximum number of processes to spin up 1063 when using process-based threading. If unspecified, `workers` 1064 will default to 1. 1065 use_multiprocessing: Boolean. Used for generator or 1066 `keras.utils.Sequence` input only. If `True`, use process-based 1067 threading. If unspecified, `use_multiprocessing` will default to 1068 `False`. Note that because this implementation relies on 1069 multiprocessing, you should not pass non-picklable arguments to 1070 the generator as they can't be passed easily to children processes. 1071 1072 Unpacking behavior for iterator-like inputs: 1073 A common pattern is to pass a tf.data.Dataset, generator, or 1074 tf.keras.utils.Sequence to the `x` argument of fit, which will in fact 1075 yield not only features (x) but optionally targets (y) and sample weights. 1076 Keras requires that the output of such iterator-likes be unambiguous. The 1077 iterator should return a tuple of length 1, 2, or 3, where the optional 1078 second and third elements will be used for y and sample_weight 1079 respectively. Any other type provided will be wrapped in a length one 1080 tuple, effectively treating everything as 'x'. When yielding dicts, they 1081 should still adhere to the top-level tuple structure. 1082 e.g. `({"x0": x0, "x1": x1}, y)`. Keras will not attempt to separate 1083 features, targets, and weights from the keys of a single dict. 1084 A notable unsupported data type is the namedtuple. The reason is that 1085 it behaves like both an ordered datatype (tuple) and a mapping 1086 datatype (dict). So given a namedtuple of the form: 1087 `namedtuple("example_tuple", ["y", "x"])` 1088 it is ambiguous whether to reverse the order of the elements when 1089 interpreting the value. Even worse is a tuple of the form: 1090 `namedtuple("other_tuple", ["x", "y", "z"])` 1091 where it is unclear if the tuple was intended to be unpacked into x, y, 1092 and sample_weight or passed through as a single element to `x`. As a 1093 result the data processing code will simply raise a ValueError if it 1094 encounters a namedtuple. (Along with instructions to remedy the issue.) 1095 1096 Returns: 1097 A `History` object. Its `History.history` attribute is 1098 a record of training loss values and metrics values 1099 at successive epochs, as well as validation loss values 1100 and validation metrics values (if applicable). 1101 1102 Raises: 1103 RuntimeError: 1. If the model was never compiled or, 1104 2. If `model.fit` is wrapped in `tf.function`. 1105 1106 ValueError: In case of mismatch between the provided input data 1107 and what the model expects or when the input data is empty. 1108 """ 1109 # Legacy graph support is contained in `training_v1.Model`. 1110 version_utils.disallow_legacy_graph('Model', 'fit') 1111 self._assert_compile_was_called() 1112 self._check_call_args('fit') 1113 _disallow_inside_tf_function('fit') 1114 1115 if verbose == 'auto': 1116 if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access 1117 verbose = 2 # Default to epoch-level logging for PSStrategy. 1118 else: 1119 verbose = 1 # Default to batch-level logging otherwise. 1120 1121 if validation_split: 1122 # Create the validation data using the training data. Only supported for 1123 # `Tensor` and `NumPy` input. 1124 (x, y, sample_weight), validation_data = ( 1125 data_adapter.train_validation_split( 1126 (x, y, sample_weight), validation_split=validation_split)) 1127 1128 if validation_data: 1129 val_x, val_y, val_sample_weight = ( 1130 data_adapter.unpack_x_y_sample_weight(validation_data)) 1131 1132 if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access 1133 self._cluster_coordinator = cluster_coordinator.ClusterCoordinator( 1134 self.distribute_strategy) 1135 1136 with self.distribute_strategy.scope(), \ 1137 training_utils.RespectCompiledTrainableState(self): 1138 # Creates a `tf.data.Dataset` and handles batch and epoch iteration. 1139 data_handler = data_adapter.get_data_handler( 1140 x=x, 1141 y=y, 1142 sample_weight=sample_weight, 1143 batch_size=batch_size, 1144 steps_per_epoch=steps_per_epoch, 1145 initial_epoch=initial_epoch, 1146 epochs=epochs, 1147 shuffle=shuffle, 1148 class_weight=class_weight, 1149 max_queue_size=max_queue_size, 1150 workers=workers, 1151 use_multiprocessing=use_multiprocessing, 1152 model=self, 1153 steps_per_execution=self._steps_per_execution) 1154 1155 # Container that configures and calls `tf.keras.Callback`s. 1156 if not isinstance(callbacks, callbacks_module.CallbackList): 1157 callbacks = callbacks_module.CallbackList( 1158 callbacks, 1159 add_history=True, 1160 add_progbar=verbose != 0, 1161 model=self, 1162 verbose=verbose, 1163 epochs=epochs, 1164 steps=data_handler.inferred_steps) 1165 1166 self.stop_training = False 1167 self.train_function = self.make_train_function() 1168 self._train_counter.assign(0) 1169 callbacks.on_train_begin() 1170 training_logs = None 1171 # Handle fault-tolerance for multi-worker. 1172 # TODO(omalleyt): Fix the ordering issues that mean this has to 1173 # happen after `callbacks.on_train_begin`. 1174 data_handler._initial_epoch = ( # pylint: disable=protected-access 1175 self._maybe_load_initial_epoch_from_ckpt(initial_epoch)) 1176 logs = None 1177 for epoch, iterator in data_handler.enumerate_epochs(): 1178 self.reset_metrics() 1179 callbacks.on_epoch_begin(epoch) 1180 with data_handler.catch_stop_iteration(): 1181 for step in data_handler.steps(): 1182 with trace.Trace( 1183 'train', 1184 epoch_num=epoch, 1185 step_num=step, 1186 batch_size=batch_size, 1187 _r=1): 1188 callbacks.on_train_batch_begin(step) 1189 tmp_logs = self.train_function(iterator) 1190 if data_handler.should_sync: 1191 context.async_wait() 1192 logs = tmp_logs # No error, now safe to assign to logs. 1193 end_step = step + data_handler.step_increment 1194 callbacks.on_train_batch_end(end_step, logs) 1195 if self.stop_training: 1196 break 1197 1198 logs = tf_utils.sync_to_numpy_or_python_type(logs) 1199 if logs is None: 1200 raise ValueError('Expect x to be a non-empty array or dataset.') 1201 epoch_logs = copy.copy(logs) 1202 1203 # Run validation. 1204 if validation_data and self._should_eval(epoch, validation_freq): 1205 # Create data_handler for evaluation and cache it. 1206 if getattr(self, '_eval_data_handler', None) is None: 1207 self._eval_data_handler = data_adapter.get_data_handler( 1208 x=val_x, 1209 y=val_y, 1210 sample_weight=val_sample_weight, 1211 batch_size=validation_batch_size or batch_size, 1212 steps_per_epoch=validation_steps, 1213 initial_epoch=0, 1214 epochs=1, 1215 max_queue_size=max_queue_size, 1216 workers=workers, 1217 use_multiprocessing=use_multiprocessing, 1218 model=self, 1219 steps_per_execution=self._steps_per_execution) 1220 val_logs = self.evaluate( 1221 x=val_x, 1222 y=val_y, 1223 sample_weight=val_sample_weight, 1224 batch_size=validation_batch_size or batch_size, 1225 steps=validation_steps, 1226 callbacks=callbacks, 1227 max_queue_size=max_queue_size, 1228 workers=workers, 1229 use_multiprocessing=use_multiprocessing, 1230 return_dict=True, 1231 _use_cached_eval_dataset=True) 1232 val_logs = {'val_' + name: val for name, val in val_logs.items()} 1233 epoch_logs.update(val_logs) 1234 1235 callbacks.on_epoch_end(epoch, epoch_logs) 1236 training_logs = epoch_logs 1237 if self.stop_training: 1238 break 1239 1240 # If eval data_hanlder exists, delete it after all epochs are done. 1241 if getattr(self, '_eval_data_handler', None) is not None: 1242 del self._eval_data_handler 1243 callbacks.on_train_end(logs=training_logs) 1244 return self.history 1245 1246 def test_step(self, data): 1247 """The logic for one evaluation step. 1248 1249 This method can be overridden to support custom evaluation logic. 1250 This method is called by `Model.make_test_function`. 1251 1252 This function should contain the mathematical logic for one step of 1253 evaluation. 1254 This typically includes the forward pass, loss calculation, and metrics 1255 updates. 1256 1257 Configuration details for *how* this logic is run (e.g. `tf.function` and 1258 `tf.distribute.Strategy` settings), should be left to 1259 `Model.make_test_function`, which can also be overridden. 1260 1261 Args: 1262 data: A nested structure of `Tensor`s. 1263 1264 Returns: 1265 A `dict` containing values that will be passed to 1266 `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the 1267 values of the `Model`'s metrics are returned. 1268 """ 1269 data = data_adapter.expand_1d(data) 1270 x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) 1271 1272 y_pred = self(x, training=False) 1273 # Updates stateful loss metrics. 1274 self.compiled_loss( 1275 y, y_pred, sample_weight, regularization_losses=self.losses) 1276 self.compiled_metrics.update_state(y, y_pred, sample_weight) 1277 # Collect metrics to return 1278 return_metrics = {} 1279 for metric in self.metrics: 1280 result = metric.result() 1281 if isinstance(result, dict): 1282 return_metrics.update(result) 1283 else: 1284 return_metrics[metric.name] = result 1285 return return_metrics 1286 1287 def make_test_function(self): 1288 """Creates a function that executes one step of evaluation. 1289 1290 This method can be overridden to support custom evaluation logic. 1291 This method is called by `Model.evaluate` and `Model.test_on_batch`. 1292 1293 Typically, this method directly controls `tf.function` and 1294 `tf.distribute.Strategy` settings, and delegates the actual evaluation 1295 logic to `Model.test_step`. 1296 1297 This function is cached the first time `Model.evaluate` or 1298 `Model.test_on_batch` is called. The cache is cleared whenever 1299 `Model.compile` is called. 1300 1301 Returns: 1302 Function. The function created by this method should accept a 1303 `tf.data.Iterator`, and return a `dict` containing values that will 1304 be passed to `tf.keras.Callbacks.on_test_batch_end`. 1305 """ 1306 if self.test_function is not None: 1307 return self.test_function 1308 1309 def step_function(model, iterator): 1310 """Runs a single evaluation step.""" 1311 1312 def run_step(data): 1313 outputs = model.test_step(data) 1314 # Ensure counter is updated only if `test_step` succeeds. 1315 with ops.control_dependencies(_minimum_control_deps(outputs)): 1316 model._test_counter.assign_add(1) # pylint: disable=protected-access 1317 return outputs 1318 1319 data = next(iterator) 1320 outputs = model.distribute_strategy.run(run_step, args=(data,)) 1321 outputs = reduce_per_replica( 1322 outputs, self.distribute_strategy, reduction='first') 1323 return outputs 1324 1325 if self._steps_per_execution.numpy().item() == 1: 1326 1327 def test_function(iterator): 1328 """Runs an evaluation execution with one step.""" 1329 return step_function(self, iterator) 1330 1331 else: 1332 1333 def test_function(iterator): 1334 """Runs an evaluation execution with multiple steps.""" 1335 for _ in math_ops.range(self._steps_per_execution): 1336 outputs = step_function(self, iterator) 1337 return outputs 1338 1339 if not self.run_eagerly: 1340 test_function = def_function.function( 1341 test_function, experimental_relax_shapes=True) 1342 1343 self.test_function = test_function 1344 1345 if self._cluster_coordinator: 1346 self.test_function = lambda iterator: self._cluster_coordinator.schedule( # pylint: disable=g-long-lambda 1347 test_function, args=(iterator,)) 1348 1349 return self.test_function 1350 1351 def evaluate(self, 1352 x=None, 1353 y=None, 1354 batch_size=None, 1355 verbose=1, 1356 sample_weight=None, 1357 steps=None, 1358 callbacks=None, 1359 max_queue_size=10, 1360 workers=1, 1361 use_multiprocessing=False, 1362 return_dict=False, 1363 **kwargs): 1364 """Returns the loss value & metrics values for the model in test mode. 1365 1366 Computation is done in batches (see the `batch_size` arg.) 1367 1368 Args: 1369 x: Input data. It could be: 1370 - A Numpy array (or array-like), or a list of arrays 1371 (in case the model has multiple inputs). 1372 - A TensorFlow tensor, or a list of tensors 1373 (in case the model has multiple inputs). 1374 - A dict mapping input names to the corresponding array/tensors, 1375 if the model has named inputs. 1376 - A `tf.data` dataset. Should return a tuple 1377 of either `(inputs, targets)` or 1378 `(inputs, targets, sample_weights)`. 1379 - A generator or `keras.utils.Sequence` returning `(inputs, targets)` 1380 or `(inputs, targets, sample_weights)`. 1381 A more detailed description of unpacking behavior for iterator types 1382 (Dataset, generator, Sequence) is given in the `Unpacking behavior 1383 for iterator-like inputs` section of `Model.fit`. 1384 y: Target data. Like the input data `x`, it could be either Numpy 1385 array(s) or TensorFlow tensor(s). It should be consistent with `x` 1386 (you cannot have Numpy inputs and tensor targets, or inversely). If 1387 `x` is a dataset, generator or `keras.utils.Sequence` instance, `y` 1388 should not be specified (since targets will be obtained from the 1389 iterator/dataset). 1390 batch_size: Integer or `None`. Number of samples per batch of 1391 computation. If unspecified, `batch_size` will default to 32. Do not 1392 specify the `batch_size` if your data is in the form of a dataset, 1393 generators, or `keras.utils.Sequence` instances (since they generate 1394 batches). 1395 verbose: 0 or 1. Verbosity mode. 0 = silent, 1 = progress bar. 1396 sample_weight: Optional Numpy array of weights for the test samples, 1397 used for weighting the loss function. You can either pass a flat (1D) 1398 Numpy array with the same length as the input samples 1399 (1:1 mapping between weights and samples), or in the case of 1400 temporal data, you can pass a 2D array with shape `(samples, 1401 sequence_length)`, to apply a different weight to every timestep 1402 of every sample. This argument is not supported when `x` is a 1403 dataset, instead pass sample weights as the third element of `x`. 1404 steps: Integer or `None`. Total number of steps (batches of samples) 1405 before declaring the evaluation round finished. Ignored with the 1406 default value of `None`. If x is a `tf.data` dataset and `steps` is 1407 None, 'evaluate' will run until the dataset is exhausted. This 1408 argument is not supported with array inputs. 1409 callbacks: List of `keras.callbacks.Callback` instances. List of 1410 callbacks to apply during evaluation. See 1411 [callbacks](/api_docs/python/tf/keras/callbacks). 1412 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 1413 input only. Maximum size for the generator queue. If unspecified, 1414 `max_queue_size` will default to 10. 1415 workers: Integer. Used for generator or `keras.utils.Sequence` input 1416 only. Maximum number of processes to spin up when using process-based 1417 threading. If unspecified, `workers` will default to 1. 1418 use_multiprocessing: Boolean. Used for generator or 1419 `keras.utils.Sequence` input only. If `True`, use process-based 1420 threading. If unspecified, `use_multiprocessing` will default to 1421 `False`. Note that because this implementation relies on 1422 multiprocessing, you should not pass non-picklable arguments to the 1423 generator as they can't be passed easily to children processes. 1424 return_dict: If `True`, loss and metric results are returned as a dict, 1425 with each key being the name of the metric. If `False`, they are 1426 returned as a list. 1427 **kwargs: Unused at this time. 1428 1429 See the discussion of `Unpacking behavior for iterator-like inputs` for 1430 `Model.fit`. 1431 1432 `Model.evaluate` is not yet supported with 1433 `tf.distribute.experimental.ParameterServerStrategy`. 1434 1435 Returns: 1436 Scalar test loss (if the model has a single output and no metrics) 1437 or list of scalars (if the model has multiple outputs 1438 and/or metrics). The attribute `model.metrics_names` will give you 1439 the display labels for the scalar outputs. 1440 1441 Raises: 1442 RuntimeError: If `model.evaluate` is wrapped in `tf.function`. 1443 ValueError: in case of invalid arguments. 1444 """ 1445 version_utils.disallow_legacy_graph('Model', 'evaluate') 1446 self._assert_compile_was_called() 1447 self._check_call_args('evaluate') 1448 _disallow_inside_tf_function('evaluate') 1449 use_cached_eval_dataset = kwargs.pop('_use_cached_eval_dataset', False) 1450 if kwargs: 1451 raise TypeError('Invalid keyword arguments: %s' % (kwargs,)) 1452 1453 if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access 1454 self._cluster_coordinator = cluster_coordinator.ClusterCoordinator( 1455 self.distribute_strategy) 1456 1457 with self.distribute_strategy.scope(): 1458 # Use cached evaluation data only when it's called in `Model.fit` 1459 if (use_cached_eval_dataset 1460 and getattr(self, '_eval_data_handler', None) is not None): 1461 data_handler = self._eval_data_handler 1462 else: 1463 # Creates a `tf.data.Dataset` and handles batch and epoch iteration. 1464 data_handler = data_adapter.get_data_handler( 1465 x=x, 1466 y=y, 1467 sample_weight=sample_weight, 1468 batch_size=batch_size, 1469 steps_per_epoch=steps, 1470 initial_epoch=0, 1471 epochs=1, 1472 max_queue_size=max_queue_size, 1473 workers=workers, 1474 use_multiprocessing=use_multiprocessing, 1475 model=self, 1476 steps_per_execution=self._steps_per_execution) 1477 1478 # Container that configures and calls `tf.keras.Callback`s. 1479 if not isinstance(callbacks, callbacks_module.CallbackList): 1480 callbacks = callbacks_module.CallbackList( 1481 callbacks, 1482 add_history=True, 1483 add_progbar=verbose != 0, 1484 model=self, 1485 verbose=verbose, 1486 epochs=1, 1487 steps=data_handler.inferred_steps) 1488 1489 logs = {} 1490 self.test_function = self.make_test_function() 1491 self._test_counter.assign(0) 1492 callbacks.on_test_begin() 1493 for _, iterator in data_handler.enumerate_epochs(): # Single epoch. 1494 self.reset_metrics() 1495 with data_handler.catch_stop_iteration(): 1496 for step in data_handler.steps(): 1497 with trace.Trace('test', step_num=step, _r=1): 1498 callbacks.on_test_batch_begin(step) 1499 tmp_logs = self.test_function(iterator) 1500 if data_handler.should_sync: 1501 context.async_wait() 1502 logs = tmp_logs # No error, now safe to assign to logs. 1503 end_step = step + data_handler.step_increment 1504 callbacks.on_test_batch_end(end_step, logs) 1505 logs = tf_utils.sync_to_numpy_or_python_type(logs) 1506 callbacks.on_test_end(logs=logs) 1507 1508 if return_dict: 1509 return logs 1510 else: 1511 return flatten_metrics_in_order(logs, self.metrics_names) 1512 1513 def predict_step(self, data): 1514 """The logic for one inference step. 1515 1516 This method can be overridden to support custom inference logic. 1517 This method is called by `Model.make_predict_function`. 1518 1519 This method should contain the mathematical logic for one step of inference. 1520 This typically includes the forward pass. 1521 1522 Configuration details for *how* this logic is run (e.g. `tf.function` and 1523 `tf.distribute.Strategy` settings), should be left to 1524 `Model.make_predict_function`, which can also be overridden. 1525 1526 Args: 1527 data: A nested structure of `Tensor`s. 1528 1529 Returns: 1530 The result of one inference step, typically the output of calling the 1531 `Model` on data. 1532 """ 1533 data = data_adapter.expand_1d(data) 1534 x, _, _ = data_adapter.unpack_x_y_sample_weight(data) 1535 return self(x, training=False) 1536 1537 def make_predict_function(self): 1538 """Creates a function that executes one step of inference. 1539 1540 This method can be overridden to support custom inference logic. 1541 This method is called by `Model.predict` and `Model.predict_on_batch`. 1542 1543 Typically, this method directly controls `tf.function` and 1544 `tf.distribute.Strategy` settings, and delegates the actual evaluation 1545 logic to `Model.predict_step`. 1546 1547 This function is cached the first time `Model.predict` or 1548 `Model.predict_on_batch` is called. The cache is cleared whenever 1549 `Model.compile` is called. 1550 1551 Returns: 1552 Function. The function created by this method should accept a 1553 `tf.data.Iterator`, and return the outputs of the `Model`. 1554 """ 1555 if self.predict_function is not None: 1556 return self.predict_function 1557 1558 def step_function(model, iterator): 1559 """Runs a single evaluation step.""" 1560 1561 def run_step(data): 1562 outputs = model.predict_step(data) 1563 # Ensure counter is updated only if `test_step` succeeds. 1564 with ops.control_dependencies(_minimum_control_deps(outputs)): 1565 model._predict_counter.assign_add(1) # pylint: disable=protected-access 1566 return outputs 1567 1568 data = next(iterator) 1569 outputs = model.distribute_strategy.run(run_step, args=(data,)) 1570 outputs = reduce_per_replica( 1571 outputs, self.distribute_strategy, reduction='concat') 1572 return outputs 1573 1574 if (self._steps_per_execution is None or 1575 self._steps_per_execution.numpy().item() == 1): 1576 1577 def predict_function(iterator): 1578 """Runs an evaluation execution with one step.""" 1579 return step_function(self, iterator) 1580 1581 else: 1582 1583 def predict_function(iterator): 1584 """Runs an evaluation execution with multiple steps.""" 1585 outputs = step_function(self, iterator) 1586 for _ in math_ops.range(self._steps_per_execution - 1): 1587 directives.set_loop_options( 1588 shape_invariants=[( 1589 t, tf_utils.get_tensor_spec(t, dynamic_batch=True).shape) 1590 for t in nest.flatten(outputs)]) 1591 step_outputs = step_function(self, iterator) 1592 outputs = nest.map_structure(lambda t1, t2: concat([t1, t2]), outputs, 1593 step_outputs) 1594 return outputs 1595 1596 if not self.run_eagerly: 1597 predict_function = def_function.function( 1598 predict_function, experimental_relax_shapes=True) 1599 1600 self.predict_function = predict_function 1601 return self.predict_function 1602 1603 def predict(self, 1604 x, 1605 batch_size=None, 1606 verbose=0, 1607 steps=None, 1608 callbacks=None, 1609 max_queue_size=10, 1610 workers=1, 1611 use_multiprocessing=False): 1612 """Generates output predictions for the input samples. 1613 1614 Computation is done in batches. This method is designed for performance in 1615 large scale inputs. For small amount of inputs that fit in one batch, 1616 directly using `__call__` is recommended for faster execution, e.g., 1617 `model(x)`, or `model(x, training=False)` if you have layers such as 1618 `tf.keras.layers.BatchNormalization` that behaves differently during 1619 inference. Also, note the fact that test loss is not affected by 1620 regularization layers like noise and dropout. 1621 1622 Args: 1623 x: Input samples. It could be: 1624 - A Numpy array (or array-like), or a list of arrays 1625 (in case the model has multiple inputs). 1626 - A TensorFlow tensor, or a list of tensors 1627 (in case the model has multiple inputs). 1628 - A `tf.data` dataset. 1629 - A generator or `keras.utils.Sequence` instance. 1630 A more detailed description of unpacking behavior for iterator types 1631 (Dataset, generator, Sequence) is given in the `Unpacking behavior 1632 for iterator-like inputs` section of `Model.fit`. 1633 batch_size: Integer or `None`. 1634 Number of samples per batch. 1635 If unspecified, `batch_size` will default to 32. 1636 Do not specify the `batch_size` if your data is in the 1637 form of dataset, generators, or `keras.utils.Sequence` instances 1638 (since they generate batches). 1639 verbose: Verbosity mode, 0 or 1. 1640 steps: Total number of steps (batches of samples) 1641 before declaring the prediction round finished. 1642 Ignored with the default value of `None`. If x is a `tf.data` 1643 dataset and `steps` is None, `predict` will 1644 run until the input dataset is exhausted. 1645 callbacks: List of `keras.callbacks.Callback` instances. 1646 List of callbacks to apply during prediction. 1647 See [callbacks](/api_docs/python/tf/keras/callbacks). 1648 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 1649 input only. Maximum size for the generator queue. 1650 If unspecified, `max_queue_size` will default to 10. 1651 workers: Integer. Used for generator or `keras.utils.Sequence` input 1652 only. Maximum number of processes to spin up when using 1653 process-based threading. If unspecified, `workers` will default 1654 to 1. 1655 use_multiprocessing: Boolean. Used for generator or 1656 `keras.utils.Sequence` input only. If `True`, use process-based 1657 threading. If unspecified, `use_multiprocessing` will default to 1658 `False`. Note that because this implementation relies on 1659 multiprocessing, you should not pass non-picklable arguments to 1660 the generator as they can't be passed easily to children processes. 1661 1662 See the discussion of `Unpacking behavior for iterator-like inputs` for 1663 `Model.fit`. Note that Model.predict uses the same interpretation rules as 1664 `Model.fit` and `Model.evaluate`, so inputs must be unambiguous for all 1665 three methods. 1666 1667 Returns: 1668 Numpy array(s) of predictions. 1669 1670 Raises: 1671 RuntimeError: If `model.predict` is wrapped in `tf.function`. 1672 ValueError: In case of mismatch between the provided 1673 input data and the model's expectations, 1674 or in case a stateful model receives a number of samples 1675 that is not a multiple of the batch size. 1676 """ 1677 version_utils.disallow_legacy_graph('Model', 'predict') 1678 self._check_call_args('predict') 1679 _disallow_inside_tf_function('predict') 1680 1681 # TODO(yashkatariya): Cache model on the coordinator for faster prediction. 1682 # If running under PSS, then swap it with OneDeviceStrategy so that 1683 # execution will run on the coordinator. 1684 original_pss_strategy = None 1685 if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access 1686 original_pss_strategy = self.distribute_strategy 1687 self._distribution_strategy = None 1688 1689 # Cluster coordinator is set by `.fit()` and `.evaluate()` which is not 1690 # needed in `.predict()` because all the predictions happen on the 1691 # coordinator/locally. 1692 if self._cluster_coordinator: 1693 self._cluster_coordinator = None 1694 1695 outputs = None 1696 with self.distribute_strategy.scope(): 1697 # Creates a `tf.data.Dataset` and handles batch and epoch iteration. 1698 dataset_types = (dataset_ops.DatasetV1, dataset_ops.DatasetV2) 1699 if (self._in_multi_worker_mode() or _is_tpu_multi_host( 1700 self.distribute_strategy)) and isinstance(x, dataset_types): 1701 try: 1702 options = options_lib.Options() 1703 data_option = options_lib.AutoShardPolicy.DATA 1704 options.experimental_distribute.auto_shard_policy = data_option 1705 x = x.with_options(options) 1706 except ValueError: 1707 warnings.warn('Using Model.predict with ' 1708 'MultiWorkerDistributionStrategy or TPUStrategy and ' 1709 'AutoShardPolicy.FILE might lead to out-of-order result' 1710 '. Consider setting it to AutoShardPolicy.DATA.') 1711 1712 data_handler = data_adapter.get_data_handler( 1713 x=x, 1714 batch_size=batch_size, 1715 steps_per_epoch=steps, 1716 initial_epoch=0, 1717 epochs=1, 1718 max_queue_size=max_queue_size, 1719 workers=workers, 1720 use_multiprocessing=use_multiprocessing, 1721 model=self, 1722 steps_per_execution=self._steps_per_execution) 1723 1724 # Container that configures and calls `tf.keras.Callback`s. 1725 if not isinstance(callbacks, callbacks_module.CallbackList): 1726 callbacks = callbacks_module.CallbackList( 1727 callbacks, 1728 add_history=True, 1729 add_progbar=verbose != 0, 1730 model=self, 1731 verbose=verbose, 1732 epochs=1, 1733 steps=data_handler.inferred_steps) 1734 1735 self.predict_function = self.make_predict_function() 1736 self._predict_counter.assign(0) 1737 callbacks.on_predict_begin() 1738 batch_outputs = None 1739 for _, iterator in data_handler.enumerate_epochs(): # Single epoch. 1740 with data_handler.catch_stop_iteration(): 1741 for step in data_handler.steps(): 1742 callbacks.on_predict_batch_begin(step) 1743 tmp_batch_outputs = self.predict_function(iterator) 1744 if data_handler.should_sync: 1745 context.async_wait() 1746 batch_outputs = tmp_batch_outputs # No error, now safe to assign. 1747 if outputs is None: 1748 outputs = nest.map_structure(lambda batch_output: [batch_output], 1749 batch_outputs) 1750 else: 1751 nest.map_structure_up_to( 1752 batch_outputs, 1753 lambda output, batch_output: output.append(batch_output), 1754 outputs, batch_outputs) 1755 end_step = step + data_handler.step_increment 1756 callbacks.on_predict_batch_end(end_step, {'outputs': batch_outputs}) 1757 if batch_outputs is None: 1758 raise ValueError('Expect x to be a non-empty array or dataset.') 1759 callbacks.on_predict_end() 1760 all_outputs = nest.map_structure_up_to(batch_outputs, concat, outputs) 1761 1762 # If originally PSS strategy was used, then replace it back since predict 1763 # is running under `OneDeviceStrategy` after the swap and once its done 1764 # we need to replace it back to PSS again. 1765 if original_pss_strategy is not None: 1766 self._distribution_strategy = original_pss_strategy 1767 1768 return tf_utils.sync_to_numpy_or_python_type(all_outputs) 1769 1770 def reset_metrics(self): 1771 """Resets the state of all the metrics in the model. 1772 1773 Examples: 1774 1775 >>> inputs = tf.keras.layers.Input(shape=(3,)) 1776 >>> outputs = tf.keras.layers.Dense(2)(inputs) 1777 >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs) 1778 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"]) 1779 1780 >>> x = np.random.random((2, 3)) 1781 >>> y = np.random.randint(0, 2, (2, 2)) 1782 >>> _ = model.fit(x, y, verbose=0) 1783 >>> assert all(float(m.result()) for m in model.metrics) 1784 1785 >>> model.reset_metrics() 1786 >>> assert all(float(m.result()) == 0 for m in model.metrics) 1787 1788 """ 1789 for m in self.metrics: 1790 m.reset_state() 1791 1792 def train_on_batch(self, 1793 x, 1794 y=None, 1795 sample_weight=None, 1796 class_weight=None, 1797 reset_metrics=True, 1798 return_dict=False): 1799 """Runs a single gradient update on a single batch of data. 1800 1801 Args: 1802 x: Input data. It could be: 1803 - A Numpy array (or array-like), or a list of arrays 1804 (in case the model has multiple inputs). 1805 - A TensorFlow tensor, or a list of tensors 1806 (in case the model has multiple inputs). 1807 - A dict mapping input names to the corresponding array/tensors, 1808 if the model has named inputs. 1809 y: Target data. Like the input data `x`, it could be either Numpy 1810 array(s) or TensorFlow tensor(s). It should be consistent with `x` 1811 (you cannot have Numpy inputs and tensor targets, or inversely). 1812 sample_weight: Optional array of the same length as x, containing 1813 weights to apply to the model's loss for each sample. In the case of 1814 temporal data, you can pass a 2D array with shape (samples, 1815 sequence_length), to apply a different weight to every timestep of 1816 every sample. 1817 class_weight: Optional dictionary mapping class indices (integers) to a 1818 weight (float) to apply to the model's loss for the samples from this 1819 class during training. This can be useful to tell the model to "pay 1820 more attention" to samples from an under-represented class. 1821 reset_metrics: If `True`, the metrics returned will be only for this 1822 batch. If `False`, the metrics will be statefully accumulated across 1823 batches. 1824 return_dict: If `True`, loss and metric results are returned as a dict, 1825 with each key being the name of the metric. If `False`, they are 1826 returned as a list. 1827 1828 Returns: 1829 Scalar training loss 1830 (if the model has a single output and no metrics) 1831 or list of scalars (if the model has multiple outputs 1832 and/or metrics). The attribute `model.metrics_names` will give you 1833 the display labels for the scalar outputs. 1834 1835 Raises: 1836 RuntimeError: If `model.train_on_batch` is wrapped in `tf.function`. 1837 ValueError: In case of invalid user-provided arguments. 1838 """ 1839 self._assert_compile_was_called() 1840 self._check_call_args('train_on_batch') 1841 _disallow_inside_tf_function('train_on_batch') 1842 with self.distribute_strategy.scope(), \ 1843 training_utils.RespectCompiledTrainableState(self): 1844 iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x, 1845 y, sample_weight, 1846 class_weight) 1847 self.train_function = self.make_train_function() 1848 logs = self.train_function(iterator) 1849 1850 if reset_metrics: 1851 self.reset_metrics() 1852 logs = tf_utils.sync_to_numpy_or_python_type(logs) 1853 if return_dict: 1854 return logs 1855 else: 1856 return flatten_metrics_in_order(logs, self.metrics_names) 1857 1858 def test_on_batch(self, 1859 x, 1860 y=None, 1861 sample_weight=None, 1862 reset_metrics=True, 1863 return_dict=False): 1864 """Test the model on a single batch of samples. 1865 1866 Args: 1867 x: Input data. It could be: 1868 - A Numpy array (or array-like), or a list of arrays (in case the 1869 model has multiple inputs). 1870 - A TensorFlow tensor, or a list of tensors (in case the model has 1871 multiple inputs). 1872 - A dict mapping input names to the corresponding array/tensors, if 1873 the model has named inputs. 1874 y: Target data. Like the input data `x`, it could be either Numpy 1875 array(s) or TensorFlow tensor(s). It should be consistent with `x` 1876 (you cannot have Numpy inputs and tensor targets, or inversely). 1877 sample_weight: Optional array of the same length as x, containing 1878 weights to apply to the model's loss for each sample. In the case of 1879 temporal data, you can pass a 2D array with shape (samples, 1880 sequence_length), to apply a different weight to every timestep of 1881 every sample. 1882 reset_metrics: If `True`, the metrics returned will be only for this 1883 batch. If `False`, the metrics will be statefully accumulated across 1884 batches. 1885 return_dict: If `True`, loss and metric results are returned as a dict, 1886 with each key being the name of the metric. If `False`, they are 1887 returned as a list. 1888 1889 Returns: 1890 Scalar test loss (if the model has a single output and no metrics) 1891 or list of scalars (if the model has multiple outputs 1892 and/or metrics). The attribute `model.metrics_names` will give you 1893 the display labels for the scalar outputs. 1894 1895 Raises: 1896 RuntimeError: If `model.test_on_batch` is wrapped in `tf.function`. 1897 ValueError: In case of invalid user-provided arguments. 1898 """ 1899 self._assert_compile_was_called() 1900 self._check_call_args('test_on_batch') 1901 _disallow_inside_tf_function('test_on_batch') 1902 with self.distribute_strategy.scope(): 1903 iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x, 1904 y, sample_weight) 1905 self.test_function = self.make_test_function() 1906 logs = self.test_function(iterator) 1907 1908 if reset_metrics: 1909 self.reset_metrics() 1910 logs = tf_utils.sync_to_numpy_or_python_type(logs) 1911 if return_dict: 1912 return logs 1913 else: 1914 return flatten_metrics_in_order(logs, self.metrics_names) 1915 1916 def predict_on_batch(self, x): 1917 """Returns predictions for a single batch of samples. 1918 1919 Args: 1920 x: Input data. It could be: 1921 - A Numpy array (or array-like), or a list of arrays (in case the 1922 model has multiple inputs). 1923 - A TensorFlow tensor, or a list of tensors (in case the model has 1924 multiple inputs). 1925 1926 Returns: 1927 Numpy array(s) of predictions. 1928 1929 Raises: 1930 RuntimeError: If `model.predict_on_batch` is wrapped in `tf.function`. 1931 ValueError: In case of mismatch between given number of inputs and 1932 expectations of the model. 1933 """ 1934 self._check_call_args('predict_on_batch') 1935 _disallow_inside_tf_function('predict_on_batch') 1936 with self.distribute_strategy.scope(): 1937 iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x) 1938 self.predict_function = self.make_predict_function() 1939 outputs = self.predict_function(iterator) 1940 return tf_utils.sync_to_numpy_or_python_type(outputs) 1941 1942 def fit_generator(self, 1943 generator, 1944 steps_per_epoch=None, 1945 epochs=1, 1946 verbose=1, 1947 callbacks=None, 1948 validation_data=None, 1949 validation_steps=None, 1950 validation_freq=1, 1951 class_weight=None, 1952 max_queue_size=10, 1953 workers=1, 1954 use_multiprocessing=False, 1955 shuffle=True, 1956 initial_epoch=0): 1957 """Fits the model on data yielded batch-by-batch by a Python generator. 1958 1959 DEPRECATED: 1960 `Model.fit` now supports generators, so there is no longer any need to use 1961 this endpoint. 1962 """ 1963 warnings.warn('`Model.fit_generator` is deprecated and ' 1964 'will be removed in a future version. ' 1965 'Please use `Model.fit`, which supports generators.') 1966 return self.fit( 1967 generator, 1968 steps_per_epoch=steps_per_epoch, 1969 epochs=epochs, 1970 verbose=verbose, 1971 callbacks=callbacks, 1972 validation_data=validation_data, 1973 validation_steps=validation_steps, 1974 validation_freq=validation_freq, 1975 class_weight=class_weight, 1976 max_queue_size=max_queue_size, 1977 workers=workers, 1978 use_multiprocessing=use_multiprocessing, 1979 shuffle=shuffle, 1980 initial_epoch=initial_epoch) 1981 1982 def evaluate_generator(self, 1983 generator, 1984 steps=None, 1985 callbacks=None, 1986 max_queue_size=10, 1987 workers=1, 1988 use_multiprocessing=False, 1989 verbose=0): 1990 """Evaluates the model on a data generator. 1991 1992 DEPRECATED: 1993 `Model.evaluate` now supports generators, so there is no longer any need 1994 to use this endpoint. 1995 """ 1996 warnings.warn('`Model.evaluate_generator` is deprecated and ' 1997 'will be removed in a future version. ' 1998 'Please use `Model.evaluate`, which supports generators.') 1999 self._check_call_args('evaluate_generator') 2000 2001 return self.evaluate( 2002 generator, 2003 steps=steps, 2004 max_queue_size=max_queue_size, 2005 workers=workers, 2006 use_multiprocessing=use_multiprocessing, 2007 verbose=verbose, 2008 callbacks=callbacks) 2009 2010 def predict_generator(self, 2011 generator, 2012 steps=None, 2013 callbacks=None, 2014 max_queue_size=10, 2015 workers=1, 2016 use_multiprocessing=False, 2017 verbose=0): 2018 """Generates predictions for the input samples from a data generator. 2019 2020 DEPRECATED: 2021 `Model.predict` now supports generators, so there is no longer any need 2022 to use this endpoint. 2023 """ 2024 warnings.warn('`Model.predict_generator` is deprecated and ' 2025 'will be removed in a future version. ' 2026 'Please use `Model.predict`, which supports generators.') 2027 return self.predict( 2028 generator, 2029 steps=steps, 2030 max_queue_size=max_queue_size, 2031 workers=workers, 2032 use_multiprocessing=use_multiprocessing, 2033 verbose=verbose, 2034 callbacks=callbacks) 2035 2036 ###################################################################### 2037 # Functions below are not training related. They are for model weights 2038 # tracking, save/load, serialization, etc. 2039 ###################################################################### 2040 2041 @property 2042 def trainable_weights(self): 2043 self._assert_weights_created() 2044 if not self._trainable: 2045 return [] 2046 trainable_variables = [] 2047 for trackable_obj in self._self_tracked_trackables: 2048 trainable_variables += trackable_obj.trainable_variables 2049 trainable_variables += self._trainable_weights 2050 return self._dedup_weights(trainable_variables) 2051 2052 @property 2053 def non_trainable_weights(self): 2054 self._assert_weights_created() 2055 non_trainable_variables = [] 2056 for trackable_obj in self._self_tracked_trackables: 2057 non_trainable_variables += trackable_obj.non_trainable_variables 2058 2059 if not self._trainable: 2060 # Return order is all trainable vars, then all non-trainable vars. 2061 trainable_variables = [] 2062 for trackable_obj in self._self_tracked_trackables: 2063 trainable_variables += trackable_obj.trainable_variables 2064 2065 non_trainable_variables = ( 2066 trainable_variables + self._trainable_weights + 2067 non_trainable_variables + self._non_trainable_weights) 2068 else: 2069 non_trainable_variables = ( 2070 non_trainable_variables + self._non_trainable_weights) 2071 2072 return self._dedup_weights(non_trainable_variables) 2073 2074 def get_weights(self): 2075 """Retrieves the weights of the model. 2076 2077 Returns: 2078 A flat list of Numpy arrays. 2079 """ 2080 with self.distribute_strategy.scope(): 2081 return super(Model, self).get_weights() 2082 2083 def save(self, 2084 filepath, 2085 overwrite=True, 2086 include_optimizer=True, 2087 save_format=None, 2088 signatures=None, 2089 options=None, 2090 save_traces=True): 2091 # pylint: disable=line-too-long 2092 """Saves the model to Tensorflow SavedModel or a single HDF5 file. 2093 2094 Please see `tf.keras.models.save_model` or the 2095 [Serialization and Saving guide](https://keras.io/guides/serialization_and_saving/) 2096 for details. 2097 2098 Args: 2099 filepath: String, PathLike, path to SavedModel or H5 file to save the 2100 model. 2101 overwrite: Whether to silently overwrite any existing file at the 2102 target location, or provide the user with a manual prompt. 2103 include_optimizer: If True, save optimizer's state together. 2104 save_format: Either `'tf'` or `'h5'`, indicating whether to save the 2105 model to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X, 2106 and 'h5' in TF 1.X. 2107 signatures: Signatures to save with the SavedModel. Applicable to the 2108 'tf' format only. Please see the `signatures` argument in 2109 `tf.saved_model.save` for details. 2110 options: (only applies to SavedModel format) 2111 `tf.saved_model.SaveOptions` object that specifies options for 2112 saving to SavedModel. 2113 save_traces: (only applies to SavedModel format) When enabled, the 2114 SavedModel will store the function traces for each layer. This 2115 can be disabled, so that only the configs of each layer are stored. 2116 Defaults to `True`. Disabling this will decrease serialization time 2117 and reduce file size, but it requires that all custom layers/models 2118 implement a `get_config()` method. 2119 2120 Example: 2121 2122 ```python 2123 from keras.models import load_model 2124 2125 model.save('my_model.h5') # creates a HDF5 file 'my_model.h5' 2126 del model # deletes the existing model 2127 2128 # returns a compiled model 2129 # identical to the previous one 2130 model = load_model('my_model.h5') 2131 ``` 2132 """ 2133 # pylint: enable=line-too-long 2134 save.save_model(self, filepath, overwrite, include_optimizer, save_format, 2135 signatures, options, save_traces) 2136 2137 def save_weights(self, 2138 filepath, 2139 overwrite=True, 2140 save_format=None, 2141 options=None): 2142 """Saves all layer weights. 2143 2144 Either saves in HDF5 or in TensorFlow format based on the `save_format` 2145 argument. 2146 2147 When saving in HDF5 format, the weight file has: 2148 - `layer_names` (attribute), a list of strings 2149 (ordered names of model layers). 2150 - For every layer, a `group` named `layer.name` 2151 - For every such layer group, a group attribute `weight_names`, 2152 a list of strings 2153 (ordered names of weights tensor of the layer). 2154 - For every weight in the layer, a dataset 2155 storing the weight value, named after the weight tensor. 2156 2157 When saving in TensorFlow format, all objects referenced by the network are 2158 saved in the same format as `tf.train.Checkpoint`, including any `Layer` 2159 instances or `Optimizer` instances assigned to object attributes. For 2160 networks constructed from inputs and outputs using `tf.keras.Model(inputs, 2161 outputs)`, `Layer` instances used by the network are tracked/saved 2162 automatically. For user-defined classes which inherit from `tf.keras.Model`, 2163 `Layer` instances must be assigned to object attributes, typically in the 2164 constructor. See the documentation of `tf.train.Checkpoint` and 2165 `tf.keras.Model` for details. 2166 2167 While the formats are the same, do not mix `save_weights` and 2168 `tf.train.Checkpoint`. Checkpoints saved by `Model.save_weights` should be 2169 loaded using `Model.load_weights`. Checkpoints saved using 2170 `tf.train.Checkpoint.save` should be restored using the corresponding 2171 `tf.train.Checkpoint.restore`. Prefer `tf.train.Checkpoint` over 2172 `save_weights` for training checkpoints. 2173 2174 The TensorFlow format matches objects and variables by starting at a root 2175 object, `self` for `save_weights`, and greedily matching attribute 2176 names. For `Model.save` this is the `Model`, and for `Checkpoint.save` this 2177 is the `Checkpoint` even if the `Checkpoint` has a model attached. This 2178 means saving a `tf.keras.Model` using `save_weights` and loading into a 2179 `tf.train.Checkpoint` with a `Model` attached (or vice versa) will not match 2180 the `Model`'s variables. See the [guide to training 2181 checkpoints](https://www.tensorflow.org/guide/checkpoint) for details 2182 on the TensorFlow format. 2183 2184 Args: 2185 filepath: String or PathLike, path to the file to save the weights to. 2186 When saving in TensorFlow format, this is the prefix used for 2187 checkpoint files (multiple files are generated). Note that the '.h5' 2188 suffix causes weights to be saved in HDF5 format. 2189 overwrite: Whether to silently overwrite any existing file at the 2190 target location, or provide the user with a manual prompt. 2191 save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or 2192 '.keras' will default to HDF5 if `save_format` is `None`. Otherwise 2193 `None` defaults to 'tf'. 2194 options: Optional `tf.train.CheckpointOptions` object that specifies 2195 options for saving weights. 2196 2197 Raises: 2198 ImportError: If h5py is not available when attempting to save in HDF5 2199 format. 2200 ValueError: For invalid/unknown format arguments. 2201 """ 2202 self._assert_weights_created() 2203 filepath = path_to_string(filepath) 2204 filepath_is_h5 = saving_utils.is_hdf5_filepath(filepath) 2205 if save_format is None: 2206 if filepath_is_h5: 2207 save_format = 'h5' 2208 else: 2209 save_format = 'tf' 2210 else: 2211 user_format = save_format.lower().strip() 2212 if user_format in ('tensorflow', 'tf'): 2213 save_format = 'tf' 2214 elif user_format in ('hdf5', 'h5', 'keras'): 2215 save_format = 'h5' 2216 else: 2217 raise ValueError( 2218 'Unknown format "%s". Was expecting one of {"tf", "h5"}.' % ( 2219 save_format,)) 2220 if save_format == 'tf' and filepath_is_h5: 2221 raise ValueError( 2222 ('save_weights got save_format="tf"/"tensorflow", but the ' 2223 'filepath ("%s") looks like an HDF5 file. Omit the ".h5"/".keras" ' 2224 'when saving in TensorFlow format.') 2225 % filepath) 2226 2227 if save_format == 'h5' and h5py is None: 2228 raise ImportError( 2229 '`save_weights` requires h5py when saving in hdf5.') 2230 if save_format == 'tf': 2231 check_filepath = filepath + '.index' 2232 else: 2233 check_filepath = filepath 2234 # If file exists and should not be overwritten: 2235 if not overwrite and os.path.isfile(check_filepath): 2236 proceed = ask_to_proceed_with_overwrite(check_filepath) 2237 if not proceed: 2238 return 2239 if save_format == 'h5': 2240 with h5py.File(filepath, 'w') as f: 2241 hdf5_format.save_weights_to_hdf5_group(f, self.layers) 2242 else: 2243 if context.executing_eagerly(): 2244 session = None 2245 else: 2246 session = backend.get_session() 2247 self._trackable_saver.save(filepath, session=session, options=options) 2248 # Record this checkpoint so it's visible from tf.train.latest_checkpoint. 2249 checkpoint_management.update_checkpoint_state_internal( 2250 save_dir=os.path.dirname(filepath), 2251 model_checkpoint_path=filepath, 2252 save_relative_paths=True, 2253 all_model_checkpoint_paths=[filepath]) 2254 2255 def load_weights(self, 2256 filepath, 2257 by_name=False, 2258 skip_mismatch=False, 2259 options=None): 2260 """Loads all layer weights, either from a TensorFlow or an HDF5 weight file. 2261 2262 If `by_name` is False weights are loaded based on the network's 2263 topology. This means the architecture should be the same as when the weights 2264 were saved. Note that layers that don't have weights are not taken into 2265 account in the topological ordering, so adding or removing layers is fine as 2266 long as they don't have weights. 2267 2268 If `by_name` is True, weights are loaded into layers only if they share the 2269 same name. This is useful for fine-tuning or transfer-learning models where 2270 some of the layers have changed. 2271 2272 Only topological loading (`by_name=False`) is supported when loading weights 2273 from the TensorFlow format. Note that topological loading differs slightly 2274 between TensorFlow and HDF5 formats for user-defined classes inheriting from 2275 `tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the 2276 TensorFlow format loads based on the object-local names of attributes to 2277 which layers are assigned in the `Model`'s constructor. 2278 2279 Args: 2280 filepath: String, path to the weights file to load. For weight files in 2281 TensorFlow format, this is the file prefix (the same as was passed 2282 to `save_weights`). This can also be a path to a SavedModel 2283 saved from `model.save`. 2284 by_name: Boolean, whether to load weights by name or by topological 2285 order. Only topological loading is supported for weight files in 2286 TensorFlow format. 2287 skip_mismatch: Boolean, whether to skip loading of layers where there is 2288 a mismatch in the number of weights, or a mismatch in the shape of 2289 the weight (only valid when `by_name=True`). 2290 options: Optional `tf.train.CheckpointOptions` object that specifies 2291 options for loading weights. 2292 2293 Returns: 2294 When loading a weight file in TensorFlow format, returns the same status 2295 object as `tf.train.Checkpoint.restore`. When graph building, restore 2296 ops are run automatically as soon as the network is built (on first call 2297 for user-defined classes inheriting from `Model`, immediately if it is 2298 already built). 2299 2300 When loading weights in HDF5 format, returns `None`. 2301 2302 Raises: 2303 ImportError: If h5py is not available and the weight file is in HDF5 2304 format. 2305 ValueError: If `skip_mismatch` is set to `True` when `by_name` is 2306 `False`. 2307 """ 2308 if backend.is_tpu_strategy(self._distribution_strategy): 2309 if (self._distribution_strategy.extended.steps_per_run > 1 and 2310 (not saving_utils.is_hdf5_filepath(filepath))): 2311 raise ValueError('Load weights is not yet supported with TPUStrategy ' 2312 'with steps_per_run greater than 1.') 2313 if skip_mismatch and not by_name: 2314 raise ValueError( 2315 'When calling model.load_weights, skip_mismatch can only be set to ' 2316 'True when by_name is True.') 2317 2318 filepath, save_format = _detect_save_format(filepath) 2319 if save_format == 'tf': 2320 status = self._trackable_saver.restore(filepath, options) 2321 if by_name: 2322 raise NotImplementedError( 2323 'Weights may only be loaded based on topology into Models when ' 2324 'loading TensorFlow-formatted weights (got by_name=True to ' 2325 'load_weights).') 2326 if not context.executing_eagerly(): 2327 session = backend.get_session() 2328 # Restore existing variables (if any) immediately, and set up a 2329 # streaming restore for any variables created in the future. 2330 trackable_utils.streaming_restore(status=status, session=session) 2331 status.assert_nontrivial_match() 2332 else: 2333 status = None 2334 if h5py is None: 2335 raise ImportError( 2336 '`load_weights` requires h5py when loading weights from HDF5.') 2337 if not self._is_graph_network and not self.built: 2338 raise ValueError( 2339 'Unable to load weights saved in HDF5 format into a subclassed ' 2340 'Model which has not created its variables yet. Call the Model ' 2341 'first, then load the weights.') 2342 self._assert_weights_created() 2343 with h5py.File(filepath, 'r') as f: 2344 if 'layer_names' not in f.attrs and 'model_weights' in f: 2345 f = f['model_weights'] 2346 if by_name: 2347 hdf5_format.load_weights_from_hdf5_group_by_name( 2348 f, self.layers, skip_mismatch=skip_mismatch) 2349 else: 2350 hdf5_format.load_weights_from_hdf5_group(f, self.layers) 2351 2352 # Perform any layer defined finalization of the layer state. 2353 for layer in self.layers: 2354 layer.finalize_state() 2355 return status 2356 2357 def _updated_config(self): 2358 """Util shared between different serialization methods. 2359 2360 Returns: 2361 Model config with Keras version information added. 2362 """ 2363 from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top 2364 2365 config = self.get_config() 2366 model_config = { 2367 'class_name': self.__class__.__name__, 2368 'config': config, 2369 'keras_version': keras_version, 2370 'backend': backend.backend() 2371 } 2372 return model_config 2373 2374 def get_config(self): 2375 raise NotImplementedError 2376 2377 @classmethod 2378 def from_config(cls, config, custom_objects=None): 2379 # `from_config` assumes `cls` is either `Functional` or a child class of 2380 # `Functional`. In the case that `cls` is meant to behave like a child class 2381 # of `Functional` but only inherits from the `Model` class, we have to call 2382 # `cls(...)` instead of `Functional.from_config`. 2383 from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top 2384 with generic_utils.SharedObjectLoadingScope(): 2385 input_tensors, output_tensors, created_layers = ( 2386 functional.reconstruct_from_config(config, custom_objects)) 2387 # Initialize a model belonging to `cls`, which can be user-defined or 2388 # `Functional`. 2389 model = cls(inputs=input_tensors, outputs=output_tensors, 2390 name=config.get('name')) 2391 functional.connect_ancillary_layers(model, created_layers) 2392 return model 2393 2394 def to_json(self, **kwargs): 2395 """Returns a JSON string containing the network configuration. 2396 2397 To load a network from a JSON save file, use 2398 `keras.models.model_from_json(json_string, custom_objects={})`. 2399 2400 Args: 2401 **kwargs: Additional keyword arguments 2402 to be passed to `json.dumps()`. 2403 2404 Returns: 2405 A JSON string. 2406 """ 2407 model_config = self._updated_config() 2408 return json.dumps( 2409 model_config, default=json_utils.get_json_type, **kwargs) 2410 2411 def to_yaml(self, **kwargs): 2412 """Returns a yaml string containing the network configuration. 2413 2414 Note: Since TF 2.6, this method is no longer supported and will raise a 2415 RuntimeError. 2416 2417 To load a network from a yaml save file, use 2418 `keras.models.model_from_yaml(yaml_string, custom_objects={})`. 2419 2420 `custom_objects` should be a dictionary mapping 2421 the names of custom losses / layers / etc to the corresponding 2422 functions / classes. 2423 2424 Args: 2425 **kwargs: Additional keyword arguments 2426 to be passed to `yaml.dump()`. 2427 2428 Returns: 2429 A YAML string. 2430 2431 Raises: 2432 RuntimeError: announces that the method poses a security risk 2433 """ 2434 raise RuntimeError( 2435 'Method `model.to_yaml()` has been removed due to security risk of ' 2436 'arbitrary code execution. Please use `model.to_json()` instead.' 2437 ) 2438 2439 def reset_states(self): 2440 for layer in self.layers: 2441 if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False): 2442 layer.reset_states() 2443 2444 @property 2445 @doc_controls.do_not_generate_docs 2446 def state_updates(self): 2447 """Deprecated, do NOT use! 2448 2449 Returns the `updates` from all layers that are stateful. 2450 2451 This is useful for separating training updates and 2452 state updates, e.g. when we need to update a layer's internal state 2453 during prediction. 2454 2455 Returns: 2456 A list of update ops. 2457 """ 2458 warnings.warn('`Model.state_updates` will be removed in a future version. ' 2459 'This property should not be used in TensorFlow 2.0, ' 2460 'as `updates` are applied automatically.') 2461 state_updates = [] 2462 for layer in self.layers: 2463 if getattr(layer, 'stateful', False): 2464 if hasattr(layer, 'updates'): 2465 state_updates += layer.updates 2466 return state_updates 2467 2468 @property 2469 def weights(self): 2470 """Returns the list of all layer variables/weights. 2471 2472 Note: This will not track the weights of nested `tf.Modules` that are not 2473 themselves Keras layers. 2474 2475 Returns: 2476 A list of variables. 2477 """ 2478 return self._dedup_weights(self._undeduplicated_weights) 2479 2480 @property 2481 def _undeduplicated_weights(self): 2482 """Returns the undeduplicated list of all layer variables/weights.""" 2483 self._assert_weights_created() 2484 weights = [] 2485 for layer in self._self_tracked_trackables: 2486 weights += layer.variables 2487 weights += (self._trainable_weights + self._non_trainable_weights) 2488 return weights 2489 2490 def summary(self, line_length=None, positions=None, print_fn=None): 2491 """Prints a string summary of the network. 2492 2493 Args: 2494 line_length: Total length of printed lines 2495 (e.g. set this to adapt the display to different 2496 terminal window sizes). 2497 positions: Relative or absolute positions of log elements 2498 in each line. If not provided, 2499 defaults to `[.33, .55, .67, 1.]`. 2500 print_fn: Print function to use. Defaults to `print`. 2501 It will be called on each line of the summary. 2502 You can set it to a custom function 2503 in order to capture the string summary. 2504 2505 Raises: 2506 ValueError: if `summary()` is called before the model is built. 2507 """ 2508 if not self.built: 2509 raise ValueError('This model has not yet been built. ' 2510 'Build the model first by calling `build()` or calling ' 2511 '`fit()` with some data, or specify ' 2512 'an `input_shape` argument in the first layer(s) for ' 2513 'automatic build.') 2514 layer_utils.print_summary(self, 2515 line_length=line_length, 2516 positions=positions, 2517 print_fn=print_fn) 2518 2519 @property 2520 def layers(self): 2521 return list(self._flatten_layers(include_self=False, recursive=False)) 2522 2523 def get_layer(self, name=None, index=None): 2524 """Retrieves a layer based on either its name (unique) or index. 2525 2526 If `name` and `index` are both provided, `index` will take precedence. 2527 Indices are based on order of horizontal graph traversal (bottom-up). 2528 2529 Args: 2530 name: String, name of layer. 2531 index: Integer, index of layer. 2532 2533 Returns: 2534 A layer instance. 2535 2536 Raises: 2537 ValueError: In case of invalid layer name or index. 2538 """ 2539 # TODO(fchollet): We could build a dictionary based on layer names 2540 # since they are constant, but we have not done that yet. 2541 if index is not None and name is not None: 2542 raise ValueError('Provide only a layer name or a layer index.') 2543 2544 if index is not None: 2545 if len(self.layers) <= index: 2546 raise ValueError('Was asked to retrieve layer at index ' + str(index) + 2547 ' but model only has ' + str(len(self.layers)) + 2548 ' layers.') 2549 else: 2550 return self.layers[index] 2551 2552 if name is not None: 2553 for layer in self.layers: 2554 if layer.name == name: 2555 return layer 2556 raise ValueError('No such layer: ' + name + '.') 2557 raise ValueError('Provide either a layer name or layer index.') 2558 2559 @trackable.no_automatic_dependency_tracking 2560 def _set_save_spec(self, inputs): 2561 if self._saved_model_inputs_spec is not None: 2562 return # Already set. 2563 2564 input_names = self.input_names 2565 if not input_names: 2566 input_names = compile_utils.create_pseudo_input_names(inputs) 2567 2568 flat_inputs = nest.flatten(inputs) 2569 specs = [] 2570 for name, tensor in zip(input_names, flat_inputs): 2571 specs.append( 2572 tf_utils.get_tensor_spec(tensor, dynamic_batch=False, name=name)) 2573 specs = nest.pack_sequence_as(inputs, specs) 2574 2575 self._saved_model_inputs_spec = specs 2576 2577 # Store the input shapes 2578 if (self.__class__.__name__ == 'Sequential' and 2579 self._build_input_shape is None): 2580 self._build_input_shape = nest.map_structure( 2581 lambda x: None if x is None else x.shape, specs) 2582 2583 def _assert_weights_created(self): 2584 """Asserts that all the weights for the model have been created. 2585 2586 For a non-dynamic model, the weights must already be created after the 2587 layer has been called. For a dynamic model, the exact list of weights can 2588 never be known for certain since it may change at any time during execution. 2589 2590 We run this check right before accessing weights or getting the Numpy value 2591 for the current weights. Otherwise, if the layer has never been called, 2592 the user would just get an empty list, which is misleading. 2593 2594 Raises: 2595 ValueError: if the weights of the network has not yet been created. 2596 """ 2597 if self.dynamic: 2598 return 2599 2600 if ('build' in self.__class__.__dict__ and 2601 self.__class__ != Model and 2602 not self.built): 2603 # For any model that has customized build() method but hasn't 2604 # been invoked yet, this will cover both sequential and subclass model. 2605 # Also make sure to exclude Model class itself which has build() defined. 2606 raise ValueError('Weights for model %s have not yet been created. ' 2607 'Weights are created when the Model is first called on ' 2608 'inputs or `build()` is called with an `input_shape`.' % 2609 self.name) 2610 2611 def _check_call_args(self, method_name): 2612 """Check that `call` has only one positional arg.""" 2613 # Always allow first arg, regardless of arg name. 2614 fullargspec = self._call_full_argspec 2615 if fullargspec.defaults: 2616 positional_args = fullargspec.args[:-len(fullargspec.defaults)] 2617 else: 2618 positional_args = fullargspec.args 2619 if 'training' in positional_args: 2620 positional_args.remove('training') 2621 2622 # self and first arg can be positional. 2623 if len(positional_args) > 2: 2624 extra_args = positional_args[2:] 2625 raise ValueError( 2626 'Models passed to `' + method_name + '` can only have `training` ' 2627 'and the first argument in `call` as positional arguments, ' 2628 'found: ' + str(extra_args) + '.') 2629 2630 def _validate_compile(self, optimizer, metrics, **kwargs): 2631 """Performs validation checks for the default `compile`.""" 2632 if any( 2633 isinstance(opt, optimizer_v1.Optimizer) 2634 for opt in nest.flatten(optimizer)): 2635 raise ValueError( 2636 '`tf.compat.v1.keras` Optimizer (', optimizer, ') is ' 2637 'not supported when eager execution is enabled. Use a ' 2638 '`tf.keras` Optimizer instead, or disable eager ' 2639 'execution.') 2640 2641 kwargs.pop('cloning', None) # Legacy DistStrat argument, never used. 2642 kwargs.pop('experimental_run_tf_function', None) # Always `True`. 2643 if kwargs.pop('distribute', None) is not None: 2644 raise ValueError( 2645 'Distribute argument in compile is not available in TF 2.0 please ' 2646 'create the model under the distribution strategy scope.') 2647 if kwargs.pop('target_tensors', None) is not None: 2648 raise ValueError( 2649 'target_tensors argument is not supported when executing eagerly.') 2650 invalid_kwargs = set(kwargs) - {'sample_weight_mode'} 2651 if invalid_kwargs: 2652 raise TypeError('Invalid keyword argument(s) in `compile`: %s' % 2653 (invalid_kwargs,)) 2654 2655 # Model must be created and compiled with the same DistStrat. 2656 if self.built and ds_context.has_strategy(): 2657 strategy = ds_context.get_strategy() 2658 for v in self.variables: 2659 if not strategy.extended.variable_created_in_scope(v): 2660 raise ValueError( 2661 'Variable (%s) was not created in the distribution strategy ' 2662 'scope of (%s). It is most likely due to not all layers or ' 2663 'the model or optimizer being created outside the distribution ' 2664 'strategy scope. Try to make sure your code looks similar ' 2665 'to the following.\n' 2666 'with strategy.scope():\n' 2667 ' model=_create_model()\n' 2668 ' model.compile(...)' % (v, strategy)) 2669 2670 # Model metrics must be created in the same distribution strategy scope 2671 # as the model. 2672 strategy = self.distribute_strategy 2673 for metric in nest.flatten(metrics): 2674 for v in getattr(metric, 'variables', []): 2675 if not strategy.extended.variable_created_in_scope(v): 2676 raise ValueError( 2677 'Metric (%s) passed to model.compile was created inside of a ' 2678 'different distribution strategy scope than the model. All ' 2679 'metrics must be created in the same distribution strategy ' 2680 'scope as the model (in this case %s). If you pass in a string ' 2681 'identifier for a metric to compile the metric will ' 2682 'automatically be created in the correct distribution ' 2683 'strategy scope.' % (metric, strategy) 2684 ) 2685 2686 # Model metrics must be created in the same distribution strategy scope 2687 # as the model. 2688 for opt in nest.flatten(optimizer): 2689 for v in getattr(opt, '_weights', []): 2690 if not strategy.extended.variable_created_in_scope(v): 2691 raise ValueError( 2692 'Optimizer (%s) passed to model.compile was created inside of a ' 2693 'different distribution strategy scope than the model. All ' 2694 'optimizers must be created in the same distribution strategy ' 2695 'scope as the model (in this case %s). If you pass in a string ' 2696 'identifier for an optimizer to compile the optimizer will ' 2697 'automatically be created in the correct distribution ' 2698 'strategy scope.' % (opt, strategy)) 2699 2700 def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch): 2701 """Maybe load initial epoch from ckpt considering possible worker recovery. 2702 2703 Refer to tensorflow/python/keras/distribute/worker_training_state.py 2704 for more information. 2705 2706 Args: 2707 initial_epoch: The original initial_epoch user passes in in `fit()`. 2708 2709 Returns: 2710 If the training is recovering from previous failure under multi-worker 2711 training setting, return the epoch the training is supposed to continue 2712 at. Otherwise, return the `initial_epoch` the user passes in. 2713 """ 2714 if self._training_state is not None: 2715 return self._training_state.maybe_load_initial_epoch_from_ckpt( 2716 initial_epoch, mode=ModeKeys.TRAIN) 2717 return initial_epoch 2718 2719 def _assert_compile_was_called(self): 2720 # Checks whether `compile` has been called. If it has been called, 2721 # then the optimizer is set. This is different from whether the 2722 # model is compiled 2723 # (i.e. whether the model is built and its inputs/outputs are set). 2724 if not self._is_compiled: 2725 raise RuntimeError('You must compile your model before ' 2726 'training/testing. ' 2727 'Use `model.compile(optimizer, loss)`.') 2728 2729 def _set_inputs(self, inputs, outputs=None, training=None): 2730 """This method is for compat with Modelv1. Only inputs are needed here.""" 2731 self._set_save_spec(inputs) 2732 2733 @property 2734 def _trackable_saved_model_saver(self): 2735 return model_serialization.ModelSavedModelSaver(self) 2736 2737 def _list_functions_for_serialization(self, serialization_cache): 2738 # SavedModel needs to ignore the execution functions. 2739 train_function = self.train_function 2740 test_function = self.test_function 2741 predict_function = self.predict_function 2742 train_tf_function = self.train_tf_function 2743 self.train_function = None 2744 self.test_function = None 2745 self.predict_function = None 2746 self.train_tf_function = None 2747 functions = super( 2748 Model, self)._list_functions_for_serialization(serialization_cache) 2749 self.train_function = train_function 2750 self.test_function = test_function 2751 self.predict_function = predict_function 2752 self.train_tf_function = train_tf_function 2753 return functions 2754 2755 def _should_eval(self, epoch, validation_freq): 2756 epoch = epoch + 1 # one-index the user-facing epoch. 2757 if isinstance(validation_freq, int): 2758 return epoch % validation_freq == 0 2759 elif isinstance(validation_freq, list): 2760 return epoch in validation_freq 2761 else: 2762 raise ValueError('Expected `validation_freq` to be a list or int.') 2763 2764 ###################################################################### 2765 # Functions below exist only as v1 / v2 compatibility shims. 2766 ###################################################################### 2767 2768 def _get_compile_args(self, user_metrics=True): 2769 """Used for saving or cloning a Model. 2770 2771 Args: 2772 user_metrics: Whether to return user-supplied metrics or `Metric` objects. 2773 Defaults to returning the user-supplied metrics. 2774 2775 Returns: 2776 Dictionary of arguments that were used when compiling the model. 2777 """ 2778 self._assert_compile_was_called() 2779 # pylint: disable=protected-access 2780 2781 saved_metrics = self.compiled_metrics._user_metrics 2782 saved_weighted_metrics = self.compiled_metrics._user_weighted_metrics 2783 2784 if not user_metrics: 2785 if saved_metrics is not None: 2786 saved_metrics = self.compiled_metrics._metrics 2787 if saved_weighted_metrics is not None: 2788 saved_weighted_metrics = self.compiled_metrics._weighted_metrics 2789 2790 compile_args = { 2791 'optimizer': self.optimizer, 2792 'loss': self.compiled_loss._user_losses, 2793 'metrics': saved_metrics, 2794 'weighted_metrics': saved_weighted_metrics, 2795 'loss_weights': self.compiled_loss._user_loss_weights, 2796 } 2797 # pylint: enable=protected-access 2798 return compile_args 2799 2800 def _get_callback_model(self): 2801 return self 2802 2803 def _in_multi_worker_mode(self): 2804 return self.distribute_strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access 2805 2806 @property 2807 def _compile_was_called(self): 2808 return self._is_compiled 2809 2810 2811def reduce_per_replica(values, strategy, reduction='first'): 2812 """Reduce PerReplica objects. 2813 2814 Args: 2815 values: Structure of `PerReplica` objects or `Tensor`s. `Tensor`s are 2816 returned as-is. 2817 strategy: `tf.distribute.Strategy` object. 2818 reduction: One of 'first', 'concat'. 2819 2820 Returns: 2821 Structure of `Tensor`s. 2822 """ 2823 2824 def _reduce(v): 2825 """Reduce a single `PerReplica` object.""" 2826 if reduction == 'concat' and _collective_all_reduce_multi_worker(strategy): 2827 return _multi_worker_concat(v, strategy) 2828 if not _is_per_replica_instance(v): 2829 return v 2830 elif reduction == 'first': 2831 return strategy.unwrap(v)[0] 2832 elif reduction == 'concat': 2833 if _is_tpu_multi_host(strategy): 2834 return _tpu_multi_host_concat(v, strategy) 2835 else: 2836 return concat(strategy.unwrap(v)) 2837 else: 2838 raise ValueError('`reduction` must be "first" or "concat".') 2839 2840 return nest.map_structure(_reduce, values) 2841 2842 2843def concat(tensors, axis=0): 2844 """Concats `tensor`s along `axis`.""" 2845 if isinstance(tensors[0], sparse_tensor.SparseTensor): 2846 return sparse_ops.sparse_concat_v2(axis=axis, sp_inputs=tensors) 2847 return array_ops.concat(tensors, axis=axis) 2848 2849 2850def _is_tpu_multi_host(strategy): 2851 return (backend.is_tpu_strategy(strategy) and 2852 strategy.extended.num_hosts > 1) 2853 2854 2855def _tpu_multi_host_concat(v, strategy): 2856 """Correctly order TPU PerReplica objects.""" 2857 replicas = strategy.unwrap(v) 2858 # When distributed datasets are created from Tensors / NumPy, 2859 # TPUStrategy.experimental_distribute_dataset shards data in 2860 # (Replica, Host) order, and TPUStrategy.unwrap returns it in 2861 # (Host, Replica) order. 2862 # TODO(b/150317897): Figure out long-term plan here. 2863 num_replicas_per_host = strategy.extended.num_replicas_per_host 2864 ordered_replicas = [] 2865 for replica_id in range(num_replicas_per_host): 2866 ordered_replicas += replicas[replica_id::num_replicas_per_host] 2867 return concat(ordered_replicas) 2868 2869 2870def _collective_all_reduce_multi_worker(strategy): 2871 return (isinstance(strategy, 2872 collective_all_reduce_strategy.CollectiveAllReduceStrategy) 2873 ) and strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access 2874 2875 2876# TODO(wxinyi): merge this with _tpu_multi_host_concat once we have all_gather 2877# for all strategies 2878def _multi_worker_concat(v, strategy): 2879 """Order PerReplica objects for CollectiveAllReduceStrategy and concat.""" 2880 replicas = strategy.gather(v, axis=0) 2881 # v might not have the same shape on different replicas 2882 if _is_per_replica_instance(v): 2883 shapes = array_ops.concat([ 2884 array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0) 2885 for single_value in v.values 2886 ], 2887 axis=0) 2888 all_shapes = strategy.gather(shapes, axis=0) 2889 else: 2890 # v is a tensor. This may happen when, say, we have 2x1 multi-worker. 2891 all_shapes = strategy.gather( 2892 array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0), axis=0) 2893 2894 replicas = array_ops.split( 2895 replicas, 2896 num_or_size_splits=all_shapes, 2897 num=strategy.num_replicas_in_sync) 2898 ordered_replicas = [] 2899 num_replicas_per_worker = len(strategy.extended.worker_devices) 2900 for replica_id in range(num_replicas_per_worker): 2901 ordered_replicas += replicas[replica_id::num_replicas_per_worker] 2902 return concat(ordered_replicas) 2903 2904 2905def _is_scalar(x): 2906 return isinstance(x, (ops.Tensor, variables.Variable)) and x.shape.rank == 0 2907 2908 2909def write_scalar_summaries(logs, step): 2910 for name, value in logs.items(): 2911 if _is_scalar(value): 2912 summary_ops_v2.scalar('batch_' + name, value, step=step) 2913 2914 2915def _minimum_control_deps(outputs): 2916 """Returns the minimum control dependencies to ensure step succeeded.""" 2917 if context.executing_eagerly(): 2918 return [] # Control dependencies not needed. 2919 outputs = nest.flatten(outputs, expand_composites=True) 2920 for out in outputs: 2921 # Variables can't be control dependencies. 2922 if not isinstance(out, variables.Variable): 2923 return [out] # Return first Tensor or Op from outputs. 2924 return [] # No viable Tensor or Op to use for control deps. 2925 2926 2927def _disallow_inside_tf_function(method_name): 2928 if ops.inside_function(): 2929 error_msg = ( 2930 'Detected a call to `Model.{method_name}` inside a `tf.function`. ' 2931 '`Model.{method_name} is a high-level endpoint that manages its own ' 2932 '`tf.function`. Please move the call to `Model.{method_name}` outside ' 2933 'of all enclosing `tf.function`s. Note that you can call a `Model` ' 2934 'directly on `Tensor`s inside a `tf.function` like: `model(x)`.' 2935 ).format(method_name=method_name) 2936 raise RuntimeError(error_msg) 2937 2938 2939def _detect_save_format(filepath): 2940 """Returns path to weights file and save format.""" 2941 2942 filepath = path_to_string(filepath) 2943 if saving_utils.is_hdf5_filepath(filepath): 2944 return filepath, 'h5' 2945 2946 # Filepath could be a TensorFlow checkpoint file prefix or SavedModel 2947 # directory. It's possible for filepath to be both a prefix and directory. 2948 # Prioritize checkpoint over SavedModel. 2949 if _is_readable_tf_checkpoint(filepath): 2950 save_format = 'tf' 2951 elif sm_loader.contains_saved_model(filepath): 2952 ckpt_path = os.path.join(filepath, sm_constants.VARIABLES_DIRECTORY, 2953 sm_constants.VARIABLES_FILENAME) 2954 if _is_readable_tf_checkpoint(ckpt_path): 2955 filepath = ckpt_path 2956 save_format = 'tf' 2957 else: 2958 raise ValueError('Unable to load weights. filepath {} appears to be a ' 2959 'SavedModel directory, but checkpoint either doesn\'t ' 2960 'exist, or is incorrectly formatted.'.format(filepath)) 2961 else: 2962 # Not a TensorFlow checkpoint. This filepath is likely an H5 file that 2963 # doesn't have the hdf5/keras extensions. 2964 save_format = 'h5' 2965 return filepath, save_format 2966 2967 2968def _is_readable_tf_checkpoint(filepath): 2969 try: 2970 py_checkpoint_reader.NewCheckpointReader(filepath) 2971 return True 2972 except errors_impl.DataLossError: 2973 # The checkpoint is not readable in TensorFlow format. 2974 return False 2975 2976 2977def flatten_metrics_in_order(logs, metrics_names): 2978 """Turns the `logs` dict into a list as per key order of `metrics_names`.""" 2979 results = [] 2980 for name in metrics_names: 2981 if name in logs: 2982 results.append(logs[name]) 2983 for key in sorted(logs.keys()): 2984 if key not in metrics_names: 2985 results.append(logs[key]) 2986 if len(results) == 1: 2987 return results[0] 2988 return results 2989 2990 2991def _is_per_replica_instance(obj): 2992 return (isinstance(obj, ds_values.DistributedValues) and 2993 isinstance(obj, composite_tensor.CompositeTensor)) 2994 2995 2996def saver_with_op_caching(obj): 2997 if context.executing_eagerly(): 2998 saveables_cache = None 2999 else: 3000 saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary() 3001 return trackable_utils.TrackableSaver( 3002 graph_view_lib.ObjectGraphView( 3003 weakref.ref(obj), saveables_cache=saveables_cache)) 3004