1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""V1 Training-related part of the Keras engine.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import warnings 22 23import numpy as np 24 25from tensorflow.python import tf2 26from tensorflow.python.data.ops import dataset_ops 27from tensorflow.python.data.ops import iterator_ops 28from tensorflow.python.distribute import distribution_strategy_context 29from tensorflow.python.distribute import parameter_server_strategy 30from tensorflow.python.distribute import parameter_server_strategy_v2 31from tensorflow.python.eager import context 32from tensorflow.python.eager import def_function 33from tensorflow.python.framework import constant_op 34from tensorflow.python.framework import ops 35from tensorflow.python.framework import sparse_tensor 36from tensorflow.python.framework import tensor_shape 37from tensorflow.python.framework import tensor_spec 38from tensorflow.python.framework import tensor_util 39from tensorflow.python.framework import type_spec 40from tensorflow.python.keras import backend as K 41from tensorflow.python.keras import losses 42from tensorflow.python.keras import metrics as metrics_module 43from tensorflow.python.keras import optimizer_v1 44from tensorflow.python.keras import optimizers 45from tensorflow.python.keras.distribute import distributed_training_utils 46from tensorflow.python.keras.distribute import distributed_training_utils_v1 47from tensorflow.python.keras.engine import base_layer 48from tensorflow.python.keras.engine import training as training_lib 49from tensorflow.python.keras.engine import training_arrays_v1 50from tensorflow.python.keras.engine import training_distributed_v1 51from tensorflow.python.keras.engine import training_eager_v1 52from tensorflow.python.keras.engine import training_generator_v1 53from tensorflow.python.keras.engine import training_utils 54from tensorflow.python.keras.engine import training_utils_v1 55from tensorflow.python.keras.mixed_precision import loss_scale_optimizer 56from tensorflow.python.keras.mixed_precision import policy 57from tensorflow.python.keras.optimizer_v2 import optimizer_v2 58from tensorflow.python.keras.saving import saving_utils 59from tensorflow.python.keras.saving.saved_model import model_serialization 60from tensorflow.python.keras.utils import data_utils 61from tensorflow.python.keras.utils import layer_utils 62from tensorflow.python.keras.utils import losses_utils 63from tensorflow.python.keras.utils import tf_inspect 64from tensorflow.python.keras.utils import tf_utils 65from tensorflow.python.keras.utils.mode_keys import ModeKeys 66from tensorflow.python.ops import array_ops 67from tensorflow.python.ops import math_ops 68from tensorflow.python.platform import tf_logging as logging 69from tensorflow.python.training.tracking import base as trackable 70from tensorflow.python.util import nest 71 72try: 73 from scipy.sparse import issparse # pylint: disable=g-import-not-at-top 74except ImportError: 75 issparse = None 76 77 78class Model(training_lib.Model): 79 """`Model` groups layers into an object with training and inference features. 80 81 There are two ways to instantiate a `Model`: 82 83 1 - With the "functional API", where you start from `Input`, 84 you chain layer calls to specify the model's forward pass, 85 and finally you create your model from inputs and outputs: 86 87 ```python 88 import tensorflow as tf 89 90 inputs = tf.keras.Input(shape=(3,)) 91 x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs) 92 outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x) 93 model = tf.keras.Model(inputs=inputs, outputs=outputs) 94 ``` 95 96 2 - By subclassing the `Model` class: in that case, you should define your 97 layers in `__init__` and you should implement the model's forward pass 98 in `call`. 99 100 ```python 101 import tensorflow as tf 102 103 class MyModel(tf.keras.Model): 104 105 def __init__(self): 106 super(MyModel, self).__init__() 107 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 108 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 109 110 def call(self, inputs): 111 x = self.dense1(inputs) 112 return self.dense2(x) 113 114 model = MyModel() 115 ``` 116 117 If you subclass `Model`, you can optionally have 118 a `training` argument (boolean) in `call`, which you can use to specify 119 a different behavior in training and inference: 120 121 ```python 122 import tensorflow as tf 123 124 class MyModel(tf.keras.Model): 125 126 def __init__(self): 127 super(MyModel, self).__init__() 128 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 129 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 130 self.dropout = tf.keras.layers.Dropout(0.5) 131 132 def call(self, inputs, training=False): 133 x = self.dense1(inputs) 134 if training: 135 x = self.dropout(x, training=training) 136 return self.dense2(x) 137 138 model = MyModel() 139 ``` 140 """ 141 142 def __init__(self, *args, **kwargs): 143 super(Model, self).__init__(*args, **kwargs) 144 # initializing _distribution_strategy here since it is possible to call 145 # predict on a model without compiling it. 146 self._distribution_strategy = None 147 self._compile_time_distribution_strategy = None 148 if (ops.executing_eagerly_outside_functions() and 149 distribution_strategy_context.has_strategy()): 150 self._set_strategy( 151 distribution_strategy_context.get_strategy()) 152 153 # This flag is used to track if the user is using the deprecated path of 154 # passing distribution strategy to compile rather than creating the model 155 # under distribution strategy scope. 156 self._compile_distribution = False 157 158 self._run_eagerly = None 159 self._experimental_run_tf_function = ( 160 ops.executing_eagerly_outside_functions()) 161 162 self._v1_compile_was_called = False 163 164 def _init_batch_counters(self): 165 pass # Batch counters should not be created in legacy graph mode. 166 167 @trackable.no_automatic_dependency_tracking 168 def _set_strategy(self, strategy): 169 self._compile_time_distribution_strategy = strategy 170 171 def get_weights(self): 172 """Retrieves the weights of the model. 173 174 Returns: 175 A flat list of Numpy arrays. 176 """ 177 strategy = (self._distribution_strategy or 178 self._compile_time_distribution_strategy) 179 if strategy: 180 with strategy.scope(): 181 return base_layer.Layer.get_weights(self) 182 return base_layer.Layer.get_weights(self) 183 184 def load_weights(self, filepath, by_name=False, skip_mismatch=False): 185 """Loads all layer weights, either from a TensorFlow or an HDF5 weight file. 186 187 If `by_name` is False weights are loaded based on the network's 188 topology. This means the architecture should be the same as when the weights 189 were saved. Note that layers that don't have weights are not taken into 190 account in the topological ordering, so adding or removing layers is fine as 191 long as they don't have weights. 192 193 If `by_name` is True, weights are loaded into layers only if they share the 194 same name. This is useful for fine-tuning or transfer-learning models where 195 some of the layers have changed. 196 197 Only topological loading (`by_name=False`) is supported when loading weights 198 from the TensorFlow format. Note that topological loading differs slightly 199 between TensorFlow and HDF5 formats for user-defined classes inheriting from 200 `tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the 201 TensorFlow format loads based on the object-local names of attributes to 202 which layers are assigned in the `Model`'s constructor. 203 204 Args: 205 filepath: String, path to the weights file to load. For weight files in 206 TensorFlow format, this is the file prefix (the same as was passed 207 to `save_weights`). 208 by_name: Boolean, whether to load weights by name or by topological 209 order. Only topological loading is supported for weight files in 210 TensorFlow format. 211 skip_mismatch: Boolean, whether to skip loading of layers where there is 212 a mismatch in the number of weights, or a mismatch in the shape of 213 the weight (only valid when `by_name=True`). 214 215 Returns: 216 When loading a weight file in TensorFlow format, returns the same status 217 object as `tf.train.Checkpoint.restore`. When graph building, restore 218 ops are run automatically as soon as the network is built (on first call 219 for user-defined classes inheriting from `Model`, immediately if it is 220 already built). 221 222 When loading weights in HDF5 format, returns `None`. 223 224 Raises: 225 ImportError: If h5py is not available and the weight file is in HDF5 226 format. 227 ValueError: If `skip_mismatch` is set to `True` when `by_name` is 228 `False`. 229 """ 230 if K.is_tpu_strategy(self._distribution_strategy): 231 if (self._distribution_strategy.extended.steps_per_run > 1 and 232 (not saving_utils.is_hdf5_filepath(filepath))): # pylint: disable=protected-access 233 raise ValueError('Load weights is not yet supported with TPUStrategy ' 234 'with steps_per_run greater than 1.') 235 return super(Model, self).load_weights(filepath, by_name, skip_mismatch) 236 237 @trackable.no_automatic_dependency_tracking 238 def compile(self, 239 optimizer='rmsprop', 240 loss=None, 241 metrics=None, 242 loss_weights=None, 243 sample_weight_mode=None, 244 weighted_metrics=None, 245 target_tensors=None, 246 distribute=None, 247 **kwargs): 248 """Configures the model for training. 249 250 Args: 251 optimizer: String (name of optimizer) or optimizer instance. 252 See `tf.keras.optimizers`. 253 loss: String (name of objective function), objective function or 254 `tf.keras.losses.Loss` instance. See `tf.keras.losses`. An objective 255 function is any callable with the signature 256 `scalar_loss = fn(y_true, y_pred)`. If the model has multiple 257 outputs, you can use a different loss on each output by passing a 258 dictionary or a list of losses. The loss value that will be 259 minimized by the model will then be the sum of all individual 260 losses. 261 metrics: List of metrics to be evaluated by the model during training 262 and testing. Typically you will use `metrics=['accuracy']`. 263 To specify different metrics for different outputs of a 264 multi-output model, you could also pass a dictionary, such as 265 `metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}`. 266 You can also pass a list (len = len(outputs)) of lists of metrics 267 such as `metrics=[['accuracy'], ['accuracy', 'mse']]` or 268 `metrics=['accuracy', ['accuracy', 'mse']]`. 269 loss_weights: Optional list or dictionary specifying scalar 270 coefficients (Python floats) to weight the loss contributions 271 of different model outputs. 272 The loss value that will be minimized by the model 273 will then be the *weighted sum* of all individual losses, 274 weighted by the `loss_weights` coefficients. 275 If a list, it is expected to have a 1:1 mapping 276 to the model's outputs. If a tensor, it is expected to map 277 output names (strings) to scalar coefficients. 278 sample_weight_mode: If you need to do timestep-wise 279 sample weighting (2D weights), set this to `"temporal"`. 280 `None` defaults to sample-wise weights (1D). 281 If the model has multiple outputs, you can use a different 282 `sample_weight_mode` on each output by passing a 283 dictionary or a list of modes. 284 weighted_metrics: List of metrics to be evaluated and weighted 285 by sample_weight or class_weight during training and testing. 286 target_tensors: By default, Keras will create placeholders for the 287 model's target, which will be fed with the target data during 288 training. If instead you would like to use your own 289 target tensors (in turn, Keras will not expect external 290 Numpy data for these targets at training time), you 291 can specify them via the `target_tensors` argument. It can be 292 a single tensor (for a single-output model), a list of tensors, 293 or a dict mapping output names to target tensors. 294 distribute: NOT SUPPORTED IN TF 2.0, please create and compile the 295 model under distribution strategy scope instead of passing it to 296 compile. 297 **kwargs: Any additional arguments. 298 299 Raises: 300 ValueError: In case of invalid arguments for 301 `optimizer`, `loss`, `metrics` or `sample_weight_mode`. 302 """ 303 self._assert_built_as_v1() 304 self._run_eagerly = kwargs.pop('run_eagerly', None) 305 self._experimental_run_tf_function = kwargs.pop( 306 'experimental_run_tf_function', True) 307 self._v1_compile_was_called = True 308 309 # Prepare Session arguments (legacy). 310 kwargs.pop('cloning', None) # Legacy DistStrat argument, never used. 311 allowed_kwargs = {'feed_dict', 'fetches', 'options', 'run_metadata'} 312 unknown_kwargs = set(kwargs.keys()) - allowed_kwargs 313 if unknown_kwargs: 314 raise TypeError( 315 'Invalid keyword argument(s) in `compile`: %s' % (unknown_kwargs,)) 316 self._function_kwargs = kwargs 317 if self._function_kwargs: 318 self._experimental_run_tf_function = False 319 if self.run_eagerly: 320 raise ValueError( 321 'Session keyword arguments are not supported ' 322 'when `run_eagerly=True`. You passed the following ' 323 'Session arguments: %s' % (self._function_kwargs,)) 324 325 self._set_optimizer(optimizer) 326 is_any_keras_optimizer_v1 = any( 327 (isinstance(opt, optimizer_v1.Optimizer) 328 and not isinstance(opt, optimizer_v1.TFOptimizer) 329 ) for opt in nest.flatten(self.optimizer)) 330 331 if is_any_keras_optimizer_v1 and ops.executing_eagerly_outside_functions(): 332 raise ValueError('`tf.compat.v1.keras` Optimizer (', optimizer, ') is ' 333 'not supported when eager execution is enabled. Use a ' 334 '`tf.keras` Optimizer instead, or disable eager ' 335 'execution.') 336 337 if ((target_tensors is not None) 338 or not ops.executing_eagerly_outside_functions()): 339 # Fallback out of things that aren't supported with v2 loops 340 self._experimental_run_tf_function = False 341 342 if distribute is not None: 343 if tf2.enabled() or self._experimental_run_tf_function: 344 raise ValueError( 345 'Distribute argument in compile is not available in TF 2.0 please ' 346 'create the model under the distribution strategy scope.') 347 logging.warning('Distribute argument in compile is deprecated please ' 348 'create the model under the distribution strategy scope.') 349 self._distribution_strategy = distribute 350 self._compile_distribution = True 351 else: 352 if distribution_strategy_context.has_strategy(): 353 # When the user builds the model in the DS scope and cross replica 354 # context we want distribution strategy to be set but when building the 355 # replica copies of the models internally we should not be compiling 356 # with distribution strategy and use the default compilation path. 357 if distribution_strategy_context.in_cross_replica_context(): 358 self._distribution_strategy = ( 359 distribution_strategy_context.get_strategy()) 360 361 if isinstance(self._distribution_strategy, 362 parameter_server_strategy.ParameterServerStrategyV1): 363 raise NotImplementedError( 364 '`tf.compat.v1.distribute.experimental.ParameterServerStrategy` ' 365 'currently only works with the tf.Estimator API') 366 367 if isinstance(self._distribution_strategy, 368 parameter_server_strategy_v2.ParameterServerStrategyV2): 369 raise NotImplementedError( 370 '`tf.distribute.experimental.ParameterServerStrategy` is only ' 371 'supported in TF2.') 372 373 if not self._experimental_run_tf_function: 374 self._validate_compile_param_for_distribution_strategy(self.run_eagerly, 375 sample_weight_mode, 376 target_tensors, 377 weighted_metrics) 378 # We've disabled automatic dependency tracking for this method, but do want 379 # to add a checkpoint dependency on the optimizer if it's trackable. 380 if isinstance(self.optimizer, trackable.Trackable): 381 self._track_trackable( 382 self.optimizer, name='optimizer', overwrite=True) 383 self.loss = loss or {} 384 self.loss_weights = loss_weights 385 self.sample_weight_mode = sample_weight_mode 386 self._compile_metrics = metrics or [] 387 self._compile_weighted_metrics = weighted_metrics 388 if self.run_eagerly and target_tensors is not None: 389 raise ValueError( 390 'target_tensors argument is not supported when ' 391 'running a model eagerly.') 392 393 # _training_endpoints contains a list of _TrainingEndpoint object, which has 394 # all the model output/target/loss and related metadata. 395 self._training_endpoints = [] 396 397 # Used to freeze the behavior of the Model once `compile` has been called. 398 self._compiled_trainable_state = self._get_trainable_state() 399 400 # Set tf.distribute.Strategy specific parameters. 401 self._distributed_model_cache = {} 402 self._distributed_function_cache = {} 403 404 # Clear any `_eager_losses` that was added. 405 self._clear_losses() 406 407 if (not context.executing_eagerly() and 408 self._distribution_strategy is not None): 409 # Ensures a Session is created and configured correctly for Distribution 410 # Strategy. 411 K.configure_and_create_distributed_session(self._distribution_strategy) 412 # Initialize model metric attributes. 413 self._init_metric_attributes() 414 if not self.built or not self.inputs or not self.outputs: 415 # Model is not compilable because it does not know its number of inputs 416 # and outputs, nor their shapes and names. We will compile after the first 417 # time the model gets called on training data. 418 return 419 self._is_compiled = True 420 base_layer.keras_api_gauge.get_cell('compile').set(True) 421 422 # Prepare list of loss functions, same size of model outputs. 423 self.loss_functions = training_utils_v1.prepare_loss_functions( 424 self.loss, self.output_names) 425 426 target_tensors = self._process_target_tensor_for_compile(target_tensors) 427 428 for o, n, l, t in zip(self.outputs, self.output_names, 429 self.loss_functions, target_tensors): 430 endpoint = _TrainingEndpoint(o, n, l) 431 endpoint.create_training_target(t, run_eagerly=self.run_eagerly) 432 self._training_endpoints.append(endpoint) 433 434 # Prepare list loss weights, same size of model outputs. 435 training_utils_v1.prepare_loss_weights(self._training_endpoints, 436 loss_weights) 437 438 # Initialization for Eager mode execution. 439 if self.run_eagerly: 440 self._compile_eagerly(metrics, weighted_metrics, sample_weight_mode) 441 return 442 443 with K.get_graph().as_default(): 444 # Save all metric attributes per output of the model. 445 self._cache_output_metric_attributes(metrics, weighted_metrics) 446 447 # Set metric attributes on model. 448 self._set_metric_attributes() 449 450 # Invoke metric functions (unweighted) for all the outputs. 451 self._handle_metrics( 452 self.outputs, 453 targets=self._targets, 454 skip_target_masks=self._prepare_skip_target_masks(), 455 masks=self._prepare_output_masks()) 456 457 # Prepare sample weight modes. List with the same length as model outputs. 458 training_utils_v1.prepare_sample_weight_modes( 459 self._training_endpoints, sample_weight_mode) 460 461 # Creates the model loss and weighted metrics sub-graphs. 462 self._compile_weights_loss_and_weighted_metrics() 463 464 # Functions for train, test and predict will 465 # be compiled lazily when required. 466 # This saves time when the user is not using all functions. 467 self.train_function = None 468 self.test_function = None 469 self.predict_function = None 470 471 # Collected trainable weights, sorted in topological order. 472 self._collected_trainable_weights = self.trainable_weights 473 474 # Validate all variables were correctly created in distribution scope. 475 if self._distribution_strategy and not self._compile_distribution: 476 for v in self.variables: 477 strategy = self._distribution_strategy 478 if not strategy.extended.variable_created_in_scope(v): 479 raise ValueError( 480 'Variable (%s) was not created in the distribution strategy ' 481 'scope of (%s). It is most likely due to not all layers or ' 482 'the model or optimizer being created outside the distribution ' 483 'strategy scope. Try to make sure your code looks similar ' 484 'to the following.\n' 485 'with strategy.scope():\n' 486 ' model=_create_model()\n' 487 ' model.compile(...)'% (v, strategy)) 488 489 @trackable.no_automatic_dependency_tracking 490 def _init_distributed_function_cache_if_not_compiled(self): 491 if not hasattr(self, '_distributed_function_cache'): 492 self._distributed_function_cache = {} 493 494 @property 495 def metrics(self): 496 """Returns the model's metrics added using `compile`, `add_metric` APIs.""" 497 metrics = [] 498 if self._is_compiled: 499 if not hasattr(self, '_v1_compile_was_called'): 500 # See b/155687393 for more details, the model is created as a v2 501 # instance but converted to v1. Fallback to use base Model to retrieve 502 # the metrics. 503 return super(Model, self).metrics 504 metrics += self._compile_metric_functions 505 metrics.extend(self._metrics) 506 metrics.extend( 507 _get_metrics_from_layers( 508 list(self._flatten_layers(include_self=False, recursive=False)))) 509 return metrics 510 511 @property 512 def metrics_names(self): 513 """Returns the model's display labels for all outputs.""" 514 515 # This property includes all output names including `loss` and per-output 516 # losses for backward compatibility. 517 metrics_names = ['loss'] 518 if self._is_compiled: 519 if not hasattr(self, '_v1_compile_was_called'): 520 # See b/155687393 for more details, the model is created as a v2 521 # instance but converted to v1. Fallback to use base Model to retrieve 522 # the metrics name 523 return super(Model, self).metrics_names 524 525 # Add output loss metric names to the metric names list. 526 if len(self._training_endpoints) > 1: 527 metrics_names.extend([ 528 e.loss_name() 529 for e in self._training_endpoints 530 if not e.should_skip_target() 531 ]) 532 533 # Add all metric names. 534 metrics_names += [m.name for m in self.metrics] 535 return metrics_names 536 537 @property 538 def run_eagerly(self): 539 """Settable attribute indicating whether the model should run eagerly. 540 541 Running eagerly means that your model will be run step by step, 542 like Python code. Your model might run slower, but it should become easier 543 for you to debug it by stepping into individual layer calls. 544 545 By default, we will attempt to compile your model to a static graph to 546 deliver the best execution performance. 547 548 Returns: 549 Boolean, whether the model should run eagerly. 550 """ 551 if self._run_eagerly is True and not context.executing_eagerly(): 552 raise ValueError('You can only set `run_eagerly=True` if eager execution ' 553 'is enabled.') 554 if not self.dynamic: 555 if self._run_eagerly is None: 556 # Respect `tf.config.run_functions_eagerly` unless 557 # `run_eagerly` was explicitly passed to `compile`. 558 return def_function.functions_run_eagerly() 559 else: 560 return self._run_eagerly 561 else: 562 if not context.executing_eagerly(): 563 raise ValueError('Your model contains layers that can only be ' 564 'successfully run in eager execution (layers ' 565 'constructed with `dynamic=True`). ' 566 'You must enable eager execution with ' 567 '`tf.enable_eager_execution()`.') 568 if self._run_eagerly is False: 569 # TODO(fchollet): consider using py_func to enable this. 570 raise ValueError('Your model contains layers that can only be ' 571 'successfully run in eager execution (layers ' 572 'constructed with `dynamic=True`). ' 573 'You cannot set `run_eagerly=False`.') 574 return context.executing_eagerly() 575 576 @run_eagerly.setter 577 def run_eagerly(self, value): 578 self._run_eagerly = value 579 580 def _select_training_loop(self, inputs): 581 """Select training loop for fit/eval/predict based on the inputs.""" 582 # TODO(kaftan) or TODO(scottzhu): This check should eventually be nicely 583 # integrated into the data adapters in the v2 loop. We can't do this yet 584 # because we currently have to fall back for unhandled data types. 585 if isinstance(inputs, (iterator_ops.Iterator, 586 iterator_ops.IteratorBase)): 587 raise ValueError('For performance reasons Keras `fit`, `evaluate` and' 588 '`predict` accept tf.data `Datasets` as input but not ' 589 'iterators that have been manually generated from ' 590 'Datasets by users. Please directly pass in the ' 591 'original `Dataset` object instead of passing in ' 592 '`iter(dataset)`.') 593 594 # Case 1: distribution strategy. 595 if self._distribution_strategy: 596 if self._in_multi_worker_mode(): 597 return training_distributed_v1.DistributionMultiWorkerTrainingLoop( 598 training_distributed_v1.DistributionSingleWorkerTrainingLoop()) 599 else: 600 return training_distributed_v1.DistributionSingleWorkerTrainingLoop() 601 602 # Case 2: generator-like. Input is Python generator, or Sequence object, 603 # or a non-distributed Dataset or iterator in eager execution. 604 if data_utils.is_generator_or_sequence(inputs): 605 return training_generator_v1.GeneratorOrSequenceTrainingLoop() 606 if training_utils_v1.is_eager_dataset_or_iterator(inputs): 607 return training_generator_v1.EagerDatasetOrIteratorTrainingLoop() 608 609 # Case 3: Symbolic tensors or Numpy array-like. 610 # This includes Datasets and iterators in graph mode (since they 611 # generate symbolic tensors). 612 if self.run_eagerly: 613 return training_generator_v1.GeneratorLikeTrainingLoop() 614 else: 615 return training_arrays_v1.ArrayLikeTrainingLoop() 616 617 def fit(self, 618 x=None, 619 y=None, 620 batch_size=None, 621 epochs=1, 622 verbose=1, 623 callbacks=None, 624 validation_split=0., 625 validation_data=None, 626 shuffle=True, 627 class_weight=None, 628 sample_weight=None, 629 initial_epoch=0, 630 steps_per_epoch=None, 631 validation_steps=None, 632 validation_freq=1, 633 max_queue_size=10, 634 workers=1, 635 use_multiprocessing=False, 636 **kwargs): 637 """Trains the model for a fixed number of epochs (iterations on a dataset). 638 639 Args: 640 x: Input data. It could be: 641 - A Numpy array (or array-like), or a list of arrays 642 (in case the model has multiple inputs). 643 - A TensorFlow tensor, or a list of tensors 644 (in case the model has multiple inputs). 645 - A dict mapping input names to the corresponding array/tensors, 646 if the model has named inputs. 647 - A `tf.data` dataset. Should return a tuple 648 of either `(inputs, targets)` or 649 `(inputs, targets, sample_weights)`. 650 - A generator or `keras.utils.Sequence` returning `(inputs, targets)` 651 or `(inputs, targets, sample weights)`. 652 y: Target data. Like the input data `x`, 653 it could be either Numpy array(s) or TensorFlow tensor(s). 654 It should be consistent with `x` (you cannot have Numpy inputs and 655 tensor targets, or inversely). If `x` is a dataset, generator, 656 or `keras.utils.Sequence` instance, `y` should 657 not be specified (since targets will be obtained from `x`). 658 batch_size: Integer or `None`. 659 Number of samples per gradient update. 660 If unspecified, `batch_size` will default to 32. 661 Do not specify the `batch_size` if your data is in the 662 form of symbolic tensors, datasets, 663 generators, or `keras.utils.Sequence` instances (since they generate 664 batches). 665 epochs: Integer. Number of epochs to train the model. 666 An epoch is an iteration over the entire `x` and `y` 667 data provided. 668 Note that in conjunction with `initial_epoch`, 669 `epochs` is to be understood as "final epoch". 670 The model is not trained for a number of iterations 671 given by `epochs`, but merely until the epoch 672 of index `epochs` is reached. 673 verbose: 0, 1, or 2. Verbosity mode. 674 0 = silent, 1 = progress bar, 2 = one line per epoch. 675 Note that the progress bar is not particularly useful when 676 logged to a file, so verbose=2 is recommended when not running 677 interactively (eg, in a production environment). 678 callbacks: List of `keras.callbacks.Callback` instances. 679 List of callbacks to apply during training. 680 See `tf.keras.callbacks`. 681 validation_split: Float between 0 and 1. 682 Fraction of the training data to be used as validation data. 683 The model will set apart this fraction of the training data, 684 will not train on it, and will evaluate 685 the loss and any model metrics 686 on this data at the end of each epoch. 687 The validation data is selected from the last samples 688 in the `x` and `y` data provided, before shuffling. This argument is 689 not supported when `x` is a dataset, generator or 690 `keras.utils.Sequence` instance. 691 validation_data: Data on which to evaluate 692 the loss and any model metrics at the end of each epoch. 693 The model will not be trained on this data. 694 `validation_data` will override `validation_split`. 695 `validation_data` could be: 696 - tuple `(x_val, y_val)` of Numpy arrays or tensors 697 - tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays 698 - dataset 699 For the first two cases, `batch_size` must be provided. 700 For the last case, `validation_steps` could be provided. 701 shuffle: Boolean (whether to shuffle the training data 702 before each epoch) or str (for 'batch'). 703 'batch' is a special option for dealing with the 704 limitations of HDF5 data; it shuffles in batch-sized chunks. 705 Has no effect when `steps_per_epoch` is not `None`. 706 class_weight: Optional dictionary mapping class indices (integers) 707 to a weight (float) value, used for weighting the loss function 708 (during training only). 709 This can be useful to tell the model to 710 "pay more attention" to samples from 711 an under-represented class. 712 sample_weight: Optional Numpy array of weights for 713 the training samples, used for weighting the loss function 714 (during training only). You can either pass a flat (1D) 715 Numpy array with the same length as the input samples 716 (1:1 mapping between weights and samples), 717 or in the case of temporal data, 718 you can pass a 2D array with shape 719 `(samples, sequence_length)`, 720 to apply a different weight to every timestep of every sample. 721 In this case you should make sure to specify 722 `sample_weight_mode="temporal"` in `compile()`. This argument is not 723 supported when `x` is a dataset, generator, or 724 `keras.utils.Sequence` instance, instead provide the sample_weights 725 as the third element of `x`. 726 initial_epoch: Integer. 727 Epoch at which to start training 728 (useful for resuming a previous training run). 729 steps_per_epoch: Integer or `None`. 730 Total number of steps (batches of samples) 731 before declaring one epoch finished and starting the 732 next epoch. When training with input tensors such as 733 TensorFlow data tensors, the default `None` is equal to 734 the number of samples in your dataset divided by 735 the batch size, or 1 if that cannot be determined. If x is a 736 `tf.data` dataset, and 'steps_per_epoch' 737 is None, the epoch will run until the input dataset is exhausted. 738 This argument is not supported with array inputs. 739 validation_steps: Only relevant if `validation_data` is provided and 740 is a `tf.data` dataset. Total number of steps (batches of 741 samples) to draw before stopping when performing validation 742 at the end of every epoch. If 'validation_steps' is None, validation 743 will run until the `validation_data` dataset is exhausted. In the 744 case of a infinite dataset, it will run into a infinite loop. 745 If 'validation_steps' is specified and only part of the dataset 746 will be consumed, the evaluation will start from the beginning of 747 the dataset at each epoch. This ensures that the same validation 748 samples are used every time. 749 validation_freq: Only relevant if validation data is provided. Integer 750 or `collections.abc.Container` instance (e.g. list, tuple, etc.). 751 If an integer, specifies how many training epochs to run before a 752 new validation run is performed, e.g. `validation_freq=2` runs 753 validation every 2 epochs. If a Container, specifies the epochs on 754 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs 755 validation at the end of the 1st, 2nd, and 10th epochs. 756 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 757 input only. Maximum size for the generator queue. 758 If unspecified, `max_queue_size` will default to 10. 759 workers: Integer. Used for generator or `keras.utils.Sequence` input 760 only. Maximum number of processes to spin up 761 when using process-based threading. If unspecified, `workers` 762 will default to 1. If 0, will execute the generator on the main 763 thread. 764 use_multiprocessing: Boolean. Used for generator or 765 `keras.utils.Sequence` input only. If `True`, use process-based 766 threading. If unspecified, `use_multiprocessing` will default to 767 `False`. Note that because this implementation relies on 768 multiprocessing, you should not pass non-picklable arguments to 769 the generator as they can't be passed easily to children processes. 770 **kwargs: Used for backwards compatibility. 771 772 Returns: 773 A `History` object. Its `History.history` attribute is 774 a record of training loss values and metrics values 775 at successive epochs, as well as validation loss values 776 and validation metrics values (if applicable). 777 778 Raises: 779 RuntimeError: If the model was never compiled. 780 ValueError: In case of mismatch between the provided input data 781 and what the model expects. 782 """ 783 self._assert_built_as_v1() 784 base_layer.keras_api_gauge.get_cell('fit').set(True) 785 # Legacy support 786 if 'nb_epoch' in kwargs: 787 logging.warning( 788 'The `nb_epoch` argument in `fit` has been renamed `epochs`.') 789 epochs = kwargs.pop('nb_epoch') 790 if kwargs: 791 raise TypeError('Unrecognized keyword arguments: ' + str(kwargs)) 792 self._assert_compile_was_called() 793 self._check_call_args('fit') 794 795 func = self._select_training_loop(x) 796 return func.fit( 797 self, 798 x=x, 799 y=y, 800 batch_size=batch_size, 801 epochs=epochs, 802 verbose=verbose, 803 callbacks=callbacks, 804 validation_split=validation_split, 805 validation_data=validation_data, 806 shuffle=shuffle, 807 class_weight=class_weight, 808 sample_weight=sample_weight, 809 initial_epoch=initial_epoch, 810 steps_per_epoch=steps_per_epoch, 811 validation_steps=validation_steps, 812 validation_freq=validation_freq, 813 max_queue_size=max_queue_size, 814 workers=workers, 815 use_multiprocessing=use_multiprocessing) 816 817 def evaluate(self, 818 x=None, 819 y=None, 820 batch_size=None, 821 verbose=1, 822 sample_weight=None, 823 steps=None, 824 callbacks=None, 825 max_queue_size=10, 826 workers=1, 827 use_multiprocessing=False): 828 """Returns the loss value & metrics values for the model in test mode. 829 830 Computation is done in batches (see the `batch_size` arg.) 831 832 Args: 833 x: Input data. It could be: 834 - A Numpy array (or array-like), or a list of arrays 835 (in case the model has multiple inputs). 836 - A TensorFlow tensor, or a list of tensors 837 (in case the model has multiple inputs). 838 - A dict mapping input names to the corresponding array/tensors, 839 if the model has named inputs. 840 - A `tf.data` dataset. 841 - A generator or `keras.utils.Sequence` instance. 842 y: Target data. Like the input data `x`, 843 it could be either Numpy array(s) or TensorFlow tensor(s). 844 It should be consistent with `x` (you cannot have Numpy inputs and 845 tensor targets, or inversely). 846 If `x` is a dataset, generator or 847 `keras.utils.Sequence` instance, `y` should not be specified (since 848 targets will be obtained from the iterator/dataset). 849 batch_size: Integer or `None`. 850 Number of samples per batch of computation. 851 If unspecified, `batch_size` will default to 32. 852 Do not specify the `batch_size` if your data is in the 853 form of symbolic tensors, dataset, 854 generators, or `keras.utils.Sequence` instances (since they generate 855 batches). 856 verbose: 0 or 1. Verbosity mode. 857 0 = silent, 1 = progress bar. 858 sample_weight: Optional Numpy array of weights for 859 the test samples, used for weighting the loss function. 860 You can either pass a flat (1D) 861 Numpy array with the same length as the input samples 862 (1:1 mapping between weights and samples), 863 or in the case of temporal data, 864 you can pass a 2D array with shape 865 `(samples, sequence_length)`, 866 to apply a different weight to every timestep of every sample. 867 In this case you should make sure to specify 868 `sample_weight_mode="temporal"` in `compile()`. This argument is not 869 supported when `x` is a dataset, instead pass 870 sample weights as the third element of `x`. 871 steps: Integer or `None`. 872 Total number of steps (batches of samples) 873 before declaring the evaluation round finished. 874 Ignored with the default value of `None`. 875 If x is a `tf.data` dataset and `steps` is 876 None, 'evaluate' will run until the dataset is exhausted. 877 This argument is not supported with array inputs. 878 callbacks: List of `keras.callbacks.Callback` instances. 879 List of callbacks to apply during evaluation. 880 See [callbacks](/api_docs/python/tf/keras/callbacks). 881 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 882 input only. Maximum size for the generator queue. 883 If unspecified, `max_queue_size` will default to 10. 884 workers: Integer. Used for generator or `keras.utils.Sequence` input 885 only. Maximum number of processes to spin up when using 886 process-based threading. If unspecified, `workers` will default 887 to 1. If 0, will execute the generator on the main thread. 888 use_multiprocessing: Boolean. Used for generator or 889 `keras.utils.Sequence` input only. If `True`, use process-based 890 threading. If unspecified, `use_multiprocessing` will default to 891 `False`. Note that because this implementation relies on 892 multiprocessing, you should not pass non-picklable arguments to 893 the generator as they can't be passed easily to children processes. 894 895 Returns: 896 Scalar test loss (if the model has a single output and no metrics) 897 or list of scalars (if the model has multiple outputs 898 and/or metrics). The attribute `model.metrics_names` will give you 899 the display labels for the scalar outputs. 900 901 Raises: 902 ValueError: in case of invalid arguments. 903 """ 904 self._assert_built_as_v1() 905 base_layer.keras_api_gauge.get_cell('evaluate').set(True) 906 self._assert_compile_was_called() 907 self._check_call_args('evaluate') 908 909 func = self._select_training_loop(x) 910 return func.evaluate( 911 self, 912 x=x, 913 y=y, 914 batch_size=batch_size, 915 verbose=verbose, 916 sample_weight=sample_weight, 917 steps=steps, 918 callbacks=callbacks, 919 max_queue_size=max_queue_size, 920 workers=workers, 921 use_multiprocessing=use_multiprocessing) 922 923 def predict(self, 924 x, 925 batch_size=None, 926 verbose=0, 927 steps=None, 928 callbacks=None, 929 max_queue_size=10, 930 workers=1, 931 use_multiprocessing=False): 932 """Generates output predictions for the input samples. 933 934 Computation is done in batches (see the `batch_size` arg.) 935 936 Args: 937 x: Input samples. It could be: 938 - A Numpy array (or array-like), or a list of arrays 939 (in case the model has multiple inputs). 940 - A TensorFlow tensor, or a list of tensors 941 (in case the model has multiple inputs). 942 - A `tf.data` dataset. 943 - A generator or `keras.utils.Sequence` instance. 944 batch_size: Integer or `None`. 945 Number of samples per batch of computation. 946 If unspecified, `batch_size` will default to 32. 947 Do not specify the `batch_size` if your data is in the 948 form of symbolic tensors, dataset, 949 generators, or `keras.utils.Sequence` instances (since they generate 950 batches). 951 verbose: Verbosity mode, 0 or 1. 952 steps: Total number of steps (batches of samples) 953 before declaring the prediction round finished. 954 Ignored with the default value of `None`. If x is a `tf.data` 955 dataset and `steps` is None, `predict` will 956 run until the input dataset is exhausted. 957 callbacks: List of `keras.callbacks.Callback` instances. 958 List of callbacks to apply during prediction. 959 See [callbacks](/api_docs/python/tf/keras/callbacks). 960 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 961 input only. Maximum size for the generator queue. 962 If unspecified, `max_queue_size` will default to 10. 963 workers: Integer. Used for generator or `keras.utils.Sequence` input 964 only. Maximum number of processes to spin up when using 965 process-based threading. If unspecified, `workers` will default 966 to 1. If 0, will execute the generator on the main thread. 967 use_multiprocessing: Boolean. Used for generator or 968 `keras.utils.Sequence` input only. If `True`, use process-based 969 threading. If unspecified, `use_multiprocessing` will default to 970 `False`. Note that because this implementation relies on 971 multiprocessing, you should not pass non-picklable arguments to 972 the generator as they can't be passed easily to children processes. 973 974 975 Returns: 976 Numpy array(s) of predictions. 977 978 Raises: 979 ValueError: In case of mismatch between the provided 980 input data and the model's expectations, 981 or in case a stateful model receives a number of samples 982 that is not a multiple of the batch size. 983 """ 984 self._assert_built_as_v1() 985 base_layer.keras_api_gauge.get_cell('predict').set(True) 986 self._check_call_args('predict') 987 988 func = self._select_training_loop(x) 989 return func.predict( 990 self, 991 x=x, 992 batch_size=batch_size, 993 verbose=verbose, 994 steps=steps, 995 callbacks=callbacks, 996 max_queue_size=max_queue_size, 997 workers=workers, 998 use_multiprocessing=use_multiprocessing) 999 1000 def reset_metrics(self): 1001 """Resets the state of metrics.""" 1002 metrics = self._get_training_eval_metrics() 1003 for m in metrics: 1004 m.reset_states() 1005 1006 # Reset metrics on all the distributed (cloned) models. 1007 if self._distribution_strategy: 1008 distributed_training_utils_v1._reset_metrics(self) # pylint: disable=protected-access 1009 1010 def train_on_batch(self, 1011 x, 1012 y=None, 1013 sample_weight=None, 1014 class_weight=None, 1015 reset_metrics=True): 1016 """Runs a single gradient update on a single batch of data. 1017 1018 Args: 1019 x: Input data. It could be: 1020 - A Numpy array (or array-like), or a list of arrays 1021 (in case the model has multiple inputs). 1022 - A TensorFlow tensor, or a list of tensors 1023 (in case the model has multiple inputs). 1024 - A dict mapping input names to the corresponding array/tensors, 1025 if the model has named inputs. 1026 - A `tf.data` dataset. 1027 y: Target data. Like the input data `x`, it could be either Numpy 1028 array(s) or TensorFlow tensor(s). It should be consistent with `x` 1029 (you cannot have Numpy inputs and tensor targets, or inversely). If 1030 `x` is a dataset, `y` should not be specified 1031 (since targets will be obtained from the iterator). 1032 sample_weight: Optional array of the same length as x, containing 1033 weights to apply to the model's loss for each sample. In the case of 1034 temporal data, you can pass a 2D array with shape (samples, 1035 sequence_length), to apply a different weight to every timestep of 1036 every sample. In this case you should make sure to specify 1037 sample_weight_mode="temporal" in compile(). This argument is not 1038 supported when `x` is a dataset. 1039 class_weight: Optional dictionary mapping class indices (integers) to a 1040 weight (float) to apply to the model's loss for the samples from this 1041 class during training. This can be useful to tell the model to "pay 1042 more attention" to samples from an under-represented class. 1043 reset_metrics: If `True`, the metrics returned will be only for this 1044 batch. If `False`, the metrics will be statefully accumulated across 1045 batches. 1046 1047 Returns: 1048 Scalar training loss 1049 (if the model has a single output and no metrics) 1050 or list of scalars (if the model has multiple outputs 1051 and/or metrics). The attribute `model.metrics_names` will give you 1052 the display labels for the scalar outputs. 1053 1054 Raises: 1055 ValueError: In case of invalid user-provided arguments. 1056 """ 1057 self._assert_compile_was_called() 1058 self._check_call_args('train_on_batch') 1059 1060 # If at this point we are in the replica context, then it is okay to execute 1061 # the Eager code path. The expected way to get here is to call `fit` that 1062 # calls `train_on_batch` on each replica. 1063 if (self._distribution_strategy and 1064 distribution_strategy_context.in_cross_replica_context()): 1065 raise NotImplementedError('`train_on_batch` is not supported for models ' 1066 'distributed with tf.distribute.Strategy.') 1067 # Validate and standardize user data. 1068 x, y, sample_weights = self._standardize_user_data( 1069 x, y, sample_weight=sample_weight, class_weight=class_weight, 1070 extract_tensors_from_dataset=True) 1071 1072 # If `self._distribution_strategy` is True, then we are in a replica context 1073 # at this point because of the check above. `train_on_batch` is being run 1074 # for each replica by `self._distribution_strategy` and the same code path 1075 # as Eager is expected to be taken. 1076 if self.run_eagerly or self._distribution_strategy: 1077 output_dict = training_eager_v1.train_on_batch( 1078 self, 1079 x, 1080 y, 1081 sample_weights=sample_weights, 1082 output_loss_metrics=self._output_loss_metrics) 1083 outputs = (output_dict['total_loss'] + output_dict['output_losses'] 1084 + output_dict['metrics']) 1085 outputs = [_non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access 1086 else: 1087 x = training_utils_v1.ModelInputs(x).as_list() 1088 ins = x + list(y or []) + list(sample_weights or []) 1089 1090 if not isinstance(K.symbolic_learning_phase(), int): 1091 ins += [True] # Add learning phase value. 1092 1093 self._update_sample_weight_modes(sample_weights=sample_weights) 1094 self._make_train_function() 1095 outputs = self.train_function(ins) # pylint: disable=not-callable 1096 1097 if reset_metrics: 1098 self.reset_metrics() 1099 1100 if len(outputs) == 1: 1101 return outputs[0] 1102 return outputs 1103 1104 def test_on_batch(self, x, y=None, sample_weight=None, reset_metrics=True): 1105 """Test the model on a single batch of samples. 1106 1107 Args: 1108 x: Input data. It could be: 1109 - A Numpy array (or array-like), or a list of arrays 1110 (in case the model has multiple inputs). 1111 - A TensorFlow tensor, or a list of tensors 1112 (in case the model has multiple inputs). 1113 - A dict mapping input names to the corresponding array/tensors, 1114 if the model has named inputs. 1115 - A `tf.data` dataset. 1116 y: Target data. Like the input data `x`, 1117 it could be either Numpy array(s) or TensorFlow tensor(s). 1118 It should be consistent with `x` (you cannot have Numpy inputs and 1119 tensor targets, or inversely). If `x` is a dataset `y` should 1120 not be specified (since targets will be obtained from the iterator). 1121 sample_weight: Optional array of the same length as x, containing 1122 weights to apply to the model's loss for each sample. 1123 In the case of temporal data, you can pass a 2D array 1124 with shape (samples, sequence_length), 1125 to apply a different weight to every timestep of every sample. 1126 In this case you should make sure to specify 1127 sample_weight_mode="temporal" in compile(). This argument is not 1128 supported when `x` is a dataset. 1129 reset_metrics: If `True`, the metrics returned will be only for this 1130 batch. If `False`, the metrics will be statefully accumulated across 1131 batches. 1132 1133 Returns: 1134 Scalar test loss (if the model has a single output and no metrics) 1135 or list of scalars (if the model has multiple outputs 1136 and/or metrics). The attribute `model.metrics_names` will give you 1137 the display labels for the scalar outputs. 1138 1139 Raises: 1140 ValueError: In case of invalid user-provided arguments. 1141 """ 1142 self._assert_compile_was_called() 1143 self._check_call_args('test_on_batch') 1144 1145 if (self._distribution_strategy and 1146 distribution_strategy_context.in_cross_replica_context()): 1147 raise NotImplementedError('`test_on_batch` is not supported for models ' 1148 'distributed with tf.distribute.Strategy.') 1149 # Validate and standardize user data. 1150 x, y, sample_weights = self._standardize_user_data( 1151 x, y, sample_weight=sample_weight, extract_tensors_from_dataset=True) 1152 1153 # If `self._distribution_strategy` is True, then we are in a replica context 1154 # at this point. 1155 if self.run_eagerly or self._distribution_strategy: 1156 output_dict = training_eager_v1.test_on_batch( 1157 self, 1158 x, 1159 y, 1160 sample_weights=sample_weights, 1161 output_loss_metrics=self._output_loss_metrics) 1162 outputs = (output_dict['total_loss'] + output_dict['output_losses'] 1163 + output_dict['metrics']) 1164 outputs = [_non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access 1165 else: 1166 x = training_utils_v1.ModelInputs(x).as_list() 1167 inputs = x + list(y or []) + list(sample_weights or []) 1168 1169 self._update_sample_weight_modes(sample_weights=sample_weights) 1170 self._make_test_function() 1171 outputs = self.test_function(inputs) # pylint: disable=not-callable 1172 1173 if reset_metrics: 1174 self.reset_metrics() 1175 1176 if len(outputs) == 1: 1177 return outputs[0] 1178 return outputs 1179 1180 def predict_on_batch(self, x): 1181 """Returns predictions for a single batch of samples. 1182 1183 Args: 1184 x: Input data. It could be: 1185 - A Numpy array (or array-like), or a list of arrays 1186 (in case the model has multiple inputs). 1187 - A TensorFlow tensor, or a list of tensors 1188 (in case the model has multiple inputs). 1189 - A `tf.data` dataset. 1190 1191 Returns: 1192 Numpy array(s) of predictions. 1193 1194 Raises: 1195 ValueError: In case of mismatch between given number of inputs and 1196 expectations of the model. 1197 """ 1198 self._check_call_args('predict_on_batch') 1199 1200 if (self._distribution_strategy and 1201 distribution_strategy_context.in_cross_replica_context()): 1202 raise NotImplementedError( 1203 '`predict_on_batch` is not supported for models distributed with' 1204 ' tf.distribute.Strategy.') 1205 # Validate and standardize user data. 1206 inputs, _, _ = self._standardize_user_data( 1207 x, extract_tensors_from_dataset=True) 1208 # If `self._distribution_strategy` is True, then we are in a replica context 1209 # at this point. 1210 if self.run_eagerly or self._distribution_strategy: 1211 inputs = training_utils_v1.cast_if_floating_dtype(inputs) 1212 if isinstance(inputs, collections.abc.Sequence): 1213 # Unwrap lists with only one input, as we do when training on batch 1214 if len(inputs) == 1: 1215 inputs = inputs[0] 1216 1217 return self(inputs) # pylint: disable=not-callable 1218 1219 self._make_predict_function() 1220 outputs = self.predict_function(inputs) 1221 1222 if len(outputs) == 1: 1223 return outputs[0] 1224 return outputs 1225 1226 def fit_generator(self, 1227 generator, 1228 steps_per_epoch=None, 1229 epochs=1, 1230 verbose=1, 1231 callbacks=None, 1232 validation_data=None, 1233 validation_steps=None, 1234 validation_freq=1, 1235 class_weight=None, 1236 max_queue_size=10, 1237 workers=1, 1238 use_multiprocessing=False, 1239 shuffle=True, 1240 initial_epoch=0): 1241 """Fits the model on data yielded batch-by-batch by a Python generator. 1242 1243 DEPRECATED: 1244 `Model.fit` now supports generators, so there is no longer any need to use 1245 this endpoint. 1246 """ 1247 warnings.warn('`model.fit_generator` is deprecated and ' 1248 'will be removed in a future version. ' 1249 'Please use `Model.fit`, which supports generators.') 1250 return self.fit( 1251 generator, 1252 steps_per_epoch=steps_per_epoch, 1253 epochs=epochs, 1254 verbose=verbose, 1255 callbacks=callbacks, 1256 validation_data=validation_data, 1257 validation_steps=validation_steps, 1258 validation_freq=validation_freq, 1259 class_weight=class_weight, 1260 max_queue_size=max_queue_size, 1261 workers=workers, 1262 use_multiprocessing=use_multiprocessing, 1263 shuffle=shuffle, 1264 initial_epoch=initial_epoch) 1265 1266 def evaluate_generator(self, 1267 generator, 1268 steps=None, 1269 callbacks=None, 1270 max_queue_size=10, 1271 workers=1, 1272 use_multiprocessing=False, 1273 verbose=0): 1274 """Evaluates the model on a data generator. 1275 1276 DEPRECATED: 1277 `Model.evaluate` now supports generators, so there is no longer any need 1278 to use this endpoint. 1279 """ 1280 warnings.warn('`Model.evaluate_generator` is deprecated and ' 1281 'will be removed in a future version. ' 1282 'Please use `Model.evaluate`, which supports generators.') 1283 self._check_call_args('evaluate_generator') 1284 1285 return self.evaluate( 1286 generator, 1287 steps=steps, 1288 max_queue_size=max_queue_size, 1289 workers=workers, 1290 use_multiprocessing=use_multiprocessing, 1291 verbose=verbose, 1292 callbacks=callbacks) 1293 1294 def predict_generator(self, 1295 generator, 1296 steps=None, 1297 callbacks=None, 1298 max_queue_size=10, 1299 workers=1, 1300 use_multiprocessing=False, 1301 verbose=0): 1302 """Generates predictions for the input samples from a data generator. 1303 1304 DEPRECATED: 1305 `Model.predict` now supports generators, so there is no longer any need 1306 to use this endpoint. 1307 """ 1308 warnings.warn('`Model.predict_generator` is deprecated and ' 1309 'will be removed in a future version. ' 1310 'Please use `Model.predict`, which supports generators.') 1311 return self.predict( 1312 generator, 1313 steps=steps, 1314 max_queue_size=max_queue_size, 1315 workers=workers, 1316 use_multiprocessing=use_multiprocessing, 1317 verbose=verbose, 1318 callbacks=callbacks) 1319 1320 def _check_call_args(self, method_name): 1321 """Check that `call` has only one positional arg.""" 1322 # Always allow first arg, regardless of arg name. 1323 fullargspec = self._call_full_argspec 1324 if fullargspec.defaults: 1325 positional_args = fullargspec.args[:-len(fullargspec.defaults)] 1326 else: 1327 positional_args = fullargspec.args 1328 if 'training' in positional_args: 1329 positional_args.remove('training') 1330 1331 # self and first arg can be positional. 1332 if len(positional_args) > 2: 1333 extra_args = positional_args[2:] 1334 raise ValueError( 1335 'Models passed to `' + method_name + '` can only have `training` ' 1336 'and the first argument in `call` as positional arguments, ' 1337 'found: ' + str(extra_args) + '.') 1338 1339 def _set_optimizer(self, optimizer): 1340 """Sets self.optimizer. 1341 1342 Sets self.optimizer to `optimizer`, potentially wrapping it with a 1343 LossScaleOptimizer. 1344 1345 Args: 1346 optimizer: The optimizer(s) to assign to self.optimizer. 1347 """ 1348 if isinstance(optimizer, (list, tuple)): 1349 self.optimizer = [optimizers.get(opt) for opt in optimizer] 1350 else: 1351 self.optimizer = optimizers.get(optimizer) 1352 1353 if isinstance(self._dtype_policy, policy.PolicyV1): 1354 loss_scale = self._dtype_policy.loss_scale 1355 elif self._dtype_policy.name == 'mixed_float16': 1356 loss_scale = 'dynamic' 1357 else: 1358 loss_scale = None 1359 1360 if (loss_scale is not None and 1361 not isinstance(self.optimizer, 1362 loss_scale_optimizer.LossScaleOptimizer)): 1363 if isinstance(self.optimizer, list): 1364 raise ValueError('When a dtype policy with a loss scale is used, you ' 1365 'can only pass a single optimizer. Using policy %s ' 1366 'and got optimizers: %s' % 1367 self._dtype_policy, self.optimizer) 1368 if not isinstance(self.optimizer, optimizer_v2.OptimizerV2): 1369 raise ValueError('"optimizer" must be an instance of ' 1370 'tf.keras.optimizers.Optimizer when a dype policy ' 1371 'with a loss scale used, but got: %s. Using policy: ' 1372 '%s' % 1373 (self.optimizer, self._dtype_policy)) 1374 if loss_scale == 'dynamic': 1375 self.optimizer = loss_scale_optimizer.LossScaleOptimizer(self.optimizer) 1376 else: 1377 self.optimizer = loss_scale_optimizer.LossScaleOptimizerV1( 1378 self.optimizer, loss_scale) 1379 1380 def _prepare_validation_data(self, validation_data, batch_size, 1381 validation_steps): 1382 """Unpack and check the validation data.""" 1383 val_x, val_y, val_sample_weights = training_utils_v1.unpack_validation_data( 1384 validation_data) 1385 return self._standardize_user_data( 1386 val_x, 1387 val_y, 1388 sample_weight=val_sample_weights, 1389 batch_size=batch_size, 1390 steps=validation_steps, 1391 steps_name='validation_steps') 1392 1393 def _validate_compile_param_for_distribution_strategy( 1394 self, run_eagerly, sample_weight_mode, target_tensors, weighted_metrics): 1395 # Validate that arguments passed by the user to `compile` are supported by 1396 # tf.distribute.Strategy. 1397 if self._distribution_strategy: 1398 if sample_weight_mode: 1399 raise NotImplementedError('sample_weight_mode is not supported with ' 1400 'tf.distribute.Strategy.') 1401 if weighted_metrics: 1402 raise NotImplementedError('weighted_metrics is not supported with ' 1403 'tf.distribute.Strategy.') 1404 if target_tensors: 1405 raise ValueError('target_tensors is not supported with ' 1406 'tf.distribute.Strategy.') 1407 1408 if run_eagerly: 1409 raise ValueError( 1410 'We currently do not support enabling `run_eagerly` with ' 1411 'distribution strategy.') 1412 1413 if (distributed_training_utils_v1.is_distributing_by_cloning(self) and 1414 (not self.built or not self.inputs or not self.outputs)): 1415 raise ValueError( 1416 'We currently do not support distribution strategy with a ' 1417 '`Sequential` model that is created without `input_shape`/' 1418 '`input_dim` set in its first layer or a subclassed model.') 1419 1420 def _process_target_tensor_for_compile(self, target_tensors): 1421 if self.run_eagerly: 1422 # target tensor is not supported with run_eagerly. Create a list with None 1423 # as placeholder for each output. 1424 return [None for _ in self.output_names] 1425 1426 if target_tensors is not None and not (isinstance(target_tensors, list) and 1427 target_tensors == []): # pylint: disable=g-explicit-bool-comparison 1428 if isinstance(target_tensors, list): 1429 if len(target_tensors) != len(self.outputs): 1430 raise ValueError( 1431 'When passing a list as `target_tensors`, ' 1432 'it should have one entry per model output. ' 1433 'The model has %s outputs, but you passed target_tensors=%s' % 1434 (len(self.outputs), target_tensors)) 1435 elif isinstance(target_tensors, dict): 1436 unexpected_target_tensor_names = set(target_tensors.keys()).difference( 1437 self.output_names) 1438 if unexpected_target_tensor_names: 1439 raise ValueError( 1440 'Unknown entry in `target_tensors` dictionary: "{name}". ' 1441 'Only expected the following keys: {keys}'.format( 1442 name=unexpected_target_tensor_names, 1443 keys=str(self.output_names))) 1444 tmp_target_tensors = [] 1445 for name in self.output_names: 1446 tmp_target_tensors.append(target_tensors.get(name, None)) 1447 target_tensors = tmp_target_tensors 1448 elif tensor_util.is_tf_type(target_tensors): 1449 target_tensors = [target_tensors] 1450 else: 1451 raise TypeError('Expected `target_tensors` to be a list or tuple or ' 1452 'dict or a single tensor, but got:', target_tensors) 1453 else: 1454 # In case target tensor is empty or None, create a list with Nones 1455 # that has same length as self.output_names. With that, the None check of 1456 # target tensor can be skipped downstream. 1457 target_tensors = [None for _ in self.output_names] 1458 return target_tensors 1459 1460 def _compile_eagerly(self, metrics, weighted_metrics, sample_weight_mode): 1461 # Prepare sample weight modes. List with the same length as model outputs. 1462 training_utils_v1.prepare_sample_weight_modes( 1463 self._training_endpoints, sample_weight_mode) 1464 # Prepare sample weights. 1465 self._prepare_sample_weights() 1466 # Save all metric attributes per output of the model. 1467 self._cache_output_metric_attributes(metrics, weighted_metrics) 1468 self.total_loss = None 1469 # Set metric attributes on model. 1470 self._set_metric_attributes() 1471 1472 self._collected_trainable_weights = self.trainable_weights 1473 1474 def _update_sample_weight_modes(self, sample_weights=None): 1475 """Updates sample weight modes based on training/eval inputs. 1476 1477 Sample weight placeholders will be created for all or no outputs 1478 based on whether sample_weight is provided for any output. 1479 1480 If model contains `_sample_weight_modes` we check if the input 1481 `sample_weights` corresponds to the sample weight modes. 1482 1. Set sample weight mode to be 'temporal' for output i, if `compile` 1483 sample_weight_mode was set to `temporal` and sample weight inputs 1484 are given for one or more outputs. 1485 2. Set sample weight mode to be 'samplewise' for output i, if `compile` 1486 sample_weight_mode was not set and sample weight inputs are given for 1487 one or more outputs. 1488 3. Reset sample weight mode to None for output i if sample weight mode 1489 was set but there is no sample weight input. 1490 1491 Args: 1492 sample_weights: List of sample weights of the same length as model outputs 1493 or None. 1494 """ 1495 if not self._is_compiled: 1496 return 1497 if sample_weights and any(s is not None for s in sample_weights): 1498 for endpoint in self._training_endpoints: 1499 endpoint.sample_weight_mode = ( 1500 endpoint.sample_weight_mode or 'samplewise') 1501 else: 1502 for endpoint in self._training_endpoints: 1503 endpoint.sample_weight_mode = None 1504 1505 def _recompile_weights_loss_and_weighted_metrics(self): 1506 if not self._is_compiled: 1507 return False 1508 recompile = any( 1509 e.sample_weights_mismatch() for e in self._training_endpoints) 1510 1511 if recompile: 1512 self._compile_weights_loss_and_weighted_metrics() 1513 return recompile 1514 1515 @trackable.no_automatic_dependency_tracking 1516 def _compile_weights_loss_and_weighted_metrics(self, sample_weights=None): 1517 """Compiles the model loss and weighted metric sub-graphs. 1518 1519 This may be used to set graph tensors as sample weights (instead of creating 1520 placeholders). This functionality is necessary for 1521 `tf.keras.estimator.model_to_estimator`, which calls Keras models in a v1 1522 graph, and creates iterator tensors for inputs, targets, and sample weights. 1523 1524 Args: 1525 sample_weights: List of tensors to use as the sample weights. Must be the 1526 same length as the number of outputs. If left as `None`, placeholders 1527 are used instead. 1528 """ 1529 with K.get_graph().as_default(): 1530 if sample_weights is not None: 1531 self._update_sample_weight_modes(sample_weights) 1532 self._prepare_sample_weights(sample_weights) 1533 1534 masks = self._prepare_output_masks() 1535 1536 # Compute weighted metrics. 1537 self._handle_metrics( 1538 self.outputs, 1539 targets=self._targets, 1540 skip_target_masks=self._prepare_skip_target_masks(), 1541 sample_weights=self.sample_weights, 1542 masks=masks, 1543 return_weighted_metrics=True) 1544 1545 # Compute total loss. 1546 # Used to keep track of the total loss value (stateless). 1547 # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) + 1548 # loss_weight_2 * output_2_loss_fn(...) + 1549 # layer losses. 1550 self.total_loss = self._prepare_total_loss(masks) 1551 1552 def _prepare_skip_target_masks(self): 1553 """Boolean mask for whether the target in the output list should be skipped. 1554 1555 If the loss function corresponding to a model output is None, then this 1556 output will be skipped during total loss calculation and feed targets 1557 preparation. 1558 1559 Returns: 1560 A boolean list for whether the corresponding target in the output list 1561 should be skipped during loss calculation. 1562 """ 1563 return [l is None for l in self.loss_functions] 1564 1565 def _prepare_output_masks(self): 1566 """Returns masks corresponding to model outputs.""" 1567 return [getattr(x, '_keras_mask', None) for x in self.outputs] 1568 1569 def _prepare_total_loss(self, masks): 1570 """Computes total loss from loss functions. 1571 1572 Args: 1573 masks: List of mask values corresponding to each model output. 1574 1575 Returns: 1576 A list of loss weights of python floats. 1577 1578 Raises: 1579 TypeError: If model run_eagerly is True. 1580 """ 1581 if self.run_eagerly: 1582 raise TypeError('total loss can not be computed when compiled with ' 1583 'run_eagerly = True.') 1584 loss_list = [] 1585 with K.name_scope('loss'): 1586 for endpoint, mask in zip(self._training_endpoints, masks): 1587 if endpoint.should_skip_target(): 1588 continue 1589 y_true = endpoint.training_target.target 1590 y_pred = endpoint.output 1591 loss_fn = endpoint.loss_fn 1592 loss_weight = endpoint.loss_weight 1593 loss_name = endpoint.loss_name() 1594 sample_weight = endpoint.sample_weight 1595 1596 with K.name_scope(loss_name): 1597 if mask is not None: 1598 mask = math_ops.cast(mask, y_pred.dtype) 1599 # Update weights with mask. 1600 if sample_weight is None: 1601 sample_weight = mask 1602 else: 1603 # Update dimensions of weights to match with mask if possible. 1604 mask, _, sample_weight = ( 1605 losses_utils.squeeze_or_expand_dimensions( 1606 mask, sample_weight=sample_weight)) 1607 sample_weight *= mask 1608 1609 if hasattr(loss_fn, 'reduction'): 1610 per_sample_losses = loss_fn.call(y_true, y_pred) 1611 weighted_losses = losses_utils.compute_weighted_loss( 1612 per_sample_losses, 1613 sample_weight=sample_weight, 1614 reduction=losses_utils.ReductionV2.NONE) 1615 loss_reduction = loss_fn.reduction 1616 1617 # `AUTO` loss reduction defaults to `SUM_OVER_BATCH_SIZE` for all 1618 # compile use cases. 1619 if loss_reduction == losses_utils.ReductionV2.AUTO: 1620 loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE 1621 1622 # Compute the stateless loss value. 1623 output_loss = losses_utils.reduce_weighted_loss( 1624 weighted_losses, reduction=loss_reduction) 1625 else: 1626 # Compute the stateless loss value for a custom loss class. 1627 # Here we assume that the class takes care of loss reduction 1628 # because if this class returns a vector value we cannot 1629 # differentiate between use case where a custom optimizer 1630 # expects a vector loss value vs unreduced per-sample loss value. 1631 output_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight) 1632 loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE 1633 1634 if len(self.outputs) > 1: 1635 # Keep track of stateful result tensor for the loss. 1636 endpoint.output_loss_metric(output_loss) 1637 1638 # Scale output loss for distribution. For custom losses we assume 1639 # reduction was mean. 1640 if loss_reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE: 1641 output_loss = losses_utils.scale_loss_for_distribution(output_loss) 1642 1643 loss_list.append(loss_weight * output_loss) 1644 if not loss_list and not self.losses: 1645 raise ValueError('The model cannot be compiled ' 1646 'because it has no loss to optimize.') 1647 1648 # Add regularization penalties and other layer-specific losses. 1649 custom_losses = self.get_losses_for(None) + self.get_losses_for( 1650 self.inputs) 1651 if custom_losses: 1652 total_custom_loss = math_ops.add_n( 1653 losses_utils.cast_losses_to_common_dtype(custom_losses)) 1654 loss_list.append( 1655 losses_utils.scale_loss_for_distribution(total_custom_loss)) 1656 1657 loss_list = losses_utils.cast_losses_to_common_dtype(loss_list) 1658 if loss_list: 1659 total_loss = math_ops.add_n(loss_list) 1660 else: 1661 total_loss = 0. 1662 return total_loss 1663 1664 def _get_callback_model(self): 1665 """Returns the Callback Model for this Model.""" 1666 1667 if hasattr(self, '_replicated_model') and self._replicated_model: 1668 # When using training_distributed, we set the callback model 1669 # to an instance of the `DistributedModel` that we create in 1670 # the `compile` call. The `DistributedModel` is initialized 1671 # with the first replicated model. We need to set the callback 1672 # model to a DistributedModel to allow us to override saving 1673 # and loading weights when we checkpoint the model during training. 1674 return self._replicated_model 1675 if hasattr(self, 'callback_model') and self.callback_model: 1676 return self.callback_model 1677 return self 1678 1679 @trackable.no_automatic_dependency_tracking 1680 def _make_callback_model(self, grouped_model): 1681 first_replicated_model = self._distribution_strategy.unwrap( 1682 grouped_model)[0] 1683 # We initialize the callback model with the first replicated model. 1684 self._replicated_model = DistributedCallbackModel(first_replicated_model) 1685 self._replicated_model.set_original_model(self) 1686 1687 def _validate_or_infer_batch_size(self, batch_size, steps, x): 1688 """Validates that the `batch_size` provided is consistent with InputLayer. 1689 1690 It's possible that the user specified a static batch size in their 1691 InputLayer. If so, this method checks the provided `batch_size` and `x` 1692 arguments are consistent with this static batch size. Also, if 1693 `batch_size` is `None`, this method will attempt to infer the batch size 1694 from the static batch size of the InputLayer. Lastly, ValueError will be 1695 raised if `x` is a tf.data.Dataset and `batch_size` is specified as we 1696 expect users to provide batched datasets. 1697 1698 Args: 1699 batch_size: The batch_size provided as an argument to 1700 fit/evaluate/predict. 1701 steps: The steps provided as an argument to fit/evaluate/predict. 1702 x: The data passed as `x` to fit/evaluate/predict. 1703 1704 Returns: 1705 The validated batch_size, auto-inferred from the first layer if not 1706 provided. 1707 """ 1708 if (isinstance(x, (dataset_ops.DatasetV1, 1709 dataset_ops.DatasetV2, 1710 data_utils.Sequence)) or 1711 tf_inspect.isgenerator(x)): 1712 if batch_size is not None: 1713 raise ValueError( 1714 'The `batch_size` argument must not be specified for the given ' 1715 'input type. Received input: {}, batch_size: {}'.format( 1716 x, batch_size)) 1717 return 1718 1719 # Avoids the override in Sequential.layers which filters Input layers. 1720 # (Which are often the very layers that we're after.) 1721 layers = self._flatten_layers(include_self=False, recursive=False) 1722 first_layer = next(layers, None) 1723 if first_layer: 1724 # The per-replica static batch size. 1725 static_batch_size = training_utils.get_static_batch_size(first_layer) 1726 if static_batch_size is not None: 1727 1728 # Determine number of times the user-supplied batch size will be split. 1729 if (self._distribution_strategy and 1730 distributed_training_utils.global_batch_size_supported( 1731 self._distribution_strategy)): 1732 num_splits_for_ds = self._distribution_strategy.num_replicas_in_sync 1733 else: 1734 num_splits_for_ds = 1 1735 1736 # Check `batch_size` argument is consistent with InputLayer. 1737 if batch_size is not None: 1738 if batch_size % num_splits_for_ds != 0: 1739 raise ValueError('The `batch_size` argument ({}) must be divisible ' 1740 'the by number of replicas ({})'.format( 1741 batch_size, num_splits_for_ds)) 1742 per_replica_batch_size = batch_size // num_splits_for_ds 1743 1744 if per_replica_batch_size != static_batch_size: 1745 raise ValueError('The `batch_size` argument value {} is ' 1746 'incompatible with the specified batch size of ' 1747 'your Input Layer: {}'.format( 1748 per_replica_batch_size, static_batch_size)) 1749 1750 # Check Dataset/Iterator batch size is consistent with InputLayer. 1751 if isinstance(x, (dataset_ops.DatasetV2, iterator_ops.Iterator, 1752 iterator_ops.IteratorBase)): 1753 ds_batch_size = tensor_shape.Dimension( 1754 nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value 1755 if ds_batch_size is not None: 1756 if ds_batch_size % num_splits_for_ds != 0: 1757 raise ValueError( 1758 'The batch output shape of your `Dataset` {} ' 1759 'cannot be divisible by number of replicas {}'.format( 1760 ds_batch_size, num_splits_for_ds)) 1761 1762 ds_per_replica_batch_size = ds_batch_size // num_splits_for_ds 1763 if ds_per_replica_batch_size != static_batch_size: 1764 raise ValueError('The batch output shape of your `Dataset` is ' 1765 '{}, which is incompatible with the specified ' 1766 'batch size of your Input Layer: {}'.format( 1767 ds_per_replica_batch_size, 1768 static_batch_size)) 1769 1770 # Set inferred batch size from the InputLayer. 1771 if steps is None: 1772 batch_size = static_batch_size * num_splits_for_ds 1773 1774 if batch_size is None and steps is None: 1775 # Backwards compatibility 1776 batch_size = 32 1777 return batch_size 1778 1779 def _prepare_sample_weights(self, sample_weights=None): 1780 """Sets sample weight attribute on the model.""" 1781 # List with the same length as model outputs. 1782 if sample_weights is not None: 1783 if len(sample_weights) != len(self._training_endpoints): 1784 raise ValueError('Provided sample weights must have same length as the ' 1785 'number of outputs. Expected: {}, got: {}.'.format( 1786 len(self._training_endpoints), 1787 len(sample_weights))) 1788 else: 1789 sample_weights = [None] * len(self._training_endpoints) 1790 for endpoint, weight in zip(self._training_endpoints, sample_weights): 1791 endpoint.populate_sample_weight(weight, endpoint.sample_weight_mode) 1792 1793 def _cache_output_metric_attributes(self, metrics, weighted_metrics): 1794 """Caches metric name and function attributes for every model output.""" 1795 output_shapes = [] 1796 for output in self.outputs: 1797 if output is None or output.shape.rank is None: 1798 output_shapes.append(None) 1799 else: 1800 output_shapes.append(output.shape.as_list()) 1801 self._per_output_metrics = training_utils_v1.collect_per_output_metric_info( 1802 metrics, self.output_names, output_shapes, self.loss_functions) 1803 self._per_output_weighted_metrics = ( 1804 training_utils_v1.collect_per_output_metric_info( 1805 weighted_metrics, 1806 self.output_names, 1807 output_shapes, 1808 self.loss_functions, 1809 is_weighted=True)) 1810 1811 def _add_unique_metric_name(self, metric_name, output_index): 1812 """Makes the metric name unique and adds it to the model's metric name list. 1813 1814 If there are multiple outputs for which the metrics are calculated, the 1815 metric names have to be made unique by appending an integer. 1816 1817 Args: 1818 metric_name: Metric name that corresponds to the metric specified by the 1819 user. For example: 'acc'. 1820 output_index: The index of the model output for which the metric name is 1821 being added. 1822 1823 Returns: 1824 string, name of the model's unique metric name 1825 """ 1826 if len(self.output_names) > 1: 1827 metric_name = '%s_%s' % (self.output_names[output_index], metric_name) 1828 j = 1 1829 base_metric_name = metric_name 1830 while metric_name in self.metrics_names: 1831 metric_name = '%s_%d' % (base_metric_name, j) 1832 j += 1 1833 1834 return metric_name 1835 1836 def _init_metric_attributes(self): 1837 """Initialized model metric attributes.""" 1838 # List of stateful metric functions. Used for resetting metric state during 1839 # training/eval. 1840 self._compile_metric_functions = [] 1841 1842 def _set_per_output_metric_attributes(self, metrics_dict, output_index): 1843 """Sets the metric attributes on the model for the given output. 1844 1845 Args: 1846 metrics_dict: A dict with metric names as keys and metric fns as values. 1847 output_index: The index of the model output for which the metric 1848 attributes are added. 1849 1850 Returns: 1851 Metrics dict updated with unique metric names as keys. 1852 """ 1853 updated_metrics_dict = collections.OrderedDict() 1854 for metric_name, metric_fn in metrics_dict.items(): 1855 metric_name = self._add_unique_metric_name(metric_name, output_index) 1856 1857 # Update the name on the metric class to be the unique generated name. 1858 metric_fn._name = metric_name # pylint: disable=protected-access 1859 updated_metrics_dict[metric_name] = metric_fn 1860 # Keep track of metric name and function. 1861 self._compile_metric_functions.append(metric_fn) 1862 return updated_metrics_dict 1863 1864 def _set_metric_attributes(self): 1865 """Sets the metric attributes on the model for all the model outputs.""" 1866 updated_per_output_metrics = [] 1867 updated_per_output_weighted_metrics = [] 1868 for i, endpoint in enumerate(self._training_endpoints): 1869 if endpoint.should_skip_target(): 1870 updated_per_output_metrics.append(self._per_output_metrics[i]) 1871 updated_per_output_weighted_metrics.append( 1872 self._per_output_weighted_metrics[i]) 1873 continue 1874 updated_per_output_metrics.append( 1875 self._set_per_output_metric_attributes(self._per_output_metrics[i], 1876 i)) 1877 updated_per_output_weighted_metrics.append( 1878 self._set_per_output_metric_attributes( 1879 self._per_output_weighted_metrics[i], i)) 1880 1881 # Create a metric wrapper for each output loss. This computes mean of an 1882 # output loss across mini-batches (irrespective of how we reduce within a 1883 # batch). 1884 if len(self._training_endpoints) > 1: 1885 for endpoint in self._training_endpoints: 1886 if not endpoint.should_skip_target(): 1887 endpoint.output_loss_metric = metrics_module.Mean( 1888 name=endpoint.loss_name()) 1889 1890 self._per_output_metrics = updated_per_output_metrics 1891 self._per_output_weighted_metrics = updated_per_output_weighted_metrics 1892 1893 def _handle_per_output_metrics(self, 1894 metrics_dict, 1895 y_true, 1896 y_pred, 1897 mask, 1898 weights=None): 1899 """Calls metric functions for a single output. 1900 1901 Args: 1902 metrics_dict: A dict with metric names as keys and metric fns as values. 1903 y_true: Target output. 1904 y_pred: Predicted output. 1905 mask: Computed mask value for the current output. 1906 weights: Weights to be applied on the current output. 1907 1908 Returns: 1909 A list of metric result tensors. 1910 """ 1911 metric_results = [] 1912 for metric_name, metric_fn in metrics_dict.items(): 1913 with K.name_scope(metric_name): 1914 metric_result = training_utils_v1.call_metric_function( 1915 metric_fn, y_true, y_pred, weights=weights, mask=mask) 1916 metric_results.append(metric_result) 1917 return metric_results 1918 1919 def _handle_metrics(self, 1920 outputs, 1921 targets=None, 1922 skip_target_masks=None, 1923 sample_weights=None, 1924 masks=None, 1925 return_weighted_metrics=False, 1926 return_weighted_and_unweighted_metrics=False): 1927 """Handles calling metric functions. 1928 1929 Args: 1930 outputs: List of outputs (predictions). 1931 targets: List of targets. 1932 skip_target_masks: Optional. List of boolean for whether the corresponding 1933 target should be ignored or not. 1934 sample_weights: Optional list of sample weight arrays. 1935 masks: List of computed output mask values. 1936 return_weighted_metrics: Flag that indicates whether weighted metrics 1937 should be computed instead of unweighted metrics. This flag is ignored 1938 when `return_weighted_and_unweighted_metrics` is enabled. 1939 return_weighted_and_unweighted_metrics: Flag that is used to indicate 1940 whether both weighted and unweighted metrics should be computed. When 1941 this is not enabled, we use `return_weighted_metrics` param to indicate 1942 whether weighted or unweighted metrics should be returned. 1943 1944 Returns: 1945 A list of metric result tensors. 1946 """ 1947 # TODO(scottzhu): Update this to use the new training_endpoints. Currently 1948 # the eager and graph logic is bit different. 1949 skip_target_masks = skip_target_masks or [False] * len(outputs) 1950 metric_results = [] 1951 with K.name_scope('metrics'): 1952 # Invoke all metrics added using `compile`. 1953 for i in range(len(outputs)): 1954 if skip_target_masks[i]: 1955 continue 1956 output = outputs[i] if outputs else None 1957 target = targets[i] if targets else None 1958 output_mask = masks[i] if masks else None 1959 1960 if (return_weighted_and_unweighted_metrics or 1961 not return_weighted_metrics): 1962 metric_results.extend( 1963 self._handle_per_output_metrics(self._per_output_metrics[i], 1964 target, output, output_mask)) 1965 if return_weighted_and_unweighted_metrics or return_weighted_metrics: 1966 metric_results.extend( 1967 self._handle_per_output_metrics( 1968 self._per_output_weighted_metrics[i], 1969 target, 1970 output, 1971 output_mask, 1972 weights=sample_weights[i] if sample_weights else None)) 1973 return metric_results 1974 1975 def _check_trainable_weights_consistency(self): 1976 """Check trainable weights count consistency. 1977 1978 This will raise a warning if `trainable_weights` and 1979 `_collected_trainable_weights` are inconsistent (i.e. have different 1980 number of parameters). 1981 Inconsistency will typically arise when one modifies `model.trainable` 1982 without calling `model.compile` again. 1983 """ 1984 if not hasattr(self, '_collected_trainable_weights'): 1985 return 1986 1987 if len(self.trainable_weights) != len(self._collected_trainable_weights): 1988 logging.log_first_n( 1989 logging.WARN, 'Discrepancy between trainable weights and collected' 1990 ' trainable weights, did you set `model.trainable`' 1991 ' without calling `model.compile` after ?', 1) 1992 1993 def _make_train_function(self): 1994 has_recompiled = self._recompile_weights_loss_and_weighted_metrics() 1995 self._check_trainable_weights_consistency() 1996 if isinstance(self.optimizer, list): 1997 raise ValueError('The `optimizer` in `compile` should be a single ' 1998 'optimizer.') 1999 # If we have re-compiled the loss/weighted metric sub-graphs then create 2000 # train function even if one exists already. This is because 2001 # `_feed_sample_weights` list has been updated on re-compile. 2002 if getattr(self, 'train_function', None) is None or has_recompiled: 2003 # Restore the compiled trainable state. 2004 current_trainable_state = self._get_trainable_state() 2005 self._set_trainable_state(self._compiled_trainable_state) 2006 2007 inputs = (self._feed_inputs + 2008 self._feed_targets + 2009 self._feed_sample_weights) 2010 if not isinstance(K.symbolic_learning_phase(), int): 2011 inputs += [K.symbolic_learning_phase()] 2012 2013 with K.get_graph().as_default(): 2014 with K.name_scope('training'): 2015 # Training updates 2016 updates = self.optimizer.get_updates( 2017 params=self._collected_trainable_weights, loss=self.total_loss) 2018 # Unconditional updates 2019 updates += self.get_updates_for(None) 2020 # Conditional updates relevant to this model 2021 updates += self.get_updates_for(self.inputs) 2022 2023 metrics = self._get_training_eval_metrics() 2024 metrics_tensors = [ 2025 m._call_result for m in metrics if hasattr(m, '_call_result') # pylint: disable=protected-access 2026 ] 2027 2028 with K.name_scope('training'): 2029 # Gets loss and metrics. Updates weights at each call. 2030 fn = K.function( 2031 inputs, [self.total_loss] + metrics_tensors, 2032 updates=updates, 2033 name='train_function', 2034 **self._function_kwargs) 2035 setattr(self, 'train_function', fn) 2036 2037 # Restore the current trainable state 2038 self._set_trainable_state(current_trainable_state) 2039 2040 def _make_test_function(self): 2041 has_recompiled = self._recompile_weights_loss_and_weighted_metrics() 2042 # If we have re-compiled the loss/weighted metric sub-graphs then create 2043 # test function even if one exists already. This is because 2044 # `_feed_sample_weights` list has been updated on re-compile. 2045 if getattr(self, 'test_function', None) is None or has_recompiled: 2046 inputs = (self._feed_inputs + 2047 self._feed_targets + 2048 self._feed_sample_weights) 2049 2050 with K.get_graph().as_default(): 2051 metrics = self._get_training_eval_metrics() 2052 metrics_tensors = [ 2053 m._call_result for m in metrics if hasattr(m, '_call_result') # pylint: disable=protected-access 2054 ] 2055 2056 with K.name_scope('evaluation'): 2057 updates = self.state_updates 2058 # Return loss and metrics, no gradient updates. 2059 # Does update the network states. 2060 fn = K.function( 2061 inputs, [self.total_loss] + metrics_tensors, 2062 updates=updates, 2063 name='test_function', 2064 **self._function_kwargs) 2065 setattr(self, 'test_function', fn) 2066 2067 def _make_predict_function(self): 2068 if not hasattr(self, 'predict_function'): 2069 self.predict_function = None 2070 if self.predict_function is None: 2071 inputs = self._feed_inputs 2072 # Gets network outputs. Does not update weights. 2073 # Does update the network states. 2074 kwargs = getattr(self, '_function_kwargs', {}) 2075 with K.name_scope(ModeKeys.PREDICT): 2076 self.predict_function = K.function( 2077 inputs, 2078 self.outputs, 2079 updates=self.state_updates, 2080 name='predict_function', 2081 **kwargs) 2082 2083 def _make_execution_function(self, mode): 2084 if mode == ModeKeys.TRAIN: 2085 self._make_train_function() 2086 return self.train_function 2087 if mode == ModeKeys.TEST: 2088 self._make_test_function() 2089 return self.test_function 2090 if mode == ModeKeys.PREDICT: 2091 self._make_predict_function() 2092 return self.predict_function 2093 2094 def _distribution_standardize_user_data(self, 2095 x, 2096 y=None, 2097 sample_weight=None, 2098 class_weight=None, 2099 batch_size=None, 2100 validation_split=0, 2101 shuffle=False, 2102 epochs=1, 2103 allow_partial_batch=False): 2104 """Runs validation checks on input and target data passed by the user. 2105 2106 This is called when using tf.distribute.Strategy to train, evaluate or serve 2107 the model. 2108 2109 Args: 2110 x: Input data. A numpy array or `tf.data` dataset. 2111 y: Target data. A numpy array or None if x is a `tf.data` dataset. 2112 sample_weight: An optional sample-weight array passed by the user to 2113 weight the importance of each sample in `x`. 2114 class_weight: An optional class-weight array by the user to 2115 weight the importance of samples in `x` based on the class they belong 2116 to, as conveyed by `y`. 2117 batch_size: Integer batch size. If provided, it is used to run additional 2118 validation checks on stateful models. 2119 validation_split: Float between 0 and 1. 2120 Fraction of the training data to be used as validation data. 2121 shuffle: Boolean whether to shuffle the training data before each epoch. 2122 epochs: Integer epochs. If > 1, repeat the numpy training data epochs 2123 times when converting to training dataset. 2124 allow_partial_batch: Boolean whether to enforce that all batches have the 2125 same size. 2126 2127 Returns: 2128 Dataset instance. 2129 2130 Raises: 2131 ValueError: In case of invalid user-provided data. 2132 RuntimeError: If the model was never compiled. 2133 """ 2134 if class_weight: 2135 raise NotImplementedError('`class_weight` is currently not supported ' 2136 'when using tf.distribute.Strategy.') 2137 2138 if (sample_weight is not None and sample_weight.all() and 2139 K.is_tpu_strategy(self._distribution_strategy)): 2140 raise NotImplementedError('`sample_weight` is currently not supported ' 2141 'when using TPUStrategy.') 2142 2143 # Validates `steps` and `shuffle` arguments right at the beginning 2144 # since we use it to construct the dataset object. 2145 # TODO(anjalisridhar): Remove this check once we refactor the 2146 # _standardize_user_data code path. This check is already present elsewhere 2147 # in the codebase. 2148 if isinstance(x, dataset_ops.DatasetV2): 2149 if shuffle: 2150 training_utils_v1.verify_dataset_shuffled(x) 2151 2152 strategy = self._distribution_strategy 2153 with strategy.scope(): 2154 # We should be sure to call get_session() inside the strategy.scope() 2155 # so the strategy can affect the session options. 2156 if ops.executing_eagerly_outside_functions(): 2157 session = None 2158 else: 2159 session = K.get_session() 2160 2161 first_x_value = nest.flatten(x)[0] 2162 if isinstance(first_x_value, np.ndarray): 2163 x = training_utils.list_to_tuple(x) 2164 if y is not None: 2165 y = training_utils.list_to_tuple(y) 2166 if sample_weight is not None: 2167 sample_weight = training_utils.list_to_tuple(sample_weight) 2168 in_tuple = (x, y, sample_weight) 2169 else: 2170 in_tuple = (x, y) 2171 else: 2172 in_tuple = x 2173 2174 ds = strategy.extended.experimental_make_numpy_dataset(in_tuple, 2175 session=session) 2176 if shuffle: 2177 # We want a buffer size that is larger than the batch size provided by 2178 # the user and provides sufficient randomness. Note that larger 2179 # numbers introduce more memory usage based on the size of each 2180 # sample. 2181 ds = ds.shuffle(max(1024, batch_size * 8)) 2182 if epochs > 1: 2183 ds = ds.repeat(epochs) 2184 2185 # We need to use the drop_remainder argument to get a known static 2186 # input shape which is required for TPUs. 2187 drop_remainder = (not allow_partial_batch and 2188 strategy.extended.experimental_require_static_shapes) 2189 2190 # TODO(b/131720208): We still drop remainder here if number of examples 2191 # is divisible by batch size, as sometimes dynamic padder will time out 2192 # with keras.metrics.CategoricalAccuracy() metric. 2193 if K.is_tpu_strategy(strategy) and not drop_remainder: 2194 dataset_size = first_x_value.shape[0] 2195 if dataset_size % batch_size == 0: 2196 drop_remainder = True 2197 2198 x = ds.batch(batch_size, drop_remainder=drop_remainder) 2199 else: 2200 assert isinstance(x, dataset_ops.DatasetV2) 2201 training_utils_v1.validate_dataset_input(x, y, sample_weight, 2202 validation_split) 2203 return x 2204 2205 def _standardize_user_data(self, 2206 x, 2207 y=None, 2208 sample_weight=None, 2209 class_weight=None, 2210 batch_size=None, 2211 check_steps=False, 2212 steps_name='steps', 2213 steps=None, 2214 validation_split=0, 2215 shuffle=False, 2216 extract_tensors_from_dataset=False): 2217 """Runs validation checks on input and target data passed by the user. 2218 2219 Also standardizes the data to lists of arrays, in order. 2220 2221 Also builds and compiles the model on the fly if it is a subclassed model 2222 that has never been called before (and thus has no inputs/outputs). 2223 2224 This is a purely internal method, subject to refactoring at any time. 2225 2226 Args: 2227 x: Input data. It could be: 2228 - A Numpy array (or array-like), or a list of arrays 2229 (in case the model has multiple inputs). 2230 - A TensorFlow tensor, or a list of tensors 2231 (in case the model has multiple inputs). 2232 - A dict mapping input names to the corresponding array/tensors, 2233 if the model has named inputs. 2234 - A `tf.data` dataset. 2235 y: Target data. Like the input data `x`, 2236 it could be either Numpy array(s) or TensorFlow tensor(s). 2237 It should be consistent with `x` (you cannot have Numpy inputs and 2238 tensor targets, or inversely). If `x` is a dataset, `y` should not be 2239 specified (since targets will be obtained from the iterator). 2240 sample_weight: An optional sample-weight array passed by the user to 2241 weight the importance of each sample in `x`. 2242 class_weight: An optional class-weight array by the user to 2243 weight the importance of samples in `x` based on the class they belong 2244 to, as conveyed by `y`. If both `sample_weight` and `class_weight` are 2245 provided, the weights are multiplied. 2246 batch_size: Integer batch size. If provided, it is used to run additional 2247 validation checks on stateful models. 2248 check_steps: boolean, True if we want to check for validity of `steps` and 2249 False, otherwise. For example, when we are standardizing one batch of 2250 data for train_on_batch/predict_on_batch/test_on_batch APIs, `steps` 2251 value is not required and we should not check for its validity in these 2252 cases. 2253 steps_name: The public API's parameter name for `steps`. 2254 steps: Integer or `None`. Total number of steps (batches of samples) to 2255 execute. 2256 validation_split: Float between 0 and 1. 2257 Fraction of the training data to be used as validation data. 2258 shuffle: Boolean whether to shuffle the training data before each epoch. 2259 extract_tensors_from_dataset: Boolean. When `x` is a dataset instance, 2260 this indicates whether to extract actual tensors from the dataset or 2261 instead output the dataset instance itself. 2262 Set to True when calling from `train_on_batch`/etc. 2263 2264 Returns: 2265 A tuple of 3: inputs (arrays or dicts, depending on whether `x` was a dict 2266 or not), target arrays, sample-weight arrays. 2267 If the model's input and targets are symbolic, these lists are empty 2268 (since the model takes no user-provided data, instead the data comes 2269 from the symbolic inputs/targets). 2270 2271 Raises: 2272 ValueError: In case of invalid user-provided data. 2273 RuntimeError: If the model was never compiled. 2274 """ 2275 if isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): 2276 # Graph mode dataset. We'll pass the dataset as-is (unless 2277 # `extract_tensors_from_dataset` is True, in which case we extract 2278 # the tensors from the dataset and we output them. 2279 training_utils_v1.validate_dataset_input(x, y, sample_weight, 2280 validation_split) 2281 if shuffle: 2282 training_utils_v1.verify_dataset_shuffled(x) 2283 2284 is_dataset = True 2285 if extract_tensors_from_dataset: 2286 # We do this for `train_on_batch`/etc. 2287 x, y, sample_weight = training_utils_v1.extract_tensors_from_dataset(x) 2288 elif isinstance(x, iterator_ops.Iterator): 2289 # Graph mode iterator. We extract the symbolic tensors. 2290 training_utils_v1.validate_dataset_input(x, y, sample_weight, 2291 validation_split) 2292 iterator = x 2293 x, y, sample_weight = training_utils_v1.unpack_iterator_input(iterator) 2294 is_dataset = True 2295 else: 2296 is_dataset = False 2297 2298 # Validates `steps` argument based on x's type. 2299 if check_steps: 2300 training_utils_v1.check_steps_argument(x, steps, steps_name) 2301 2302 # First, we build the model on the fly if necessary. 2303 if not self.inputs: 2304 all_inputs, y_input, dict_inputs = self._build_model_with_inputs(x, y) 2305 is_build_called = True 2306 else: 2307 all_inputs = [] 2308 # Whether this is a subclassed model that expects dictionary inputs 2309 # rather than list inputs (e.g. FeatureColumn-based models). 2310 dict_inputs = isinstance(self.inputs, dict) 2311 is_build_called = False 2312 y_input = y 2313 2314 # Second, we compile the model on the fly if necessary, mostly for subclass 2315 # models. 2316 is_compile_called = False 2317 if not self._is_compiled and self.optimizer: 2318 self._compile_from_inputs(all_inputs, y_input, x, y) 2319 is_compile_called = True 2320 2321 # In graph mode, if we had just set inputs and targets as symbolic tensors 2322 # by invoking build and compile on the model respectively, we do not have to 2323 # feed anything to the model. Model already has input and target data as 2324 # part of the graph. 2325 # Note: in this case, `any` and `all` are equivalent since we disallow 2326 # mixed symbolic/value inputs. 2327 2328 # self.run_eagerly is not free to compute, so we want to reuse the value. 2329 run_eagerly = self.run_eagerly 2330 2331 if (not run_eagerly and is_build_called and is_compile_called and 2332 not is_dataset and any(_is_symbolic_tensor(v) for v in all_inputs)): 2333 return [], [], None 2334 2335 return self._standardize_tensors( 2336 x, y, sample_weight, 2337 run_eagerly=run_eagerly, 2338 dict_inputs=dict_inputs, 2339 is_dataset=is_dataset, 2340 class_weight=class_weight, 2341 batch_size=batch_size) 2342 2343 def _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs, 2344 is_dataset, class_weight=None, batch_size=None): 2345 if run_eagerly: 2346 # In eager mode, do not do shape validation 2347 # since the network has no input nodes (placeholders) to be fed. 2348 feed_input_names = self.input_names 2349 feed_input_shapes = None 2350 elif not self._is_graph_network: 2351 # Case: symbolic-mode subclassed network. Do not do shape validation. 2352 feed_input_names = self._feed_input_names 2353 feed_input_shapes = None 2354 else: 2355 # Case: symbolic-mode graph network. 2356 # In this case, we run extensive shape validation checks. 2357 feed_input_names = self._feed_input_names 2358 feed_input_shapes = self._feed_input_shapes 2359 2360 # Standardize the inputs. 2361 if not isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): 2362 # TODO(fchollet): run static checks with dataset output shape(s). 2363 x = training_utils_v1.standardize_input_data( 2364 x, 2365 feed_input_names, 2366 feed_input_shapes, 2367 check_batch_axis=False, # Don't enforce the batch size. 2368 exception_prefix='input') 2369 2370 # Get typespecs for the input data and sanitize it if necessary. 2371 # TODO(momernick): This should be capable of doing full input validation 2372 # at all times - validate that this is so and refactor the standardization 2373 # code. 2374 if isinstance(x, dataset_ops.DatasetV2): 2375 x_shapes = dataset_ops.get_structure(x) 2376 if isinstance(x_shapes, tuple): 2377 # If the output of a Dataset is a tuple, we assume it's either of the 2378 # form (x_data, y_data) or (x_data, y_data, sample_weights). In either 2379 # case, we only care about x_data here. 2380 x_shapes = x_shapes[0] 2381 else: 2382 flat_inputs = nest.flatten(x, expand_composites=False) 2383 flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False) 2384 converted_x = [] 2385 for (a, b) in zip(flat_inputs, flat_expected_inputs): 2386 converted_x.append(_convert_scipy_sparse_tensor(a, b)) 2387 x = nest.pack_sequence_as(x, converted_x, expand_composites=False) 2388 2389 def _type_spec_from_value(value): 2390 """Grab type_spec without converting array-likes to tensors.""" 2391 if tf_utils.is_extension_type(value): 2392 return value._type_spec # pylint: disable=protected-access 2393 # Get a TensorSpec for array-like data without 2394 # converting the data to a Tensor 2395 if hasattr(value, 'shape') and hasattr(value, 'dtype'): 2396 return tensor_spec.TensorSpec(value.shape, value.dtype) 2397 else: 2398 return type_spec.type_spec_from_value(value) 2399 2400 x_shapes = nest.map_structure(_type_spec_from_value, x) 2401 2402 flat_inputs = nest.flatten(x_shapes, expand_composites=False) 2403 flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False) 2404 for (a, b) in zip(flat_inputs, flat_expected_inputs): 2405 nest.assert_same_structure(a, b, expand_composites=True) 2406 2407 if y is not None: 2408 # Prepare self._sample_weight_modes. List with the same length as 2409 # model outputs. 2410 training_utils_v1.prepare_sample_weight_modes(self._training_endpoints, 2411 self.sample_weight_mode) 2412 feed_output_names = self._feed_output_names 2413 feed_sample_weight_modes = self._sample_weight_modes 2414 if not self._is_graph_network: 2415 feed_output_shapes = None 2416 else: 2417 feed_output_shapes = self._feed_output_shapes 2418 2419 # Standardize the outputs. 2420 y = training_utils_v1.standardize_input_data( 2421 y, 2422 feed_output_names, 2423 # Don't enforce target shapes to match output shapes. 2424 # Precise checks will be run in `check_loss_and_target_compatibility`. 2425 shapes=None, 2426 check_batch_axis=False, # Don't enforce the batch size. 2427 exception_prefix='target') 2428 2429 # Generate sample-wise weight values given the `sample_weight` and 2430 # `class_weight` arguments. 2431 sample_weights = training_utils_v1.standardize_sample_weights( 2432 sample_weight, feed_output_names) 2433 class_weights = training_utils_v1.standardize_class_weights( 2434 class_weight, feed_output_names) 2435 2436 sample_weights = [ 2437 training_utils_v1.standardize_weights(ref, sw, cw, mode) 2438 for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights, 2439 feed_sample_weight_modes) 2440 ] 2441 # Check that all arrays have the same length. 2442 if not self._distribution_strategy: 2443 training_utils_v1.check_array_lengths(x, y, sample_weights) 2444 if self._is_graph_network and not run_eagerly: 2445 # Additional checks to avoid users mistakenly using improper loss fns. 2446 training_utils_v1.check_loss_and_target_compatibility( 2447 y, self._feed_loss_fns, feed_output_shapes) 2448 2449 sample_weights, _, _ = training_utils.handle_partial_sample_weights( 2450 y, sample_weights, feed_sample_weight_modes, check_all_flat=True) 2451 else: 2452 y = [] 2453 sample_weights = None 2454 2455 if self.stateful and batch_size and not is_dataset: 2456 # Check that for stateful networks, number of samples is a multiple 2457 # of the static batch size. 2458 if x[0].shape[0] % batch_size != 0: 2459 raise ValueError('In a stateful network, ' 2460 'you should only pass inputs with ' 2461 'a number of samples that can be ' 2462 'divided by the batch size. Found: ' + 2463 str(x[0].shape[0]) + ' samples') 2464 2465 # If dictionary inputs were provided, we return a dictionary as well. 2466 if dict_inputs and not isinstance(x, (dataset_ops.DatasetV1, 2467 dataset_ops.DatasetV2)): 2468 x = dict(zip(feed_input_names, x)) 2469 return x, y, sample_weights 2470 2471 def _build_model_with_inputs(self, inputs, targets): 2472 """Build the model (set model inputs/outputs), mainly for subclass model.""" 2473 processed_inputs = [] 2474 is_dict_inputs = False 2475 orig_inputs = inputs 2476 # We need to use `inputs` to set the model inputs. 2477 # If input data is a dataset iterator in graph mode or if it is an eager 2478 # iterator and only one batch of samples is required, we fetch the data 2479 # tensors from the iterator and then standardize them. 2480 if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): 2481 inputs, targets, _ = training_utils_v1.extract_tensors_from_dataset( 2482 inputs) 2483 # We type-check that `inputs` and `targets` are either single arrays 2484 # or lists of arrays, and extract a flat list of inputs from the passed 2485 # structure. 2486 training_utils_v1.validate_input_types(inputs, orig_inputs) 2487 2488 if isinstance(inputs, (list, tuple)): 2489 processed_inputs += list(inputs) 2490 elif isinstance(inputs, dict): 2491 is_dict_inputs = True 2492 keys = sorted(inputs.keys()) 2493 processed_inputs = [inputs[k] for k in keys] 2494 else: 2495 processed_inputs.append(inputs) 2496 # Now that we have a flat set of inputs, we make sure that none of them 2497 # are CompositeTensors or CompositeTensorValues of any type (or scipy 2498 # sparse arrays, which we treat as SparseTensor values). We cannot safely 2499 # infer input data from an arbitrary composite tensor, so we don't try - 2500 # users should explicitly add composite tensor inputs to their subclassed 2501 # models. 2502 for input_tensor in processed_inputs: 2503 if training_utils_v1.is_composite_or_composite_value(input_tensor): 2504 # TODO(b/132691975): Document subclass-model CT input handling. 2505 raise ValueError( 2506 'All SparseTensor and RaggedTensor inputs must be explicitly ' 2507 'declared using a keras.Input() with sparse=True or ragged=True. ' 2508 'We found an undeclared input %s. For Sequential models, please ' 2509 'add a keras.Input() as your first Layer. For subclassed models, ' 2510 'please call self._set_inputs() on your input set, which you can ' 2511 'create using keras.Input() for each input to your model.' % 2512 (input_tensor,)) 2513 # Build the model using the retrieved inputs (value or symbolic). 2514 # If values are generated from a dataset, then in symbolic-mode 2515 # placeholders will be created to match the value shapes. 2516 if isinstance(orig_inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2, 2517 iterator_ops.Iterator)): 2518 if not self.inputs: 2519 # For subclassed models, a robust input spec is not available so we 2520 # must cast to the model dtype. 2521 inputs = training_utils_v1.cast_if_floating_dtype(inputs, self.dtype) 2522 2523 def create_tensor_spec(t): 2524 return tensor_spec.TensorSpec(t.shape, t.dtype) 2525 2526 cast_inputs = nest.map_structure(create_tensor_spec, inputs) 2527 elif training_utils_v1.has_tensors(inputs): 2528 cast_inputs = training_utils_v1.cast_if_floating_dtype(inputs) 2529 else: 2530 cast_inputs = inputs 2531 self._set_inputs(cast_inputs) 2532 return processed_inputs, targets, is_dict_inputs 2533 2534 def _compile_from_inputs(self, all_inputs, target, orig_inputs, orig_target): 2535 if target is not None: 2536 # We need to use `y` to set the model targets. 2537 if training_utils_v1.has_tensors(target): 2538 target = training_utils_v1.cast_if_floating_dtype_and_mismatch( 2539 target, self.outputs) 2540 training_utils_v1.validate_input_types( 2541 target, orig_target, allow_dict=False, field_name='target') 2542 if isinstance(target, (list, tuple)): 2543 all_inputs += list(target) 2544 else: 2545 all_inputs.append(target) 2546 # Type check that all inputs are *either* value *or* symbolic. 2547 # TODO(fchollet): this check could be removed in Eager mode? 2548 if any(tensor_util.is_tf_type(v) for v in all_inputs): 2549 if not all(tensor_util.is_tf_type(v) for v in all_inputs): 2550 raise ValueError('Do not pass inputs that mix Numpy arrays and ' 2551 'TensorFlow tensors. ' 2552 'You passed: x=' + str(orig_inputs) + 2553 '; y=' + str(orig_target)) 2554 is_dataset = isinstance(orig_inputs, (dataset_ops.DatasetV1, 2555 dataset_ops.DatasetV2, 2556 iterator_ops.Iterator)) 2557 if is_dataset or context.executing_eagerly(): 2558 target_tensors = None 2559 else: 2560 # Handle target tensors if any passed. 2561 if target is not None: 2562 if not isinstance(target, (list, tuple)): 2563 target = [target] 2564 target_tensors = [v for v in target if _is_symbolic_tensor(v)] 2565 else: 2566 target_tensors = None 2567 2568 self.compile( 2569 optimizer=self.optimizer, 2570 loss=self.loss, 2571 metrics=self._compile_metrics, 2572 weighted_metrics=self._compile_weighted_metrics, 2573 loss_weights=self.loss_weights, 2574 target_tensors=target_tensors, 2575 sample_weight_mode=self.sample_weight_mode, 2576 run_eagerly=self.run_eagerly, 2577 experimental_run_tf_function=self._experimental_run_tf_function) 2578 2579 # TODO(omalleyt): Consider changing to a more descriptive function name. 2580 def _set_inputs(self, inputs, outputs=None, training=None): 2581 """Set model's input and output specs based on the input data received. 2582 2583 This is to be used for Model subclasses, which do not know at instantiation 2584 time what their inputs look like. 2585 2586 Args: 2587 inputs: Single array, or list of arrays. The arrays could be placeholders, 2588 Numpy arrays, data tensors, or TensorSpecs. 2589 - if placeholders: the model is built on top of these placeholders, 2590 and we expect Numpy data to be fed for them when calling `fit`/etc. 2591 - if Numpy data or TensorShapes: we create placeholders matching the 2592 TensorShapes or shapes of the Numpy arrays. We expect Numpy data to be 2593 fed for these placeholders when calling `fit`/etc. 2594 - if data tensors: the model is built on top of these tensors. 2595 We do not expect any Numpy data to be provided when calling `fit`/etc. 2596 outputs: None, a data tensor, or a list of tensors. If None, the 2597 outputs will be determined by invoking `self.call()`, otherwise the 2598 provided value will be used. 2599 training: Boolean or None. Only relevant in symbolic mode. Specifies 2600 whether to build the model's graph in inference mode (False), training 2601 mode (True), or using the Keras learning phase (None). 2602 Raises: 2603 ValueError: If dict inputs are passed to a Sequential Model where the 2604 first layer isn't FeatureLayer. 2605 """ 2606 self._set_save_spec(inputs) 2607 inputs = self._set_input_attrs(inputs) 2608 2609 if outputs is None: 2610 kwargs = {} 2611 if self._expects_training_arg: 2612 # In V2 mode, feeding `training=None` is not allowed because any value 2613 # explicitly passed by the user is respected, even `None`.` 2614 if training is None and not ops.executing_eagerly_outside_functions(): 2615 training = K.learning_phase() 2616 if training is not None: 2617 kwargs['training'] = training 2618 try: 2619 outputs = self(inputs, **kwargs) 2620 except NotImplementedError: 2621 # This Model or a submodel is dynamic and hasn't overridden 2622 # `compute_output_shape`. 2623 outputs = None 2624 2625 self._set_output_attrs(outputs) 2626 2627 @trackable.no_automatic_dependency_tracking 2628 def _set_input_attrs(self, inputs): 2629 """Sets attributes related to the inputs of the Model.""" 2630 if self.inputs: 2631 raise ValueError('Model inputs are already set.') 2632 2633 if self.__class__.__name__ == 'Sequential' and not self.built: 2634 if tensor_util.is_tf_type(inputs): 2635 input_shape = (None,) + tuple(inputs.shape.as_list()[1:]) 2636 elif isinstance(inputs, tensor_shape.TensorShape): 2637 input_shape = (None,) + tuple(inputs.as_list()[1:]) 2638 elif isinstance(inputs, dict): 2639 # We assert that the first layer is a FeatureLayer. 2640 if not training_utils_v1.is_feature_layer(self.layers[0]): 2641 raise ValueError('Passing a dictionary input to a Sequential Model ' 2642 'which doesn\'t have FeatureLayer as the first layer' 2643 ' is an error.') 2644 input_shape = (None,) 2645 else: 2646 input_shape = (None,) + tuple(inputs.shape[1:]) 2647 self._build_input_shape = input_shape 2648 2649 # Cast inputs to the compute dtype. This is primarily used 2650 # when saving to determine the correct dtype in the input signature. 2651 inputs = self._maybe_cast_inputs(inputs) 2652 2653 # On-the-fly setting of symbolic model inputs (either by using the tensor 2654 # provided, or by creating a placeholder if Numpy data was provided). 2655 model_inputs = training_utils_v1.ModelInputs(inputs) 2656 inputs = model_inputs.get_symbolic_inputs() 2657 self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True) 2658 self.input_names = model_inputs.get_input_names() 2659 2660 self._feed_inputs = [] 2661 self._feed_input_names = [] 2662 self._feed_input_shapes = [] 2663 2664 for k, v in model_inputs.as_dict(): 2665 if K.is_placeholder(v): 2666 self._feed_input_names.append(k) 2667 self._feed_inputs.append(v) 2668 self._feed_input_shapes.append(K.int_shape(v)) 2669 2670 return inputs 2671 2672 @trackable.no_automatic_dependency_tracking 2673 def _set_output_attrs(self, outputs): 2674 """Sets attributes related to the outputs of the Model.""" 2675 # NOTE(taylorrobie): This convention cannot be changed without updating the 2676 # data adapter since it assumes nest.flatten ordering. 2677 outputs = nest.flatten(outputs) 2678 self.outputs = outputs 2679 self.output_names = training_utils_v1.generic_output_names(outputs) 2680 # TODO(scottzhu): Should we cleanup the self._training_endpoints here? 2681 self.built = True 2682 2683 @property 2684 def _targets(self): 2685 """The output target tensors for the model.""" 2686 return [ 2687 e.training_target.target 2688 for e in self._training_endpoints 2689 if e.has_training_target() 2690 ] 2691 2692 @property 2693 def _feed_targets(self): 2694 return [ 2695 e.training_target.target 2696 for e in self._training_endpoints 2697 if e.has_feedable_training_target() 2698 ] 2699 2700 @property 2701 def _feed_output_names(self): 2702 return [ 2703 e.output_name 2704 for e in self._training_endpoints 2705 if e.has_feedable_training_target() 2706 ] 2707 2708 @property 2709 def _feed_output_shapes(self): 2710 return [ 2711 e.feed_output_shape 2712 for e in self._training_endpoints 2713 if e.has_feedable_training_target() 2714 ] 2715 2716 @property 2717 def _feed_loss_fns(self): 2718 return [ 2719 e.loss_fn 2720 for e in self._training_endpoints 2721 if e.has_feedable_training_target() 2722 ] 2723 2724 @property 2725 def _loss_weights_list(self): 2726 return [e.loss_weight for e in self._training_endpoints] 2727 2728 @property 2729 def _output_loss_metrics(self): 2730 if hasattr(self, '_training_endpoints'): 2731 return [ 2732 e.output_loss_metric 2733 for e in self._training_endpoints 2734 if e.output_loss_metric is not None 2735 ] 2736 return None 2737 2738 @property 2739 def sample_weights(self): 2740 return [e.sample_weight for e in self._training_endpoints] 2741 2742 @property 2743 def _sample_weight_modes(self): 2744 return [e.sample_weight_mode for e in self._training_endpoints] 2745 2746 @property 2747 def _feed_sample_weights(self): 2748 return [e.sample_weight for e in self._training_endpoints 2749 if e.sample_weight is not None] 2750 2751 def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode): 2752 """Maybe load initial epoch from ckpt considering possible worker recovery. 2753 2754 Refer to tensorflow/python/keras/distribute/worker_training_state.py 2755 for more information. 2756 2757 Args: 2758 initial_epoch: The original initial_epoch user passes in in `fit()`. 2759 mode: The mode for running `model.fit()`. 2760 2761 Returns: 2762 If the training is recovering from previous failure under multi-worker 2763 training setting, return the epoch the training is supposed to continue 2764 at. Otherwise, return the `initial_epoch` the user passes in. 2765 """ 2766 if self._training_state is not None: 2767 return self._training_state.maybe_load_initial_epoch_from_ckpt( 2768 initial_epoch, mode) 2769 return initial_epoch 2770 2771 def _get_training_eval_metrics(self): 2772 """Returns all the metrics that are to be reported. 2773 2774 This includes the output loss metrics, compile metrics/weighted metrics, 2775 add_metric metrics. 2776 """ 2777 metrics = [] 2778 metrics.extend(getattr(self, '_output_loss_metrics', None) or []) 2779 metrics.extend(getattr(self, 'metrics', None) or []) 2780 return metrics 2781 2782 def _assert_compile_was_called(self): 2783 # Checks whether `compile` has been called. If it has been called, 2784 # then the optimizer is set. This is different from whether the 2785 # model is compiled 2786 # (i.e. whether the model is built and its inputs/outputs are set). 2787 if not self._compile_was_called: 2788 raise RuntimeError('You must compile your model before ' 2789 'training/testing. ' 2790 'Use `model.compile(optimizer, loss)`.') 2791 2792 def _in_multi_worker_mode(self): 2793 """Method to infer if this `Model` is working in multi-worker settings. 2794 2795 Multi-worker training refers to the setup where the training is 2796 distributed across multiple workers, as opposed to the case where 2797 only a local process performs the training. This function is 2798 used to infer for example whether or not a distribute coordinator 2799 should be run, and thus TensorFlow servers should be started for 2800 communication with other servers in the cluster, or whether or not 2801 saving/restoring checkpoints is relevant for preemption fault tolerance. 2802 2803 Experimental. Signature and implementation are subject to change. 2804 2805 Returns: 2806 Whether this model indicates it's working in multi-worker settings. 2807 """ 2808 strategy = self._distribution_strategy 2809 2810 # Otherwise, use the strategy whose scope this is in. 2811 if not strategy and distribution_strategy_context.has_strategy(): 2812 strategy = distribution_strategy_context.get_strategy() 2813 return strategy and strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access 2814 2815 @property 2816 def _trackable_saved_model_saver(self): 2817 return model_serialization.ModelSavedModelSaver(self) 2818 2819 def _get_compile_args(self, user_metrics=True): 2820 del user_metrics 2821 self._assert_compile_was_called() 2822 kwargs = { 2823 'loss': self.loss, 2824 'metrics': self._compile_metrics, 2825 'loss_weights': self.loss_weights, 2826 'sample_weight_mode': self.sample_weight_mode, 2827 'weighted_metrics': self._compile_weighted_metrics, 2828 } 2829 return kwargs 2830 2831 @property 2832 def _compile_was_called(self): 2833 return self._v1_compile_was_called 2834 2835 2836class DistributedCallbackModel(Model): 2837 """Model that is used for callbacks with tf.distribute.Strategy.""" 2838 2839 def __init__(self, model): 2840 super(DistributedCallbackModel, self).__init__() 2841 self.optimizer = model.optimizer 2842 2843 def set_original_model(self, orig_model): 2844 self._original_model = orig_model 2845 2846 def save_weights(self, filepath, overwrite=True, save_format=None): 2847 self._replicated_model.save_weights(filepath, overwrite=overwrite, 2848 save_format=save_format) 2849 2850 def save(self, filepath, overwrite=True, include_optimizer=True): 2851 # save weights from the distributed model to the original model 2852 distributed_model_weights = self.get_weights() 2853 self._original_model.set_weights(distributed_model_weights) 2854 # TODO(anjalisridhar): Do we need to save the original model here? 2855 # Saving the first replicated model works as well. 2856 self._original_model.save(filepath, overwrite=True, include_optimizer=False) 2857 2858 def load_weights(self, filepath, by_name=False): 2859 self._original_model.load_weights(filepath, by_name=False) 2860 # Copy the weights from the original model to each of the replicated models. 2861 orig_model_weights = self._original_model.get_weights() 2862 distributed_training_utils_v1.set_weights( 2863 self._original_model._distribution_strategy, self, # pylint: disable=protected-access 2864 orig_model_weights) 2865 2866 def __getattr__(self, item): 2867 # Allowed attributes of the model that can be accessed by the user 2868 # during a callback. 2869 if item not in ('_setattr_tracking', '_layers'): 2870 logging.warning('You are accessing attribute ' + item + ' of the ' 2871 'DistributedCallbackModel that may not have been set ' 2872 'correctly.') 2873 return super(DistributedCallbackModel, self).__getattr__(item) 2874 2875 2876class _TrainingEndpoint(object): 2877 """A container for the training output/target and related entities. 2878 2879 In the case of model with multiple outputs, there is a one-to-one mapping 2880 between model output (y_pred), model target (y_true), loss, metrics etc. 2881 By unifying these entities into one class, different entity can access 2882 information between each other, rather than currently access different list of 2883 attributes of the model. 2884 """ 2885 2886 def __init__(self, 2887 output, 2888 output_name, 2889 loss_fn, 2890 loss_weight=None, 2891 training_target=None, 2892 output_loss_metric=None, 2893 sample_weight=None, 2894 sample_weight_mode=None): 2895 """Initialize the _TrainingEndpoint. 2896 2897 Note that the output and output_name should be stable as long as the model 2898 structure doesn't change. The training_target suppose to be mutable since 2899 the information is provided via `compile()` 2900 2901 Args: 2902 output: the output tensor of the model. 2903 output_name: the unique name of the output tensor. 2904 loss_fn: the loss function for the output tensor. 2905 loss_weight: float, the weights for the loss. 2906 training_target: the _TrainingTarget for the model. 2907 output_loss_metric: the metric object for the loss function. 2908 sample_weight: the weights for how a sample is weighted during metric and 2909 loss calculation. Could be None. 2910 sample_weight_mode: string, 'temporal', 'samplewise' or None. The mode for 2911 how the sample_weight is populated. 2912 """ 2913 self._output = output 2914 self._output_name = output_name 2915 self._loss_fn = loss_fn 2916 self._loss_weight = loss_weight 2917 self._training_target = training_target 2918 self._output_loss_metric = output_loss_metric 2919 self._sample_weight = sample_weight 2920 self._sample_weight_mode = sample_weight_mode 2921 2922 @property 2923 def output(self): 2924 return self._output 2925 2926 @property 2927 def output_name(self): 2928 return self._output_name 2929 2930 @property 2931 def shape(self): 2932 return K.int_shape(self.output) 2933 2934 @property 2935 def loss_fn(self): 2936 return self._loss_fn 2937 2938 @property 2939 def loss_weight(self): 2940 return self._loss_weight 2941 2942 @loss_weight.setter 2943 def loss_weight(self, value): 2944 self._loss_weight = value 2945 2946 @property 2947 def training_target(self): 2948 return self._training_target 2949 2950 @training_target.setter 2951 def training_target(self, value): 2952 self._training_target = value 2953 2954 def create_training_target(self, target, run_eagerly=False): 2955 """Create training_target instance and update the self.training_target. 2956 2957 Note that the input target should just be a tensor or None, and 2958 corresponding training target will be created based on the output and 2959 loss_fn. 2960 2961 Args: 2962 target: the target tensor for the current output. Could be None. 2963 run_eagerly: boolean, whether the model is in run_eagerly mode. 2964 2965 Raises: 2966 ValueError if the training_target field for the current instance has 2967 already been populated. 2968 """ 2969 if self.has_training_target(): 2970 raise ValueError('The training_target field for the _TrainingEndpoint ' 2971 'instance has already been populated') 2972 if run_eagerly: 2973 # When run_eagerly, the target tensor is ignored, and the None placeholder 2974 # is created instead. 2975 self.training_target = _TrainingTarget( 2976 None, feedable=True, skip_target_weights=False) 2977 return 2978 2979 if self.should_skip_target(): 2980 self.training_target = _TrainingTarget(None) 2981 else: 2982 if target is not None and not K.is_placeholder(target): 2983 feedable = False 2984 skip_target_weights = True 2985 else: 2986 feedable = True 2987 skip_target_weights = False 2988 2989 if target is None: 2990 target_dtype = losses.LABEL_DTYPES_FOR_LOSSES.get( 2991 self.loss_fn, K.dtype(self.output)) 2992 2993 target = K.placeholder( 2994 ndim=len(self.shape), 2995 name=self.output_name + '_target', 2996 sparse=K.is_sparse(self.output), 2997 dtype=target_dtype) 2998 2999 self.training_target = _TrainingTarget( 3000 target, 3001 feedable=feedable, 3002 skip_target_weights=skip_target_weights) 3003 3004 @property 3005 def output_loss_metric(self): 3006 return self._output_loss_metric 3007 3008 @output_loss_metric.setter 3009 def output_loss_metric(self, value): 3010 self._output_loss_metric = value 3011 3012 @property 3013 def sample_weight(self): 3014 return self._sample_weight 3015 3016 @sample_weight.setter 3017 def sample_weight(self, value): 3018 self._sample_weight = value 3019 3020 @property 3021 def sample_weight_mode(self): 3022 return self._sample_weight_mode 3023 3024 @sample_weight_mode.setter 3025 def sample_weight_mode(self, value): 3026 self._sample_weight_mode = value 3027 3028 def should_skip_target(self): 3029 return self._loss_fn is None 3030 3031 def should_skip_target_weights(self): 3032 return (self.should_skip_target() or self.training_target is None or 3033 self.training_target.skip_target_weights) 3034 3035 def has_training_target(self): 3036 return self.training_target is not None 3037 3038 def has_feedable_training_target(self): 3039 return (not self.should_skip_target() and 3040 self.training_target is not None and self.training_target.feedable) 3041 3042 def loss_name(self): 3043 if self._loss_fn is not None: 3044 return self._output_name + '_loss' 3045 return None 3046 3047 @property 3048 def feed_output_shape(self): 3049 """The output shape for the feedable target.""" 3050 if not self.has_feedable_training_target(): 3051 return None 3052 3053 if ((isinstance(self.loss_fn, losses.LossFunctionWrapper) and 3054 self.loss_fn.fn == losses.sparse_categorical_crossentropy)) or ( 3055 isinstance(self.loss_fn, losses.SparseCategoricalCrossentropy)): 3056 if K.image_data_format() == 'channels_first': 3057 return (self.shape[0], 1) + self.shape[2:] 3058 else: 3059 return self.shape[:-1] + (1,) 3060 elif (not isinstance(self.loss_fn, losses.Loss) or 3061 (isinstance(self.loss_fn, losses.LossFunctionWrapper) and 3062 (getattr(losses, self.loss_fn.fn.__name__, None) is None))): 3063 # If the given loss is not an instance of the `Loss` class (custom 3064 # class) or if the loss function that is wrapped is not in the 3065 # `losses` module, then it is a user-defined loss and we make no 3066 # assumptions about it. 3067 return None 3068 else: 3069 return self.shape 3070 3071 def sample_weights_mismatch(self): 3072 """Check if the sample weight and the mode match or not.""" 3073 # If there is a mismatch between sample weight mode and the placeholders 3074 # created, then recompile the sub-graphs that depend on sample weights. 3075 return ( 3076 (self.sample_weight_mode is not None and self.sample_weight is None) or 3077 (self.sample_weight_mode is None and self.sample_weight is not None)) 3078 3079 def populate_sample_weight(self, sample_weight, sample_weight_mode): 3080 """Populate the sample weight and based on the sample weight mode.""" 3081 if (sample_weight is None and 3082 (self.should_skip_target_weights() or sample_weight_mode is None or 3083 context.executing_eagerly())): 3084 self._sample_weight = None 3085 return 3086 3087 assert sample_weight_mode in ['temporal', 'samplewise'] 3088 if sample_weight_mode == 'temporal': 3089 default_value = [[1.]] 3090 shape = [None, None] 3091 else: 3092 # sample_weight_mode == 'samplewise' 3093 default_value = [1.] 3094 shape = [None] 3095 3096 if sample_weight is not None: 3097 if not sample_weight.shape.is_compatible_with(shape): 3098 raise ValueError('Received sample weight with shape {}. Expected shape ' 3099 '{}.'.format(sample_weight.shape, shape)) 3100 self._sample_weight = sample_weight 3101 else: 3102 self._sample_weight = array_ops.placeholder_with_default( 3103 constant_op.constant(default_value, dtype=K.floatx()), 3104 shape=shape, 3105 name=self.output_name + '_sample_weights') 3106 3107 3108class _TrainingTarget(object): 3109 """Container for a target tensor (y_true) and its metadata (shape, loss...). 3110 3111 Args: 3112 target: A target tensor for the model. It may be `None` if the 3113 output is excluded from loss computation. It is still kept as None 3114 since each output of the model should have a corresponding target. If 3115 the target is None, the rest of the attributes will be None as well. 3116 feedable: Boolean, whether the target is feedable (requires data to be 3117 passed in `fit` or `train_on_batch`), or not (model compiled with 3118 `target_tensors` argument). 3119 skip_target_weights: Boolean, whether the target should be skipped during 3120 weights calculation. 3121 """ 3122 3123 def __init__(self, target, feedable=False, skip_target_weights=True): 3124 self._target = target 3125 self._feedable = feedable 3126 self._skip_target_weights = skip_target_weights 3127 3128 @property 3129 def target(self): 3130 return self._target 3131 3132 @property 3133 def feedable(self): 3134 return self._feedable 3135 3136 @property 3137 def skip_target_weights(self): 3138 return self._skip_target_weights 3139 3140 3141def _is_symbolic_tensor(x): 3142 return tensor_util.is_tf_type(x) 3143 3144 3145def _convert_scipy_sparse_tensor(value, expected_input): 3146 """Handle scipy sparse tensor conversions. 3147 3148 This method takes a value 'value' and returns the proper conversion. If 3149 value is a scipy sparse tensor and the expected input is a dense tensor, 3150 we densify 'value'. If value is a scipy sparse tensor and the expected input 3151 is a TF SparseTensor, we convert 'value' to a SparseTensor. If 'value' is 3152 not a scipy sparse tensor, or scipy is not imported, we pass it through 3153 unchanged. 3154 3155 Args: 3156 value: An object that may be a scipy sparse tensor 3157 expected_input: The expected input placeholder. 3158 3159 Returns: 3160 The possibly-converted 'value'. 3161 """ 3162 if issparse is not None and issparse(value): 3163 if K.is_sparse(expected_input): 3164 sparse_coo = value.tocoo() 3165 row, col = sparse_coo.row, sparse_coo.col 3166 data, shape = sparse_coo.data, sparse_coo.shape 3167 indices = np.concatenate((np.expand_dims(row, 1), np.expand_dims(col, 1)), 3168 1) 3169 return sparse_tensor.SparseTensor(indices, data, shape) 3170 else: 3171 if ops.executing_eagerly_outside_functions(): 3172 # In TF2 we do not silently densify sparse matrices. 3173 raise ValueError('A SciPy sparse matrix was passed to a model ' 3174 'that expects dense inputs. Please densify your ' 3175 'inputs first, such as by calling `x.toarray().') 3176 return value.toarray() 3177 else: 3178 return value 3179 3180 3181def _get_metrics_from_layers(layers): 3182 """Returns list of metrics from the given layers. 3183 3184 This will not include the `compile` metrics of a model layer. 3185 3186 Args: 3187 layers: List of layers. 3188 3189 Returns: 3190 List of metrics. 3191 """ 3192 metrics = [] 3193 layers = layer_utils.filter_empty_layer_containers(layers) 3194 for layer in layers: 3195 if isinstance(layer, Model): 3196 # We cannot call 'metrics' on the model because we do not want to 3197 # include the metrics that were added in compile API of a nested model. 3198 metrics.extend(layer._metrics) # pylint: disable=protected-access 3199 metrics.extend(_get_metrics_from_layers(layer.layers)) 3200 else: 3201 metrics.extend(layer.metrics) 3202 return metrics 3203 3204 3205def _non_none_constant_value(v): 3206 constant_value = tensor_util.constant_value(v) 3207 return constant_value if constant_value is not None else v 3208