1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Training-related part of the Keras engine. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import numpy as np 23 24from tensorflow.python import tf2 25from tensorflow.python.data.ops import dataset_ops 26from tensorflow.python.data.ops import iterator_ops 27from tensorflow.python.distribute import distribute_coordinator as dc 28from tensorflow.python.distribute import distribution_strategy_context 29from tensorflow.python.eager import context 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.framework import tensor_util 33from tensorflow.python.keras import backend as K 34from tensorflow.python.keras import losses 35from tensorflow.python.keras import metrics as metrics_module 36from tensorflow.python.keras import optimizers 37from tensorflow.python.keras.engine import distributed_training_utils 38from tensorflow.python.keras.engine import network 39from tensorflow.python.keras.engine import training_arrays 40from tensorflow.python.keras.engine import training_distributed 41from tensorflow.python.keras.engine import training_eager 42from tensorflow.python.keras.engine import training_generator 43from tensorflow.python.keras.engine import training_utils 44from tensorflow.python.keras.saving import saving_utils 45from tensorflow.python.keras.utils import data_utils 46from tensorflow.python.keras.utils import losses_utils 47from tensorflow.python.keras.utils.generic_utils import slice_arrays 48from tensorflow.python.keras.utils.mode_keys import ModeKeys 49from tensorflow.python.ops import math_ops 50from tensorflow.python.platform import tf_logging as logging 51from tensorflow.python.training.tracking import base as trackable 52from tensorflow.python.util import nest 53from tensorflow.python.util.tf_export import keras_export 54 55 56@keras_export('keras.models.Model', 'keras.Model') 57class Model(network.Network): 58 """`Model` groups layers into an object with training and inference features. 59 60 There are two ways to instantiate a `Model`: 61 62 1 - With the "functional API", where you start from `Input`, 63 you chain layer calls to specify the model's forward pass, 64 and finally you create your model from inputs and outputs: 65 66 ```python 67 import tensorflow as tf 68 69 inputs = tf.keras.Input(shape=(3,)) 70 x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs) 71 outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x) 72 model = tf.keras.Model(inputs=inputs, outputs=outputs) 73 ``` 74 75 2 - By subclassing the `Model` class: in that case, you should define your 76 layers in `__init__` and you should implement the model's forward pass 77 in `call`. 78 79 ```python 80 import tensorflow as tf 81 82 class MyModel(tf.keras.Model): 83 84 def __init__(self): 85 super(MyModel, self).__init__() 86 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 87 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 88 89 def call(self, inputs): 90 x = self.dense1(inputs) 91 return self.dense2(x) 92 93 model = MyModel() 94 ``` 95 96 If you subclass `Model`, you can optionally have 97 a `training` argument (boolean) in `call`, which you can use to specify 98 a different behavior in training and inference: 99 100 ```python 101 import tensorflow as tf 102 103 class MyModel(tf.keras.Model): 104 105 def __init__(self): 106 super(MyModel, self).__init__() 107 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 108 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 109 self.dropout = tf.keras.layers.Dropout(0.5) 110 111 def call(self, inputs, training=False): 112 x = self.dense1(inputs) 113 if training: 114 x = self.dropout(x, training=training) 115 return self.dense2(x) 116 117 model = MyModel() 118 ``` 119 """ 120 121 def __init__(self, *args, **kwargs): 122 super(Model, self).__init__(*args, **kwargs) 123 # initializing _distribution_strategy here since it is possible to call 124 # predict on a model without compiling it. 125 self._distribution_strategy = None 126 # This flag is used to track if the user is using the deprecated path of 127 # passing distribution strategy to compile rather than creating the model 128 # under distribution strategy scope. 129 self._compile_distribution = False 130 131 self.run_eagerly = None 132 133 def get_weights(self): 134 """Retrieves the weights of the model. 135 136 Returns: 137 A flat list of Numpy arrays. 138 """ 139 if self._distribution_strategy: 140 with self._distribution_strategy.scope(): 141 return super(Model, self).get_weights() 142 return super(Model, self).get_weights() 143 144 def load_weights(self, filepath, by_name=False): 145 """Loads all layer weights, either from a TensorFlow or an HDF5 file.""" 146 if distributed_training_utils.is_tpu_strategy(self._distribution_strategy): 147 if (self._distribution_strategy.extended.steps_per_run > 1 and 148 (not network._is_hdf5_filepath(filepath))): # pylint: disable=protected-access 149 raise ValueError('Load weights is not yet supported with TPUStrategy ' 150 'with steps_per_run greater than 1.') 151 return super(Model, self).load_weights(filepath, by_name) 152 153 @trackable.no_automatic_dependency_tracking 154 def compile(self, 155 optimizer, 156 loss=None, 157 metrics=None, 158 loss_weights=None, 159 sample_weight_mode=None, 160 weighted_metrics=None, 161 target_tensors=None, 162 distribute=None, 163 **kwargs): 164 """Configures the model for training. 165 166 Arguments: 167 optimizer: String (name of optimizer) or optimizer instance. 168 See `tf.keras.optimizers`. 169 loss: String (name of objective function), objective function or 170 `tf.losses.Loss` instance. See `tf.losses`. If the model has 171 multiple outputs, you can use a different loss on each output by 172 passing a dictionary or a list of losses. The loss value that will 173 be minimized by the model will then be the sum of all individual 174 losses. 175 metrics: List of metrics to be evaluated by the model during training 176 and testing. Typically you will use `metrics=['accuracy']`. 177 To specify different metrics for different outputs of a 178 multi-output model, you could also pass a dictionary, such as 179 `metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}`. 180 You can also pass a list (len = len(outputs)) of lists of metrics 181 such as `metrics=[['accuracy'], ['accuracy', 'mse']]` or 182 `metrics=['accuracy', ['accuracy', 'mse']]`. 183 loss_weights: Optional list or dictionary specifying scalar 184 coefficients (Python floats) to weight the loss contributions 185 of different model outputs. 186 The loss value that will be minimized by the model 187 will then be the *weighted sum* of all individual losses, 188 weighted by the `loss_weights` coefficients. 189 If a list, it is expected to have a 1:1 mapping 190 to the model's outputs. If a tensor, it is expected to map 191 output names (strings) to scalar coefficients. 192 sample_weight_mode: If you need to do timestep-wise 193 sample weighting (2D weights), set this to `"temporal"`. 194 `None` defaults to sample-wise weights (1D). 195 If the model has multiple outputs, you can use a different 196 `sample_weight_mode` on each output by passing a 197 dictionary or a list of modes. 198 weighted_metrics: List of metrics to be evaluated and weighted 199 by sample_weight or class_weight during training and testing. 200 target_tensors: By default, Keras will create placeholders for the 201 model's target, which will be fed with the target data during 202 training. If instead you would like to use your own 203 target tensors (in turn, Keras will not expect external 204 Numpy data for these targets at training time), you 205 can specify them via the `target_tensors` argument. It can be 206 a single tensor (for a single-output model), a list of tensors, 207 or a dict mapping output names to target tensors. 208 distribute: NOT SUPPORTED IN TF 2.0, please create and compile the 209 model under distribution strategy scope instead of passing it to 210 compile. 211 **kwargs: Any additional arguments. 212 213 Raises: 214 ValueError: In case of invalid arguments for 215 `optimizer`, `loss`, `metrics` or `sample_weight_mode`. 216 """ 217 run_eagerly = kwargs.pop('run_eagerly', None) 218 if run_eagerly and getattr(self, '_contains_symbolic_tensors', False): 219 raise ValueError( 220 'We currently do not support enabling `run_eagerly` on compile if ' 221 '`model.add_loss(tensor)` or `model.add_metric(tensor)` ' 222 'has been called.') 223 224 self._run_eagerly = run_eagerly 225 optimizer = optimizers.get(optimizer) 226 227 if distribute is not None: 228 if tf2.enabled(): 229 raise ValueError( 230 'Distribute argument in compile is not available in TF 2.0 please ' 231 'create the model under the distribution strategy scope.') 232 logging.warning('Distribute argument in compile is deprecated please ' 233 'create the model under the distribution strategy scope.') 234 self._distribution_strategy = distribute 235 self._compile_distribution = True 236 else: 237 if distribution_strategy_context.has_strategy(): 238 # When the user builds the model in the DS scope and cross replica 239 # context we want distribution strategy to be set but when building the 240 # replica copies of the models internally we should not be compiling 241 # with distribution strategy and use the default compilation path. 242 if distribution_strategy_context.in_cross_replica_context(): 243 self._distribution_strategy = ( 244 distribution_strategy_context.get_strategy()) 245 246 # Validate that arguments passed by the user to `compile` are supported by 247 # DistributionStrategy. 248 if self._distribution_strategy: 249 if sample_weight_mode: 250 raise NotImplementedError('sample_weight_mode is not supported with ' 251 'DistributionStrategy.') 252 if weighted_metrics: 253 raise NotImplementedError('weighted_metrics is not supported with ' 254 'DistributionStrategy.') 255 if target_tensors: 256 raise ValueError('target_tensors is not supported with ' 257 'DistributionStrategy.') 258 259 if run_eagerly: 260 raise ValueError( 261 'We currently do not support enabling `run_eagerly` with ' 262 'distribution strategy.') 263 264 if getattr(self, '_contains_symbolic_tensors', False): 265 raise ValueError( 266 'We currently do not support compiling the model with distribution ' 267 'strategy if `model.add_loss(tensor)` or `model.add_metric(tensor)`' 268 ' has been called.') 269 270 if not self.built or not self.inputs or not self.outputs: 271 raise ValueError( 272 'We currently do not support distribution strategy with a ' 273 '`Sequential` model that is created without `input_shape`/' 274 '`input_dim` set in its first layer or a subclassed model.') 275 276 loss = loss or {} 277 278 self.optimizer = optimizer 279 # We've disabled automatic dependency tracking for this method, but do want 280 # to add a checkpoint dependency on the optimizer if it's trackable. 281 if isinstance(self.optimizer, trackable.Trackable): 282 self._track_trackable( 283 self.optimizer, name='optimizer', overwrite=True) 284 self.loss = loss 285 self._compile_metrics = metrics or [] 286 self.loss_weights = loss_weights 287 self.sample_weight_mode = sample_weight_mode 288 self._compile_weighted_metrics = weighted_metrics 289 if self.run_eagerly and target_tensors is not None: 290 raise ValueError( 291 'target_tensors argument is not supported when ' 292 'running a model eagerly.') 293 self.target_tensors = target_tensors 294 295 # Set DistributionStrategy specific parameters. 296 self._distributed_model_cache = {} 297 298 if self._distribution_strategy is not None: 299 # Ensures a Session is created and configured correctly for Distribution 300 # Strategy. 301 K.configure_and_create_distributed_session(self._distribution_strategy) 302 # Initialize model metric attributes. 303 self._init_metric_attributes() 304 if not self.built or not self.inputs or not self.outputs: 305 # Model is not compilable because it does not know its number of inputs 306 # and outputs, nor their shapes and names. We will compile after the first 307 # time the model gets called on training data. 308 return 309 self._is_compiled = True 310 311 # Prepare list of loss functions, same size of model outputs. 312 self.loss_functions = training_utils.prepare_loss_functions( 313 loss, self.output_names) 314 315 self._feed_outputs = [] 316 self._feed_output_names = [] 317 self._feed_output_shapes = [] 318 self._feed_loss_fns = [] 319 # if loss function is None, then this output will be skipped during total 320 # loss calculation and feed targets preparation. 321 skip_target_indices = [] 322 skip_target_weighing_indices = [] 323 for i, loss_function in enumerate(self.loss_functions): 324 if loss_function is None: 325 skip_target_indices.append(i) 326 skip_target_weighing_indices.append(i) 327 328 # Prepare output masks. 329 if not self.run_eagerly: 330 masks = [getattr(x, '_keras_mask', None) for x in self.outputs] 331 332 # Prepare list loss weights, same size of model outputs. 333 self.loss_weights_list = training_utils.prepare_loss_weights( 334 self.output_names, loss_weights) 335 336 # Initialization for Eager mode execution. 337 if self.run_eagerly: 338 # Prepare sample weights. 339 self._set_sample_weight_attributes(sample_weight_mode, 340 skip_target_weighing_indices) 341 # Save all metric attributes per output of the model. 342 self._cache_output_metric_attributes(metrics, weighted_metrics) 343 344 if target_tensors is not None: 345 raise ValueError('target_tensors are not currently supported in Eager ' 346 'mode.') 347 self.total_loss = None 348 349 # Set metric attributes on model. 350 self._set_metric_attributes(skip_target_indices=skip_target_indices) 351 352 self.targets = [] 353 for i in range(len(self.outputs)): 354 self._feed_output_names.append(self.output_names[i]) 355 self._collected_trainable_weights = self.trainable_weights 356 return 357 358 with K.get_graph().as_default(): 359 # Prepare targets of model. 360 self.targets = [] 361 self._feed_targets = [] 362 if target_tensors not in (None, []): 363 if isinstance(target_tensors, list): 364 if len(target_tensors) != len(self.outputs): 365 raise ValueError( 366 'When passing a list as `target_tensors`, ' 367 'it should have one entry per model output. ' 368 'The model has %s outputs, but you passed target_tensors=%s' % 369 (len(self.outputs), target_tensors)) 370 elif isinstance(target_tensors, dict): 371 for name in target_tensors: 372 if name not in self.output_names: 373 raise ValueError( 374 'Unknown entry in `target_tensors` ' 375 'dictionary: "' + name + '". ' 376 'Only expected the following keys: ' + str(self.output_names)) 377 tmp_target_tensors = [] 378 for name in self.output_names: 379 tmp_target_tensors.append(target_tensors.get(name, None)) 380 target_tensors = tmp_target_tensors 381 elif tensor_util.is_tensor(target_tensors): 382 target_tensors = [target_tensors] 383 else: 384 raise TypeError('Expected `target_tensors` to be a list or tuple or ' 385 'dict or a single tensor, but got:', target_tensors) 386 387 for i in range(len(self.outputs)): 388 if i in skip_target_indices: 389 self.targets.append(None) 390 else: 391 shape = K.int_shape(self.outputs[i]) 392 name = self.output_names[i] 393 if target_tensors not in (None, []): 394 target = target_tensors[i] 395 else: 396 target = None 397 if target is None or K.is_placeholder(target): 398 if target is None: 399 target_dtype = losses.LABEL_DTYPES_FOR_LOSSES.get( 400 self.loss_functions[i], 401 K.dtype(self.outputs[i])) 402 403 target = K.placeholder( 404 ndim=len(shape), 405 name=name + '_target', 406 sparse=K.is_sparse(self.outputs[i]), 407 dtype=target_dtype) 408 self._feed_targets.append(target) 409 self._feed_outputs.append(self.outputs[i]) 410 self._feed_output_names.append(name) 411 self._feed_output_shapes.append(shape) 412 self._feed_loss_fns.append(self.loss_functions[i]) 413 else: 414 skip_target_weighing_indices.append(i) 415 self.targets.append(target) 416 417 # Prepare sample weights. 418 self._set_sample_weight_attributes(sample_weight_mode, 419 skip_target_weighing_indices) 420 # Save all metric attributes per output of the model. 421 self._cache_output_metric_attributes(metrics, weighted_metrics) 422 423 # Set metric attributes on model. 424 self._set_metric_attributes(skip_target_indices=skip_target_indices) 425 426 # Invoke metric functions for all the outputs. 427 self._handle_metrics( 428 self.outputs, 429 masks=masks, 430 targets=self.targets, 431 skip_target_indices=skip_target_indices, 432 sample_weights=self.sample_weights) 433 434 # Compute total loss. 435 # Used to keep track of the total loss value (stateless). 436 # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) + 437 # loss_weight_2 * output_2_loss_fn(...) + 438 # layer losses. 439 self.total_loss = self._prepare_total_loss(skip_target_indices, masks) 440 441 # Functions for train, test and predict will 442 # be compiled lazily when required. 443 # This saves time when the user is not using all functions. 444 self._function_kwargs = kwargs 445 446 self.train_function = None 447 self.test_function = None 448 self.predict_function = None 449 450 # Collected trainable weights, sorted in topological order. 451 trainable_weights = self.trainable_weights 452 self._collected_trainable_weights = trainable_weights 453 454 # Validate all variables were correctly created in distribution scope. 455 if self._distribution_strategy and not self._compile_distribution: 456 for v in self.variables: 457 strategy = self._distribution_strategy 458 if not strategy.extended.variable_created_in_scope(v): 459 raise ValueError( 460 'Variable (%s) was not created in the distribution strategy ' 461 'scope of (%s). It is most likely due to not all layers or ' 462 'the model or optimizer being created outside the distribution ' 463 'strategy scope. Try to make sure your code looks similar ' 464 'to the following.\n' 465 'with strategy.scope():\n' 466 ' model=_create_model()\n' 467 ' model.compile(...)'% (v, strategy)) 468 469 @property 470 def metrics(self): 471 """Returns the model's metrics added using `compile`, `add_metric` APIs.""" 472 metrics = [] 473 if self._is_compiled: 474 metrics += self._compile_metric_functions 475 return metrics + super(Model, self).metrics 476 477 @property 478 def metrics_names(self): 479 """Returns the model's display labels for all outputs.""" 480 metrics_names = [] 481 if self._is_compiled: 482 metrics_names += self._compile_metrics_names # Includes names of losses. 483 484 # Add metric names from layers. 485 for layer in self.layers: 486 metrics_names += [m.name for m in layer._metrics] # pylint: disable=protected-access 487 metrics_names += [m.name for m in self._metrics] 488 return metrics_names 489 490 @property 491 def run_eagerly(self): 492 """Settable attribute indicating whether the model should run eagerly. 493 494 Running eagerly means that your model will be run step by step, 495 like Python code. Your model might run slower, but it should become easier 496 for you to debug it by stepping into individual layer calls. 497 498 By default, we will attempt to compile your model to a static graph to 499 deliver the best execution performance. 500 501 Returns: 502 Boolean, whether the model should run eagerly. 503 """ 504 if self._run_eagerly is True and not context.executing_eagerly(): 505 raise ValueError('You can only set `run_eagerly=True` if eager execution ' 506 'is enabled.') 507 if not self.dynamic: 508 if self._run_eagerly is None: 509 return False 510 else: 511 return self._run_eagerly 512 else: 513 if not context.executing_eagerly(): 514 raise ValueError('Your model contains layers that can only be ' 515 'successfully run in eager execution (layers ' 516 'constructed with `dynamic=True`). ' 517 'You must enable eager execution with ' 518 '`tf.enable_eager_execution()`.') 519 if self._run_eagerly is False: 520 # TODO(fchollet): consider using py_func to enable this. 521 raise ValueError('Your model contains layers that can only be ' 522 'successfully run in eager execution (layers ' 523 'constructed with `dynamic=True`). ' 524 'You cannot set `run_eagerly=False`.') 525 return context.executing_eagerly() 526 527 @run_eagerly.setter 528 def run_eagerly(self, value): 529 self._run_eagerly = value 530 531 def fit(self, 532 x=None, 533 y=None, 534 batch_size=None, 535 epochs=1, 536 verbose=1, 537 callbacks=None, 538 validation_split=0., 539 validation_data=None, 540 shuffle=True, 541 class_weight=None, 542 sample_weight=None, 543 initial_epoch=0, 544 steps_per_epoch=None, 545 validation_steps=None, 546 validation_freq=1, 547 max_queue_size=10, 548 workers=1, 549 use_multiprocessing=False, 550 **kwargs): 551 """Trains the model for a fixed number of epochs (iterations on a dataset). 552 553 Arguments: 554 x: Input data. It could be: 555 - A Numpy array (or array-like), or a list of arrays 556 (in case the model has multiple inputs). 557 - A TensorFlow tensor, or a list of tensors 558 (in case the model has multiple inputs). 559 - A dict mapping input names to the corresponding array/tensors, 560 if the model has named inputs. 561 - A `tf.data` dataset or a dataset iterator. Should return a tuple 562 of either `(inputs, targets)` or 563 `(inputs, targets, sample_weights)`. 564 - A generator or `keras.utils.Sequence` returning `(inputs, targets)` 565 or `(inputs, targets, sample weights)`. 566 y: Target data. Like the input data `x`, 567 it could be either Numpy array(s) or TensorFlow tensor(s). 568 It should be consistent with `x` (you cannot have Numpy inputs and 569 tensor targets, or inversely). If `x` is a dataset, dataset 570 iterator, generator, or `keras.utils.Sequence` instance, `y` should 571 not be specified (since targets will be obtained from `x`). 572 batch_size: Integer or `None`. 573 Number of samples per gradient update. 574 If unspecified, `batch_size` will default to 32. 575 Do not specify the `batch_size` if your data is in the 576 form of symbolic tensors, dataset, dataset iterators, 577 generators, or `keras.utils.Sequence` instances (since they generate 578 batches). 579 epochs: Integer. Number of epochs to train the model. 580 An epoch is an iteration over the entire `x` and `y` 581 data provided. 582 Note that in conjunction with `initial_epoch`, 583 `epochs` is to be understood as "final epoch". 584 The model is not trained for a number of iterations 585 given by `epochs`, but merely until the epoch 586 of index `epochs` is reached. 587 verbose: Integer. 0, 1, or 2. Verbosity mode. 588 0 = silent, 1 = progress bar, 2 = one line per epoch. 589 callbacks: List of `keras.callbacks.Callback` instances. 590 List of callbacks to apply during training. 591 See `tf.keras.callbacks`. 592 validation_split: Float between 0 and 1. 593 Fraction of the training data to be used as validation data. 594 The model will set apart this fraction of the training data, 595 will not train on it, and will evaluate 596 the loss and any model metrics 597 on this data at the end of each epoch. 598 The validation data is selected from the last samples 599 in the `x` and `y` data provided, before shuffling. This argument is 600 not supported when `x` is a dataset, dataset iterator, generator or 601 `keras.utils.Sequence` instance. 602 validation_data: Data on which to evaluate 603 the loss and any model metrics at the end of each epoch. 604 The model will not be trained on this data. 605 `validation_data` will override `validation_split`. 606 `validation_data` could be: 607 - tuple `(x_val, y_val)` of Numpy arrays or tensors 608 - tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays 609 - dataset or a dataset iterator 610 For the first two cases, `batch_size` must be provided. 611 For the last case, `validation_steps` must be provided. 612 shuffle: Boolean (whether to shuffle the training data 613 before each epoch) or str (for 'batch'). 614 'batch' is a special option for dealing with the 615 limitations of HDF5 data; it shuffles in batch-sized chunks. 616 Has no effect when `steps_per_epoch` is not `None`. 617 class_weight: Optional dictionary mapping class indices (integers) 618 to a weight (float) value, used for weighting the loss function 619 (during training only). 620 This can be useful to tell the model to 621 "pay more attention" to samples from 622 an under-represented class. 623 sample_weight: Optional Numpy array of weights for 624 the training samples, used for weighting the loss function 625 (during training only). You can either pass a flat (1D) 626 Numpy array with the same length as the input samples 627 (1:1 mapping between weights and samples), 628 or in the case of temporal data, 629 you can pass a 2D array with shape 630 `(samples, sequence_length)`, 631 to apply a different weight to every timestep of every sample. 632 In this case you should make sure to specify 633 `sample_weight_mode="temporal"` in `compile()`. This argument is not 634 supported when `x` is a dataset, dataset iterator, generator, or 635 `keras.utils.Sequence` instance, instead provide the sample_weights 636 as the third element of `x`. 637 initial_epoch: Integer. 638 Epoch at which to start training 639 (useful for resuming a previous training run). 640 steps_per_epoch: Integer or `None`. 641 Total number of steps (batches of samples) 642 before declaring one epoch finished and starting the 643 next epoch. When training with input tensors such as 644 TensorFlow data tensors, the default `None` is equal to 645 the number of samples in your dataset divided by 646 the batch size, or 1 if that cannot be determined. If x is a 647 `tf.data` dataset or a dataset iterator, and 'steps_per_epoch' 648 is None, the epoch will run until the input dataset is exhausted. 649 validation_steps: Only relevant if `validation_data` is provided and 650 is a dataset or dataset iterator. Total number of steps (batches of 651 samples) to draw before stopping when performing validation 652 at the end of every epoch. If validation_data is a `tf.data` dataset 653 or a dataset iterator, and 'validation_steps' is None, validation 654 will run until the `validation_data` dataset is exhausted. 655 validation_freq: Only relevant if validation data is provided. Integer 656 or `collections.Container` instance (e.g. list, tuple, etc.). If an 657 integer, specifies how many training epochs to run before a new 658 validation run is performed, e.g. `validation_freq=2` runs 659 validation every 2 epochs. If a Container, specifies the epochs on 660 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs 661 validation at the end of the 1st, 2nd, and 10th epochs. 662 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 663 input only. Maximum size for the generator queue. 664 If unspecified, `max_queue_size` will default to 10. 665 workers: Integer. Used for generator or `keras.utils.Sequence` input 666 only. Maximum number of processes to spin up 667 when using process-based threading. If unspecified, `workers` 668 will default to 1. If 0, will execute the generator on the main 669 thread. 670 use_multiprocessing: Boolean. Used for generator or 671 `keras.utils.Sequence` input only. If `True`, use process-based 672 threading. If unspecified, `use_multiprocessing` will default to 673 `False`. Note that because this implementation relies on 674 multiprocessing, you should not pass non-picklable arguments to 675 the generator as they can't be passed easily to children processes. 676 **kwargs: Used for backwards compatibility. 677 678 Returns: 679 A `History` object. Its `History.history` attribute is 680 a record of training loss values and metrics values 681 at successive epochs, as well as validation loss values 682 and validation metrics values (if applicable). 683 684 Raises: 685 RuntimeError: If the model was never compiled. 686 ValueError: In case of mismatch between the provided input data 687 and what the model expects. 688 """ 689 # Legacy support 690 if 'nb_epoch' in kwargs: 691 logging.warning( 692 'The `nb_epoch` argument in `fit` ' 693 'has been renamed `epochs`.') 694 epochs = kwargs.pop('nb_epoch') 695 if kwargs: 696 raise TypeError('Unrecognized keyword arguments: ' + str(kwargs)) 697 698 # Case 1: distribution strategy. 699 if self._distribution_strategy: 700 if K.in_multi_worker_mode(): 701 # Multi-Worker mode runs the Keras training loop on multiple 702 # servers via the Distribute Coordinator. 703 def _worker_fn(_): 704 """Run training inside the distributed coordinator.""" 705 filtered_callbacks = distributed_training_utils \ 706 .filter_distributed_callbacks(callbacks) 707 return training_distributed.fit_distributed( 708 self, 709 x=x, 710 y=y, 711 batch_size=batch_size, 712 epochs=epochs, 713 verbose=verbose, 714 callbacks=filtered_callbacks, 715 validation_split=validation_split, 716 validation_data=validation_data, 717 shuffle=shuffle, 718 class_weight=class_weight, 719 sample_weight=sample_weight, 720 initial_epoch=initial_epoch, 721 steps_per_epoch=steps_per_epoch, 722 validation_steps=validation_steps, 723 validation_freq=validation_freq) 724 725 # Independent worker only for now. 726 return dc.run_distribute_coordinator( 727 _worker_fn, 728 self._distribution_strategy, 729 mode=dc.CoordinatorMode.INDEPENDENT_WORKER) 730 else: 731 return training_distributed.fit_distributed( 732 self, 733 x=x, 734 y=y, 735 batch_size=batch_size, 736 epochs=epochs, 737 verbose=verbose, 738 callbacks=callbacks, 739 validation_split=validation_split, 740 validation_data=validation_data, 741 shuffle=shuffle, 742 class_weight=class_weight, 743 sample_weight=sample_weight, 744 initial_epoch=initial_epoch, 745 steps_per_epoch=steps_per_epoch, 746 validation_steps=validation_steps, 747 validation_freq=validation_freq) 748 749 batch_size = self._validate_or_infer_batch_size( 750 batch_size, steps_per_epoch, x) 751 752 # Case 2: generator-like. Input is Python generator, or Sequence object, 753 # or a non-distributed Dataset or iterator in eager execution. 754 if data_utils.is_generator_or_sequence(x): 755 training_utils.check_generator_arguments( 756 y, sample_weight, validation_split=validation_split) 757 return self.fit_generator( 758 x, 759 steps_per_epoch=steps_per_epoch, 760 epochs=epochs, 761 verbose=verbose, 762 callbacks=callbacks, 763 validation_data=validation_data, 764 validation_steps=validation_steps, 765 validation_freq=validation_freq, 766 class_weight=class_weight, 767 max_queue_size=max_queue_size, 768 workers=workers, 769 use_multiprocessing=use_multiprocessing, 770 shuffle=shuffle, 771 initial_epoch=initial_epoch) 772 if training_utils.is_eager_dataset_or_iterator(x): 773 # Make sure that y, sample_weights, validation_split are not passed. 774 training_utils.validate_dataset_input(x, y, sample_weight, 775 validation_split) 776 if (isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)) 777 and shuffle): 778 training_utils.verify_dataset_shuffled(x) 779 780 return self.fit_generator( 781 x, 782 steps_per_epoch=steps_per_epoch, 783 epochs=epochs, 784 verbose=verbose, 785 callbacks=callbacks, 786 validation_data=validation_data, 787 validation_steps=validation_steps, 788 validation_freq=validation_freq, 789 class_weight=class_weight, 790 workers=0, 791 shuffle=shuffle, 792 initial_epoch=initial_epoch) 793 794 # Case 3: Symbolic tensors or Numpy array-like. 795 # This includes Datasets and iterators in graph mode (since they 796 # generate symbolic tensors). 797 x, y, sample_weights = self._standardize_user_data( 798 x, 799 y, 800 sample_weight=sample_weight, 801 class_weight=class_weight, 802 batch_size=batch_size, 803 check_steps=True, 804 steps_name='steps_per_epoch', 805 steps=steps_per_epoch, 806 validation_split=validation_split, 807 shuffle=shuffle) 808 809 # Prepare validation data. 810 if validation_data: 811 val_x, val_y, val_sample_weights = self._unpack_validation_data( 812 validation_data) 813 val_x, val_y, val_sample_weights = self._standardize_user_data( 814 val_x, 815 val_y, 816 sample_weight=val_sample_weights, 817 batch_size=batch_size, 818 steps=validation_steps, 819 steps_name='validation_steps') 820 elif validation_split and 0. < validation_split < 1.: 821 if training_utils.has_symbolic_tensors(x): 822 raise ValueError('If your data is in the form of symbolic tensors, ' 823 'you cannot use `validation_split`.') 824 if hasattr(x[0], 'shape'): 825 split_at = int(x[0].shape[0] * (1. - validation_split)) 826 else: 827 split_at = int(len(x[0]) * (1. - validation_split)) 828 x, val_x = (slice_arrays(x, 0, split_at), slice_arrays(x, split_at)) 829 y, val_y = (slice_arrays(y, 0, split_at), slice_arrays(y, split_at)) 830 sample_weights, val_sample_weights = (slice_arrays( 831 sample_weights, 0, split_at), slice_arrays(sample_weights, split_at)) 832 elif validation_steps: 833 val_x = [] 834 val_y = [] 835 val_sample_weights = [] 836 else: 837 val_x = None 838 val_y = None 839 val_sample_weights = None 840 841 if self.run_eagerly: 842 return training_generator.fit_generator( 843 self, (x, y, sample_weights), 844 steps_per_epoch=steps_per_epoch, 845 batch_size=batch_size, 846 epochs=epochs, 847 verbose=verbose, 848 callbacks=callbacks, 849 validation_data=validation_data, 850 validation_steps=validation_steps, 851 validation_freq=validation_freq, 852 workers=0, 853 shuffle=shuffle, 854 initial_epoch=initial_epoch, 855 steps_name='steps_per_epoch') 856 else: 857 return training_arrays.fit_loop( 858 self, 859 x, 860 y, 861 sample_weights=sample_weights, 862 batch_size=batch_size, 863 epochs=epochs, 864 verbose=verbose, 865 callbacks=callbacks, 866 val_inputs=val_x, 867 val_targets=val_y, 868 val_sample_weights=val_sample_weights, 869 shuffle=shuffle, 870 initial_epoch=initial_epoch, 871 steps_per_epoch=steps_per_epoch, 872 validation_steps=validation_steps, 873 validation_freq=validation_freq, 874 steps_name='steps_per_epoch') 875 876 def evaluate(self, 877 x=None, 878 y=None, 879 batch_size=None, 880 verbose=1, 881 sample_weight=None, 882 steps=None, 883 callbacks=None, 884 max_queue_size=10, 885 workers=1, 886 use_multiprocessing=False): 887 """Returns the loss value & metrics values for the model in test mode. 888 889 Computation is done in batches. 890 891 Arguments: 892 x: Input data. It could be: 893 - A Numpy array (or array-like), or a list of arrays 894 (in case the model has multiple inputs). 895 - A TensorFlow tensor, or a list of tensors 896 (in case the model has multiple inputs). 897 - A dict mapping input names to the corresponding array/tensors, 898 if the model has named inputs. 899 - A `tf.data` dataset or a dataset iterator. 900 - A generator or `keras.utils.Sequence` instance. 901 y: Target data. Like the input data `x`, 902 it could be either Numpy array(s) or TensorFlow tensor(s). 903 It should be consistent with `x` (you cannot have Numpy inputs and 904 tensor targets, or inversely). 905 If `x` is a dataset, dataset iterator, generator or 906 `keras.utils.Sequence` instance, `y` should not be specified (since 907 targets will be obtained from the iterator/dataset). 908 batch_size: Integer or `None`. 909 Number of samples per gradient update. 910 If unspecified, `batch_size` will default to 32. 911 Do not specify the `batch_size` is your data is in the 912 form of symbolic tensors, dataset, dataset iterators, 913 generators, or `keras.utils.Sequence` instances (since they generate 914 batches). 915 verbose: 0 or 1. Verbosity mode. 916 0 = silent, 1 = progress bar. 917 sample_weight: Optional Numpy array of weights for 918 the test samples, used for weighting the loss function. 919 You can either pass a flat (1D) 920 Numpy array with the same length as the input samples 921 (1:1 mapping between weights and samples), 922 or in the case of temporal data, 923 you can pass a 2D array with shape 924 `(samples, sequence_length)`, 925 to apply a different weight to every timestep of every sample. 926 In this case you should make sure to specify 927 `sample_weight_mode="temporal"` in `compile()`. This argument is not 928 supported when `x` is a dataset or a dataset iterator, instead pass 929 sample weights as the third element of `x`. 930 steps: Integer or `None`. 931 Total number of steps (batches of samples) 932 before declaring the evaluation round finished. 933 Ignored with the default value of `None`. 934 If x is a `tf.data` dataset or a dataset iterator, and `steps` is 935 None, 'evaluate' will run until the dataset is exhausted. 936 callbacks: List of `keras.callbacks.Callback` instances. 937 List of callbacks to apply during evaluation. 938 See [callbacks](/api_docs/python/tf/keras/callbacks). 939 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 940 input only. Maximum size for the generator queue. 941 If unspecified, `max_queue_size` will default to 10. 942 workers: Integer. Used for generator or `keras.utils.Sequence` input 943 only. Maximum number of processes to spin up when using 944 process-based threading. If unspecified, `workers` will default 945 to 1. If 0, will execute the generator on the main thread. 946 use_multiprocessing: Boolean. Used for generator or 947 `keras.utils.Sequence` input only. If `True`, use process-based 948 threading. If unspecified, `use_multiprocessing` will default to 949 `False`. Note that because this implementation relies on 950 multiprocessing, you should not pass non-picklable arguments to 951 the generator as they can't be passed easily to children processes. 952 953 Returns: 954 Scalar test loss (if the model has a single output and no metrics) 955 or list of scalars (if the model has multiple outputs 956 and/or metrics). The attribute `model.metrics_names` will give you 957 the display labels for the scalar outputs. 958 959 Raises: 960 ValueError: in case of invalid arguments. 961 """ 962 # Case 1: distribution strategy. 963 if self._distribution_strategy: 964 if K.in_multi_worker_mode(): 965 # Multi-Worker mode runs the Keras evaluation loop on multiple 966 # servers via the Distribute Coordinator. 967 def _worker_fn(_): 968 """Run evaluation inside the distributed coordinator.""" 969 filtered_callbacks = distributed_training_utils \ 970 .filter_distributed_callbacks(callbacks) 971 return training_distributed.evaluate_distributed( 972 self, 973 x=x, 974 y=y, 975 batch_size=batch_size, 976 verbose=verbose, 977 sample_weight=sample_weight, 978 steps=steps, 979 callbacks=filtered_callbacks) 980 981 # Independent worker only for now. 982 return dc.run_distribute_coordinator( 983 _worker_fn, 984 self._distribution_strategy, 985 mode=dc.CoordinatorMode.INDEPENDENT_WORKER) 986 else: 987 return training_distributed.evaluate_distributed( 988 self, 989 x=x, 990 y=y, 991 batch_size=batch_size, 992 verbose=verbose, 993 sample_weight=sample_weight, 994 steps=steps, 995 callbacks=callbacks) 996 997 batch_size = self._validate_or_infer_batch_size(batch_size, steps, x) 998 999 # Case 2: generator-like. Input is Python generator, or Sequence object, 1000 # or a non-distributed Dataset or iterator in eager execution. 1001 if data_utils.is_generator_or_sequence(x): 1002 training_utils.check_generator_arguments(y, sample_weight) 1003 return self.evaluate_generator( 1004 x, 1005 steps=steps, 1006 verbose=verbose, 1007 callbacks=callbacks, 1008 max_queue_size=max_queue_size, 1009 workers=workers, 1010 use_multiprocessing=use_multiprocessing) 1011 if training_utils.is_eager_dataset_or_iterator(x): 1012 # Make sure that y, sample_weights are not passed. 1013 training_utils.validate_dataset_input(x, y, sample_weight) 1014 return training_generator.evaluate_generator( 1015 self, x, 1016 steps=steps, 1017 batch_size=batch_size, 1018 verbose=verbose, 1019 workers=0, 1020 callbacks=callbacks) 1021 1022 # Case 3: Symbolic tensors or Numpy array-like. 1023 # This includes Datasets and iterators in graph mode (since they 1024 # generate symbolic tensors). 1025 x, y, sample_weights = self._standardize_user_data( 1026 x, 1027 y, 1028 sample_weight=sample_weight, 1029 batch_size=batch_size, 1030 check_steps=True, 1031 steps_name='steps', 1032 steps=steps) 1033 1034 if self.run_eagerly: 1035 return training_generator.evaluate_generator( 1036 self, (x, y, sample_weights), 1037 steps=steps, 1038 batch_size=batch_size, 1039 verbose=verbose, 1040 workers=0, 1041 callbacks=callbacks) 1042 else: 1043 return training_arrays.test_loop( 1044 self, 1045 inputs=x, 1046 targets=y, 1047 sample_weights=sample_weights, 1048 batch_size=batch_size, 1049 verbose=verbose, 1050 steps=steps, 1051 callbacks=callbacks) 1052 1053 def predict(self, 1054 x, 1055 batch_size=None, 1056 verbose=0, 1057 steps=None, 1058 callbacks=None, 1059 max_queue_size=10, 1060 workers=1, 1061 use_multiprocessing=False): 1062 """Generates output predictions for the input samples. 1063 1064 Computation is done in batches. 1065 1066 Arguments: 1067 x: Input samples. It could be: 1068 - A Numpy array (or array-like), or a list of arrays 1069 (in case the model has multiple inputs). 1070 - A TensorFlow tensor, or a list of tensors 1071 (in case the model has multiple inputs). 1072 - A `tf.data` dataset or a dataset iterator. 1073 - A generator or `keras.utils.Sequence` instance. 1074 batch_size: Integer or `None`. 1075 Number of samples per gradient update. 1076 If unspecified, `batch_size` will default to 32. 1077 Do not specify the `batch_size` is your data is in the 1078 form of symbolic tensors, dataset, dataset iterators, 1079 generators, or `keras.utils.Sequence` instances (since they generate 1080 batches). 1081 verbose: Verbosity mode, 0 or 1. 1082 steps: Total number of steps (batches of samples) 1083 before declaring the prediction round finished. 1084 Ignored with the default value of `None`. If x is a `tf.data` 1085 dataset or a dataset iterator, and `steps` is None, `predict` will 1086 run until the input dataset is exhausted. 1087 callbacks: List of `keras.callbacks.Callback` instances. 1088 List of callbacks to apply during prediction. 1089 See [callbacks](/api_docs/python/tf/keras/callbacks). 1090 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 1091 input only. Maximum size for the generator queue. 1092 If unspecified, `max_queue_size` will default to 10. 1093 workers: Integer. Used for generator or `keras.utils.Sequence` input 1094 only. Maximum number of processes to spin up when using 1095 process-based threading. If unspecified, `workers` will default 1096 to 1. If 0, will execute the generator on the main thread. 1097 use_multiprocessing: Boolean. Used for generator or 1098 `keras.utils.Sequence` input only. If `True`, use process-based 1099 threading. If unspecified, `use_multiprocessing` will default to 1100 `False`. Note that because this implementation relies on 1101 multiprocessing, you should not pass non-picklable arguments to 1102 the generator as they can't be passed easily to children processes. 1103 1104 1105 Returns: 1106 Numpy array(s) of predictions. 1107 1108 Raises: 1109 ValueError: In case of mismatch between the provided 1110 input data and the model's expectations, 1111 or in case a stateful model receives a number of samples 1112 that is not a multiple of the batch size. 1113 """ 1114 # Case 1: distribution strategy. 1115 if self._distribution_strategy: 1116 return training_distributed.predict_distributed(self, 1117 x=x, 1118 batch_size=batch_size, 1119 verbose=verbose, 1120 steps=steps, 1121 callbacks=callbacks) 1122 1123 batch_size = self._validate_or_infer_batch_size(batch_size, steps, x) 1124 1125 # Case 2: generator-like. Input is Python generator, or Sequence object, 1126 # or a non-distributed Dataset or iterator in eager execution. 1127 if data_utils.is_generator_or_sequence(x): 1128 return self.predict_generator( 1129 x, 1130 steps=steps, 1131 verbose=verbose, 1132 callbacks=callbacks, 1133 max_queue_size=max_queue_size, 1134 workers=workers, 1135 use_multiprocessing=use_multiprocessing) 1136 if training_utils.is_eager_dataset_or_iterator(x): 1137 return training_generator.predict_generator( 1138 self, 1139 x, 1140 steps=steps, 1141 batch_size=batch_size, 1142 verbose=verbose, 1143 workers=0, 1144 callbacks=callbacks) 1145 1146 # Case 3: Symbolic tensors or Numpy array-like. 1147 # This includes Datasets and iterators in graph mode (since they 1148 # generate symbolic tensors). 1149 x, _, _ = self._standardize_user_data( 1150 x, check_steps=True, steps_name='steps', steps=steps) 1151 1152 if self.run_eagerly: 1153 return training_generator.predict_generator( 1154 self, 1155 x, 1156 steps=steps, 1157 batch_size=batch_size, 1158 verbose=verbose, 1159 workers=0, 1160 callbacks=callbacks) 1161 else: 1162 return training_arrays.predict_loop( 1163 self, 1164 x, 1165 batch_size=batch_size, 1166 verbose=verbose, 1167 steps=steps, 1168 callbacks=callbacks) 1169 1170 def reset_metrics(self): 1171 """Resets the state of metrics.""" 1172 if hasattr(self, 'metrics'): 1173 for m in self.metrics: 1174 m.reset_states() 1175 1176 # Reset the state of loss metric wrappers. 1177 if getattr(self, '_output_loss_metrics', None) is not None: 1178 for m in self._output_loss_metrics: 1179 m.reset_states() 1180 1181 # Reset metrics on all the distributed (cloned) models. 1182 if self._distribution_strategy: 1183 distributed_training_utils._reset_metrics(self) # pylint: disable=protected-access 1184 1185 def train_on_batch(self, 1186 x, 1187 y=None, 1188 sample_weight=None, 1189 class_weight=None, 1190 reset_metrics=True): 1191 """Runs a single gradient update on a single batch of data. 1192 1193 Arguments: 1194 x: Input data. It could be: 1195 - A Numpy array (or array-like), or a list of arrays 1196 (in case the model has multiple inputs). 1197 - A TensorFlow tensor, or a list of tensors 1198 (in case the model has multiple inputs). 1199 - A dict mapping input names to the corresponding array/tensors, 1200 if the model has named inputs. 1201 - A `tf.data` dataset or a dataset iterator. 1202 y: Target data. Like the input data `x`, it could be either Numpy 1203 array(s) or TensorFlow tensor(s). It should be consistent with `x` 1204 (you cannot have Numpy inputs and tensor targets, or inversely). If 1205 `x` is a dataset or a dataset iterator, `y` should not be specified 1206 (since targets will be obtained from the iterator). 1207 sample_weight: Optional array of the same length as x, containing 1208 weights to apply to the model's loss for each sample. In the case of 1209 temporal data, you can pass a 2D array with shape (samples, 1210 sequence_length), to apply a different weight to every timestep of 1211 every sample. In this case you should make sure to specify 1212 sample_weight_mode="temporal" in compile(). This argument is not 1213 supported when `x` is a dataset or a dataset iterator. 1214 class_weight: Optional dictionary mapping class indices (integers) to a 1215 weight (float) to apply to the model's loss for the samples from this 1216 class during training. This can be useful to tell the model to "pay 1217 more attention" to samples from an under-represented class. 1218 reset_metrics: If `True`, the metrics returned will be only for this 1219 batch. If `False`, the metrics will be statefully accumulated across 1220 batches. 1221 1222 Returns: 1223 Scalar training loss 1224 (if the model has a single output and no metrics) 1225 or list of scalars (if the model has multiple outputs 1226 and/or metrics). The attribute `model.metrics_names` will give you 1227 the display labels for the scalar outputs. 1228 1229 Raises: 1230 ValueError: In case of invalid user-provided arguments. 1231 """ 1232 if self._distribution_strategy: 1233 raise NotImplementedError('`train_on_batch` is not supported for models ' 1234 'compiled with DistributionStrategy.') 1235 # Validate and standardize user data. 1236 x, y, sample_weights = self._standardize_user_data( 1237 x, y, sample_weight=sample_weight, class_weight=class_weight, 1238 extract_tensors_from_dataset=True) 1239 1240 if self.run_eagerly: 1241 outputs = training_eager.train_on_batch( 1242 self, 1243 x, 1244 y, 1245 sample_weights=sample_weights, 1246 output_loss_metrics=self._output_loss_metrics) 1247 else: 1248 x = training_utils.ModelInputs(x).as_list() 1249 ins = x + (y or []) + (sample_weights or []) 1250 1251 if not isinstance(K.symbolic_learning_phase(), int): 1252 ins += [True] # Add learning phase value. 1253 1254 self._make_train_function() 1255 outputs = self.train_function(ins) # pylint: disable=not-callable 1256 1257 if reset_metrics: 1258 self.reset_metrics() 1259 1260 if len(outputs) == 1: 1261 return outputs[0] 1262 return outputs 1263 1264 def test_on_batch(self, x, y=None, sample_weight=None, reset_metrics=True): 1265 """Test the model on a single batch of samples. 1266 1267 Arguments: 1268 x: Input data. It could be: 1269 - A Numpy array (or array-like), or a list of arrays 1270 (in case the model has multiple inputs). 1271 - A TensorFlow tensor, or a list of tensors 1272 (in case the model has multiple inputs). 1273 - A dict mapping input names to the corresponding array/tensors, 1274 if the model has named inputs. 1275 - A `tf.data` dataset or a dataset iterator. 1276 y: Target data. Like the input data `x`, 1277 it could be either Numpy array(s) or TensorFlow tensor(s). 1278 It should be consistent with `x` (you cannot have Numpy inputs and 1279 tensor targets, or inversely). If `x` is a dataset or a 1280 dataset iterator, `y` should not be specified 1281 (since targets will be obtained from the iterator). 1282 sample_weight: Optional array of the same length as x, containing 1283 weights to apply to the model's loss for each sample. 1284 In the case of temporal data, you can pass a 2D array 1285 with shape (samples, sequence_length), 1286 to apply a different weight to every timestep of every sample. 1287 In this case you should make sure to specify 1288 sample_weight_mode="temporal" in compile(). This argument is not 1289 supported when `x` is a dataset or a dataset iterator. 1290 reset_metrics: If `True`, the metrics returned will be only for this 1291 batch. If `False`, the metrics will be statefully accumulated across 1292 batches. 1293 1294 Returns: 1295 Scalar test loss (if the model has a single output and no metrics) 1296 or list of scalars (if the model has multiple outputs 1297 and/or metrics). The attribute `model.metrics_names` will give you 1298 the display labels for the scalar outputs. 1299 1300 Raises: 1301 ValueError: In case of invalid user-provided arguments. 1302 """ 1303 if self._distribution_strategy: 1304 raise NotImplementedError('`test_on_batch` is not supported for models ' 1305 'compiled with DistributionStrategy.') 1306 # Validate and standardize user data. 1307 x, y, sample_weights = self._standardize_user_data( 1308 x, y, sample_weight=sample_weight, extract_tensors_from_dataset=True) 1309 1310 if self.run_eagerly: 1311 outputs = training_eager.test_on_batch( 1312 self, 1313 x, 1314 y, 1315 sample_weights=sample_weights, 1316 output_loss_metrics=self._output_loss_metrics) 1317 else: 1318 x = training_utils.ModelInputs(x).as_list() 1319 inputs = x + (y or []) + (sample_weights or []) 1320 1321 self._make_test_function() 1322 outputs = self.test_function(inputs) # pylint: disable=not-callable 1323 1324 if reset_metrics: 1325 self.reset_metrics() 1326 1327 if len(outputs) == 1: 1328 return outputs[0] 1329 return outputs 1330 1331 def predict_on_batch(self, x): 1332 """Returns predictions for a single batch of samples. 1333 1334 Arguments: 1335 x: Input data. It could be: 1336 - A Numpy array (or array-like), or a list of arrays 1337 (in case the model has multiple inputs). 1338 - A TensorFlow tensor, or a list of tensors 1339 (in case the model has multiple inputs). 1340 - A `tf.data` dataset or a dataset iterator. 1341 1342 Returns: 1343 Numpy array(s) of predictions. 1344 1345 Raises: 1346 ValueError: In case of mismatch between given number of inputs and 1347 expectations of the model. 1348 """ 1349 if self._distribution_strategy: 1350 raise NotImplementedError('`predict_on_batch` is not supported for ' 1351 'models compiled with DistributionStrategy.') 1352 # Validate and standardize user data. 1353 inputs, _, _ = self._standardize_user_data( 1354 x, extract_tensors_from_dataset=True) 1355 if self.run_eagerly: 1356 if (isinstance(inputs, iterator_ops.EagerIterator) or 1357 (isinstance(inputs, dataset_ops.DatasetV2))): 1358 inputs = training_utils.cast_if_floating_dtype(inputs) 1359 elif isinstance(inputs, collections.Sequence): 1360 inputs = [ 1361 ops.convert_to_tensor(val, dtype=K.floatx()) for val in inputs] 1362 1363 # Unwrap lists with only one input, as we do when training on batch 1364 if len(inputs) == 1: 1365 inputs = inputs[0] 1366 1367 return self(inputs) # pylint: disable=not-callable 1368 1369 self._make_predict_function() 1370 outputs = self.predict_function(inputs) 1371 1372 if len(outputs) == 1: 1373 return outputs[0] 1374 return outputs 1375 1376 def fit_generator(self, 1377 generator, 1378 steps_per_epoch=None, 1379 epochs=1, 1380 verbose=1, 1381 callbacks=None, 1382 validation_data=None, 1383 validation_steps=None, 1384 validation_freq=1, 1385 class_weight=None, 1386 max_queue_size=10, 1387 workers=1, 1388 use_multiprocessing=False, 1389 shuffle=True, 1390 initial_epoch=0): 1391 """Fits the model on data yielded batch-by-batch by a Python generator. 1392 1393 The generator is run in parallel to the model, for efficiency. 1394 For instance, this allows you to do real-time data augmentation 1395 on images on CPU in parallel to training your model on GPU. 1396 1397 The use of `keras.utils.Sequence` guarantees the ordering 1398 and guarantees the single use of every input per epoch when 1399 using `use_multiprocessing=True`. 1400 1401 Arguments: 1402 generator: A generator or an instance of `Sequence` 1403 (`keras.utils.Sequence`) 1404 object in order to avoid duplicate data 1405 when using multiprocessing. 1406 The output of the generator must be either 1407 - a tuple `(inputs, targets)` 1408 - a tuple `(inputs, targets, sample_weights)`. 1409 This tuple (a single output of the generator) makes a single batch. 1410 Therefore, all arrays in this tuple must have the same length (equal 1411 to the size of this batch). Different batches may have different 1412 sizes. 1413 For example, the last batch of the epoch is commonly smaller than 1414 the 1415 others, if the size of the dataset is not divisible by the batch 1416 size. 1417 The generator is expected to loop over its data 1418 indefinitely. An epoch finishes when `steps_per_epoch` 1419 batches have been seen by the model. 1420 steps_per_epoch: Total number of steps (batches of samples) 1421 to yield from `generator` before declaring one epoch 1422 finished and starting the next epoch. It should typically 1423 be equal to the number of samples of your dataset 1424 divided by the batch size. 1425 Optional for `Sequence`: if unspecified, will use 1426 the `len(generator)` as a number of steps. 1427 epochs: Integer, total number of iterations on the data. 1428 verbose: Verbosity mode, 0, 1, or 2. 1429 callbacks: List of callbacks to be called during training. 1430 validation_data: This can be either 1431 - a generator for the validation data 1432 - a tuple (inputs, targets) 1433 - a tuple (inputs, targets, sample_weights). 1434 validation_steps: Only relevant if `validation_data` 1435 is a generator. Total number of steps (batches of samples) 1436 to yield from `generator` before stopping. 1437 Optional for `Sequence`: if unspecified, will use 1438 the `len(validation_data)` as a number of steps. 1439 validation_freq: Only relevant if validation data is provided. Integer 1440 or `collections.Container` instance (e.g. list, tuple, etc.). If an 1441 integer, specifies how many training epochs to run before a new 1442 validation run is performed, e.g. `validation_freq=2` runs 1443 validation every 2 epochs. If a Container, specifies the epochs on 1444 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs 1445 validation at the end of the 1st, 2nd, and 10th epochs. 1446 class_weight: Dictionary mapping class indices to a weight 1447 for the class. 1448 max_queue_size: Integer. Maximum size for the generator queue. 1449 If unspecified, `max_queue_size` will default to 10. 1450 workers: Integer. Maximum number of processes to spin up 1451 when using process-based threading. 1452 If unspecified, `workers` will default to 1. If 0, will 1453 execute the generator on the main thread. 1454 use_multiprocessing: Boolean. 1455 If `True`, use process-based threading. 1456 If unspecified, `use_multiprocessing` will default to `False`. 1457 Note that because this implementation relies on multiprocessing, 1458 you should not pass non-picklable arguments to the generator 1459 as they can't be passed easily to children processes. 1460 shuffle: Boolean. Whether to shuffle the order of the batches at 1461 the beginning of each epoch. Only used with instances 1462 of `Sequence` (`keras.utils.Sequence`). 1463 Has no effect when `steps_per_epoch` is not `None`. 1464 initial_epoch: Epoch at which to start training 1465 (useful for resuming a previous training run) 1466 1467 Returns: 1468 A `History` object. 1469 1470 Example: 1471 1472 ```python 1473 def generate_arrays_from_file(path): 1474 while 1: 1475 f = open(path) 1476 for line in f: 1477 # create numpy arrays of input data 1478 # and labels, from each line in the file 1479 x1, x2, y = process_line(line) 1480 yield ({'input_1': x1, 'input_2': x2}, {'output': y}) 1481 f.close() 1482 1483 model.fit_generator(generate_arrays_from_file('/my_file.txt'), 1484 steps_per_epoch=10000, epochs=10) 1485 ``` 1486 Raises: 1487 ValueError: In case the generator yields data in an invalid format. 1488 """ 1489 if self._distribution_strategy: 1490 raise NotImplementedError('`fit_generator` is not supported for ' 1491 'models compiled with DistributionStrategy.') 1492 return training_generator.fit_generator( 1493 self, 1494 generator, 1495 steps_per_epoch=steps_per_epoch, 1496 epochs=epochs, 1497 verbose=verbose, 1498 callbacks=callbacks, 1499 validation_data=validation_data, 1500 validation_steps=validation_steps, 1501 validation_freq=validation_freq, 1502 class_weight=class_weight, 1503 max_queue_size=max_queue_size, 1504 workers=workers, 1505 use_multiprocessing=use_multiprocessing, 1506 shuffle=shuffle, 1507 initial_epoch=initial_epoch, 1508 steps_name='steps_per_epoch') 1509 1510 def evaluate_generator(self, 1511 generator, 1512 steps=None, 1513 callbacks=None, 1514 max_queue_size=10, 1515 workers=1, 1516 use_multiprocessing=False, 1517 verbose=0): 1518 """Evaluates the model on a data generator. 1519 1520 The generator should return the same kind of data 1521 as accepted by `test_on_batch`. 1522 1523 Arguments: 1524 generator: Generator yielding tuples (inputs, targets) 1525 or (inputs, targets, sample_weights) 1526 or an instance of `keras.utils.Sequence` 1527 object in order to avoid duplicate data 1528 when using multiprocessing. 1529 steps: Total number of steps (batches of samples) 1530 to yield from `generator` before stopping. 1531 Optional for `Sequence`: if unspecified, will use 1532 the `len(generator)` as a number of steps. 1533 callbacks: List of `keras.callbacks.Callback` instances. 1534 List of callbacks to apply during evaluation. 1535 See [callbacks](/api_docs/python/tf/keras/callbacks). 1536 max_queue_size: maximum size for the generator queue 1537 workers: Integer. Maximum number of processes to spin up 1538 when using process-based threading. 1539 If unspecified, `workers` will default to 1. If 0, will 1540 execute the generator on the main thread. 1541 use_multiprocessing: Boolean. 1542 If `True`, use process-based threading. 1543 If unspecified, `use_multiprocessing` will default to `False`. 1544 Note that because this implementation relies on multiprocessing, 1545 you should not pass non-picklable arguments to the generator 1546 as they can't be passed easily to children processes. 1547 verbose: Verbosity mode, 0 or 1. 1548 1549 Returns: 1550 Scalar test loss (if the model has a single output and no metrics) 1551 or list of scalars (if the model has multiple outputs 1552 and/or metrics). The attribute `model.metrics_names` will give you 1553 the display labels for the scalar outputs. 1554 1555 Raises: 1556 ValueError: in case of invalid arguments. 1557 1558 Raises: 1559 ValueError: In case the generator yields data in an invalid format. 1560 """ 1561 if self._distribution_strategy: 1562 raise NotImplementedError('`evaluate_generator` is not supported for ' 1563 'models compiled with DistributionStrategy.') 1564 return training_generator.evaluate_generator( 1565 self, 1566 generator, 1567 steps=steps, 1568 max_queue_size=max_queue_size, 1569 workers=workers, 1570 use_multiprocessing=use_multiprocessing, 1571 verbose=verbose, 1572 callbacks=callbacks) 1573 1574 def predict_generator(self, 1575 generator, 1576 steps=None, 1577 callbacks=None, 1578 max_queue_size=10, 1579 workers=1, 1580 use_multiprocessing=False, 1581 verbose=0): 1582 """Generates predictions for the input samples from a data generator. 1583 1584 The generator should return the same kind of data as accepted by 1585 `predict_on_batch`. 1586 1587 Arguments: 1588 generator: Generator yielding batches of input samples 1589 or an instance of `keras.utils.Sequence` object in order to 1590 avoid duplicate data when using multiprocessing. 1591 steps: Total number of steps (batches of samples) 1592 to yield from `generator` before stopping. 1593 Optional for `Sequence`: if unspecified, will use 1594 the `len(generator)` as a number of steps. 1595 callbacks: List of `keras.callbacks.Callback` instances. 1596 List of callbacks to apply during prediction. 1597 See [callbacks](/api_docs/python/tf/keras/callbacks). 1598 max_queue_size: Maximum size for the generator queue. 1599 workers: Integer. Maximum number of processes to spin up 1600 when using process-based threading. 1601 If unspecified, `workers` will default to 1. If 0, will 1602 execute the generator on the main thread. 1603 use_multiprocessing: Boolean. 1604 If `True`, use process-based threading. 1605 If unspecified, `use_multiprocessing` will default to `False`. 1606 Note that because this implementation relies on multiprocessing, 1607 you should not pass non-picklable arguments to the generator 1608 as they can't be passed easily to children processes. 1609 verbose: verbosity mode, 0 or 1. 1610 1611 Returns: 1612 Numpy array(s) of predictions. 1613 1614 Raises: 1615 ValueError: In case the generator yields data in an invalid format. 1616 """ 1617 if self._distribution_strategy: 1618 raise NotImplementedError('`predict_generator` is not supported for ' 1619 'models compiled with DistributionStrategy.') 1620 return training_generator.predict_generator( 1621 self, 1622 generator, 1623 steps=steps, 1624 max_queue_size=max_queue_size, 1625 workers=workers, 1626 use_multiprocessing=use_multiprocessing, 1627 verbose=verbose, 1628 callbacks=callbacks) 1629 1630 def _prepare_total_loss(self, skip_target_indices=None, masks=None): 1631 """Computes total loss from loss functions. 1632 1633 Arguments: 1634 skip_target_indices: A list of indices of model outputs where loss 1635 function is None. 1636 masks: List of mask values corresponding to each model output. 1637 1638 Returns: 1639 A list of loss weights of python floats. 1640 1641 Raises: 1642 TypeError: If model run_eagerly is True. 1643 """ 1644 if self.run_eagerly: 1645 raise TypeError('total loss can not be computed when compiled with ' 1646 'run_eagerly = True.') 1647 skip_target_indices = skip_target_indices or [] 1648 total_loss = None 1649 with K.name_scope('loss'): 1650 zipped_inputs = zip(self.targets, self.outputs, self.loss_functions, 1651 self.sample_weights, masks, self.loss_weights_list) 1652 for i, (y_true, y_pred, loss_fn, sample_weight, mask, 1653 loss_weight) in enumerate(zipped_inputs): 1654 if i in skip_target_indices: 1655 continue 1656 loss_name = self.output_names[i] + '_loss' 1657 with K.name_scope(loss_name): 1658 if mask is not None: 1659 mask = math_ops.cast(mask, y_pred.dtype) 1660 # Update weights with mask. 1661 if sample_weight is None: 1662 sample_weight = mask 1663 else: 1664 # Update dimensions of weights to match with mask if possible. 1665 mask, _, sample_weight = ( 1666 losses_utils.squeeze_or_expand_dimensions( 1667 mask, None, sample_weight)) 1668 sample_weight *= mask 1669 1670 # Reset reduction on the loss so that we can get the per sample loss 1671 # value. We use this to get both the stateless and stateful loss 1672 # values without having to compute the underlying loss function 1673 # twice. 1674 weighted_losses = None 1675 if hasattr(loss_fn, 'reduction'): 1676 current_loss_reduction = loss_fn.reduction 1677 loss_fn.reduction = losses_utils.ReductionV2.NONE 1678 weighted_losses = loss_fn( 1679 y_true, y_pred, sample_weight=sample_weight) 1680 loss_fn.reduction = current_loss_reduction 1681 1682 # Compute the stateless loss value. 1683 output_loss = losses_utils.reduce_weighted_loss( 1684 weighted_losses, reduction=current_loss_reduction) 1685 else: 1686 # Compute the stateless loss value for a custom loss class. 1687 # Here we assume that the class takes care of loss reduction 1688 # because if this class returns a vector value we cannot 1689 # differentiate between use case where a custom optimizer 1690 # expects a vector loss value vs unreduced per-sample loss value. 1691 output_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight) 1692 1693 if len(self.outputs) > 1: 1694 # Keep track of stateful result tensor and function for the loss. 1695 # Compute the stateful loss value. 1696 if weighted_losses is not None: 1697 # TODO(b/120571621): Directly call metric when the bug is fixed. 1698 aggregated_output_loss = self._call_fn_for_each_replica( 1699 self._output_loss_metrics[i], weighted_losses) 1700 else: 1701 # Custom loss class. 1702 aggregated_output_loss = self._call_metric_fn( 1703 self._output_loss_metrics[i], y_true, y_pred, sample_weight) 1704 self._compile_metrics_tensors[loss_name] = aggregated_output_loss 1705 1706 if total_loss is None: 1707 total_loss = loss_weight * output_loss 1708 else: 1709 total_loss += loss_weight * output_loss 1710 if total_loss is None: 1711 if not self.losses: 1712 raise ValueError('The model cannot be compiled ' 1713 'because it has no loss to optimize.') 1714 else: 1715 total_loss = 0. 1716 1717 # Add regularization penalties and other layer-specific losses. 1718 if self.losses: 1719 total_loss += losses_utils.scale_loss_for_distribution( 1720 math_ops.add_n(self.losses)) 1721 return total_loss 1722 1723 def _get_callback_model(self): 1724 """Returns the Callback Model for this Model.""" 1725 1726 if hasattr(self, '_replicated_model') and self._replicated_model: 1727 # When using training_distributed, we set the callback model 1728 # to an instance of the `DistributedModel` that we create in 1729 # the `compile` call. The `DistributedModel` is initialized 1730 # with the first replicated model. We need to set the callback 1731 # model to a DistributedModel to allow us to override saving 1732 # and loading weights when we checkpoint the model during training. 1733 return self._replicated_model 1734 if hasattr(self, 'callback_model') and self.callback_model: 1735 return self.callback_model 1736 return self 1737 1738 def _make_callback_model(self, grouped_model): 1739 first_replicated_model = self._distribution_strategy.unwrap( 1740 grouped_model)[0] 1741 # We initialize the callback model with the first replicated model. 1742 self._replicated_model = DistributedCallbackModel(first_replicated_model) 1743 self._replicated_model.set_original_model(self) 1744 1745 def _validate_or_infer_batch_size(self, batch_size, steps, x): 1746 """Validates that the `batch_size` provided is consistent with InputLayer. 1747 1748 It's possible that the user specified a static batch size in their 1749 InputLayer. If so, this method checks the provided `batch_size` and `x` 1750 arguments are consistent with this static batch size. Also, if 1751 `batch_size` is `None`, this method will attempt to infer the batch size 1752 from the static batch size of the InputLayer. Lastly, ValueError will be 1753 raised if `x` is a tf.data.Dataset and `batch_size` is specified as we 1754 expect users to provide batched datasets. 1755 1756 Arguments: 1757 batch_size: The batch_size provided as an argument to 1758 fit/evaluate/predict. 1759 steps: The steps provided as an argument to fit/evaluate/predict. 1760 x: The data passed as `x` to fit/evaluate/predict. 1761 1762 Returns: 1763 The validated batch_size, auto-inferred from the first layer if not 1764 provided. 1765 """ 1766 if batch_size is not None and isinstance(x, dataset_ops.DatasetV2): 1767 raise ValueError('The `batch_size` argument must not be specified when' 1768 ' using dataset as an input.') 1769 1770 layers = super(Model, self).layers # Avoids the override in Sequential. 1771 if layers: 1772 first_layer = layers[0] 1773 static_batch_size = training_utils.get_static_batch_size(first_layer) 1774 if static_batch_size is not None: 1775 1776 # Check `batch_size` argument is consistent with InputLayer. 1777 if batch_size is not None and batch_size != static_batch_size: 1778 raise ValueError('The `batch_size` argument value {} is incompatible ' 1779 'with the specified batch size of your Input Layer: ' 1780 '{}'.format(batch_size, static_batch_size)) 1781 1782 # Check Dataset/Iterator batch size is consistent with InputLayer. 1783 if isinstance(x, (dataset_ops.DatasetV2, iterator_ops.Iterator, 1784 iterator_ops.EagerIterator)): 1785 ds_batch_size = tensor_shape.as_dimension( 1786 nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value 1787 if ds_batch_size is not None and ds_batch_size != static_batch_size: 1788 raise ValueError('The batch output shape of your `Dataset` is {}, ' 1789 'which is incompatible with the specified batch ' 1790 'size of your Input Layer: {}'.format( 1791 ds_batch_size, static_batch_size)) 1792 1793 # Set inferred batch size from the InputLayer. 1794 if steps is None: 1795 batch_size = static_batch_size 1796 1797 if batch_size is None and steps is None: 1798 # Backwards compatibility 1799 batch_size = 32 1800 return batch_size 1801 1802 def _list_functions_for_serialization(self): 1803 return { 1804 '_default_save_signature': saving_utils.trace_model_call(self) 1805 } 1806 1807 def _set_sample_weight_attributes(self, sample_weight_mode, 1808 skip_target_weighing_indices): 1809 """Sets sample weight related attributes on the model.""" 1810 sample_weights, sample_weight_modes = training_utils.prepare_sample_weights( 1811 self.output_names, sample_weight_mode, skip_target_weighing_indices) 1812 self.sample_weights = sample_weights 1813 self.sample_weight_modes = sample_weight_modes 1814 self._feed_sample_weight_modes = [ 1815 sample_weight_modes[i] 1816 for i in range(len(self.outputs)) 1817 if i not in skip_target_weighing_indices 1818 ] 1819 self._feed_sample_weights = [ 1820 sample_weights[i] 1821 for i in range(len(sample_weights)) 1822 if i not in skip_target_weighing_indices 1823 ] 1824 1825 def _cache_output_metric_attributes(self, metrics, weighted_metrics): 1826 """Caches metric name and function attributes for every model output.""" 1827 output_shapes = [] 1828 for output in self.outputs: 1829 if output is None or output.shape.rank is None: 1830 output_shapes.append(None) 1831 else: 1832 output_shapes.append(output.shape.as_list()) 1833 self._per_output_metrics = training_utils.collect_per_output_metric_info( 1834 metrics, self.output_names, output_shapes, self.loss_functions) 1835 self._per_output_weighted_metrics = ( 1836 training_utils.collect_per_output_metric_info( 1837 weighted_metrics, 1838 self.output_names, 1839 output_shapes, 1840 self.loss_functions, 1841 is_weighted=True)) 1842 1843 def _add_unique_metric_name(self, metric_name, output_index): 1844 """Makes the metric name unique and adds it to the model's metric name list. 1845 1846 If there are multiple outputs for which the metrics are calculated, the 1847 metric names have to be made unique by appending an integer. 1848 1849 Arguments: 1850 metric_name: Metric name that corresponds to the metric specified by the 1851 user. For example: 'acc'. 1852 output_index: The index of the model output for which the metric name is 1853 being added. 1854 1855 Returns: 1856 string, name of the model's unique metric name 1857 """ 1858 if len(self.output_names) > 1: 1859 metric_name = '%s_%s' % (self.output_names[output_index], metric_name) 1860 j = 1 1861 base_metric_name = metric_name 1862 while metric_name in self._compile_metrics_names: 1863 metric_name = '%s_%d' % (base_metric_name, j) 1864 j += 1 1865 1866 return metric_name 1867 1868 @property 1869 def _all_metrics_tensors(self): 1870 """Returns a dictionary that maps metric names to metric result tensors. 1871 1872 This maps metric names from `model.metric_names` to result tensors. 1873 Just like model.metric_names, this includes loss names and tensors. 1874 """ 1875 metrics_tensors = {} 1876 if self._is_compiled: 1877 metrics_tensors.update(self._compile_metrics_tensors) 1878 metrics_tensors.update(super(Model, self)._all_metrics_tensors) 1879 return metrics_tensors 1880 1881 def _init_metric_attributes(self): 1882 """Initialized model metric attributes.""" 1883 # List of all metric names in the model. This includes loss metrics. 1884 self._compile_metrics_names = ['loss'] 1885 # List of stateful metric functions. Used for resetting metric state during 1886 # training/eval. This includes loss metric functions. 1887 self._compile_metric_functions = [] 1888 # Dict of all aggregated metric result tensors. This includes aggregated 1889 # loss result tensors. 1890 self._compile_metrics_tensors = {} 1891 # List of metric wrappers on output losses. 1892 self._output_loss_metrics = None 1893 1894 def _set_per_output_metric_attributes(self, metrics_dict, output_index): 1895 """Sets the metric attributes on the model for the given output. 1896 1897 Arguments: 1898 metrics_dict: A dict with metric names as keys and metric fns as values. 1899 output_index: The index of the model output for which the metric 1900 attributes are added. 1901 1902 Returns: 1903 Metrics dict updated with unique metric names as keys. 1904 """ 1905 updated_metrics_dict = collections.OrderedDict() 1906 for metric_name, metric_fn in metrics_dict.items(): 1907 metric_name = self._add_unique_metric_name(metric_name, output_index) 1908 1909 # Update the name on the metric class to be the unique generated name. 1910 metric_fn._name = metric_name # pylint: disable=protected-access 1911 updated_metrics_dict[metric_name] = metric_fn 1912 # Keep track of metric name and function. 1913 self._compile_metrics_names.append(metric_name) 1914 self._compile_metric_functions.append(metric_fn) 1915 return updated_metrics_dict 1916 1917 def _set_metric_attributes(self, skip_target_indices=None): 1918 """Sets the metric attributes on the model for all the model outputs.""" 1919 # Add loss metric names to the model metric names list. 1920 if len(self.outputs) > 1: 1921 output_names = [ 1922 self.output_names[i] + '_loss' 1923 for i in range(len(self.outputs)) 1924 if i not in skip_target_indices 1925 ] 1926 self._compile_metrics_names.extend(output_names) 1927 1928 skip_target_indices = skip_target_indices or [] 1929 updated_per_output_metrics = [] 1930 updated_per_output_weighted_metrics = [] 1931 for i in range(len(self.outputs)): 1932 if i in skip_target_indices: 1933 updated_per_output_metrics.append(self._per_output_metrics[i]) 1934 updated_per_output_weighted_metrics.append( 1935 self._per_output_weighted_metrics[i]) 1936 continue 1937 updated_per_output_metrics.append( 1938 self._set_per_output_metric_attributes(self._per_output_metrics[i], 1939 i)) 1940 updated_per_output_weighted_metrics.append( 1941 self._set_per_output_metric_attributes( 1942 self._per_output_weighted_metrics[i], i)) 1943 1944 # Create a metric wrapper for each output loss. 1945 if len(self.outputs) > 1: 1946 self._output_loss_metrics = [ 1947 metrics_module.SumOverBatchSize() if hasattr(loss_fn, 'reduction') 1948 else metrics_module.SumOverBatchSizeMetricWrapper(loss_fn) 1949 for loss_fn in self.loss_functions 1950 ] 1951 1952 self._per_output_metrics = updated_per_output_metrics 1953 self._per_output_weighted_metrics = updated_per_output_weighted_metrics 1954 1955 def _call_metric_fn(self, metric_fn, y_true, y_pred, weights, mask=None): 1956 # TODO(b/120571621): Remove this function when the bug is fixed. 1957 """Helper function to call metric function with distribution strategy.""" 1958 return self._call_fn_for_each_replica( 1959 training_utils.call_metric_function, 1960 metric_fn, 1961 y_true, 1962 y_pred, 1963 weights=weights, 1964 mask=mask) 1965 1966 def _call_fn_for_each_replica(self, fn, *args, **kwargs): 1967 # TODO(b/120571621): We want to avoid metric reductions here since 1968 # since TPUStrategy does not implement replica local variables. 1969 # Remove this hack once we support TPUReplicaLocalVariables. 1970 is_tpu = distributed_training_utils.is_tpu_strategy( 1971 self._distribution_strategy) 1972 if ((not is_tpu) and self._distribution_strategy and 1973 distribution_strategy_context.in_cross_replica_context()): 1974 with self._distribution_strategy.scope(): 1975 return self._distribution_strategy.extended.call_for_each_replica( 1976 fn, args, kwargs) 1977 return fn(*args, **kwargs) 1978 1979 def _handle_per_output_metrics(self, 1980 metrics_dict, 1981 y_true, 1982 y_pred, 1983 mask, 1984 weights=None): 1985 """Calls metric functions for a single output. 1986 1987 Arguments: 1988 metrics_dict: A dict with metric names as keys and metric fns as values. 1989 y_true: Target output. 1990 y_pred: Predicted output. 1991 mask: Computed mask value for the current output. 1992 weights: Weights to be applied on the current output. 1993 1994 Returns: 1995 A list of metric result tensors. 1996 """ 1997 metric_results = [] 1998 for metric_name, metric_fn in metrics_dict.items(): 1999 with K.name_scope(metric_name): 2000 metric_result = self._call_metric_fn(metric_fn, y_true, y_pred, weights, 2001 mask) 2002 metric_results.append(metric_result) 2003 if not self.run_eagerly: 2004 self._compile_metrics_tensors[metric_name] = metric_result 2005 2006 return metric_results 2007 2008 def _handle_metrics(self, 2009 outputs, 2010 skip_target_indices=None, 2011 targets=None, 2012 sample_weights=None, 2013 masks=None): 2014 """Handles calling metric functions. 2015 2016 Arguments: 2017 outputs: List of outputs (predictions). 2018 skip_target_indices: Optional. List of target ids to skip. 2019 targets: List of targets. 2020 sample_weights: Optional list of sample weight arrays. 2021 masks: List of computed output mask values. 2022 2023 Returns: 2024 A list of metric result tensors. 2025 """ 2026 skip_target_indices = skip_target_indices or [] 2027 metric_results = [] 2028 with K.name_scope('metrics'): 2029 # Invoke all metrics added using `compile`. 2030 for i in range(len(outputs)): 2031 if i in skip_target_indices: 2032 continue 2033 output = outputs[i] if outputs else None 2034 target = targets[i] if targets else None 2035 output_mask = masks[i] if masks else None 2036 metric_results.extend( 2037 self._handle_per_output_metrics(self._per_output_metrics[i], target, 2038 output, output_mask)) 2039 metric_results.extend( 2040 self._handle_per_output_metrics( 2041 self._per_output_weighted_metrics[i], 2042 target, 2043 output, 2044 output_mask, 2045 weights=sample_weights[i])) 2046 2047 # Add metric results from the `add_metric` metrics in eager mode. 2048 if context.executing_eagerly(): 2049 for m in self.metrics: 2050 if m not in self._compile_metric_functions: 2051 metric_results.append(m.result()) 2052 return metric_results 2053 2054 def _check_trainable_weights_consistency(self): 2055 """Check trainable weights count consistency. 2056 2057 This will raise a warning if `trainable_weights` and 2058 `_collected_trainable_weights` are inconsistent (i.e. have different 2059 number of parameters). 2060 Inconsistency will typically arise when one modifies `model.trainable` 2061 without calling `model.compile` again. 2062 """ 2063 if not hasattr(self, '_collected_trainable_weights'): 2064 return 2065 2066 if len(self.trainable_weights) != len(self._collected_trainable_weights): 2067 logging.log_first_n( 2068 logging.WARN, 'Discrepancy between trainable weights and collected' 2069 ' trainable weights, did you set `model.trainable`' 2070 ' without calling `model.compile` after ?', 1) 2071 2072 def _make_train_function(self): 2073 metrics_tensors = [ 2074 self._all_metrics_tensors[m] for m in self.metrics_names[1:] 2075 ] 2076 if not self._is_compiled: 2077 raise RuntimeError('You must compile your model before using it.') 2078 self._check_trainable_weights_consistency() 2079 if getattr(self, 'train_function') is None: 2080 inputs = (self._feed_inputs + 2081 self._feed_targets + 2082 self._feed_sample_weights) 2083 if not isinstance(K.symbolic_learning_phase(), int): 2084 inputs += [K.symbolic_learning_phase()] 2085 2086 with K.get_graph().as_default(): 2087 with K.name_scope('training'): 2088 with K.name_scope(self.optimizer.__class__.__name__): 2089 # Training updates 2090 updates = self.optimizer.get_updates( 2091 params=self._collected_trainable_weights, loss=self.total_loss) 2092 # Unconditional updates 2093 updates += self.get_updates_for(None) 2094 # Conditional updates relevant to this model 2095 updates += self.get_updates_for(self.inputs) 2096 2097 with K.name_scope('training'): 2098 # Gets loss and metrics. Updates weights at each call. 2099 fn = K.function( 2100 inputs, [self.total_loss] + metrics_tensors, 2101 updates=updates, 2102 name='train_function', 2103 **self._function_kwargs) 2104 setattr(self, 'train_function', fn) 2105 2106 def _make_test_function(self): 2107 metrics_tensors = [ 2108 self._all_metrics_tensors[m] for m in self.metrics_names[1:] 2109 ] 2110 if not self._is_compiled: 2111 raise RuntimeError('You must compile your model before using it.') 2112 if getattr(self, 'test_function') is None: 2113 inputs = (self._feed_inputs + 2114 self._feed_targets + 2115 self._feed_sample_weights) 2116 2117 with K.name_scope('evaluation'): 2118 updates = self.state_updates 2119 # Return loss and metrics, no gradient updates. 2120 # Does update the network states. 2121 fn = K.function( 2122 inputs, [self.total_loss] + metrics_tensors, 2123 updates=updates, 2124 name='test_function', 2125 **self._function_kwargs) 2126 setattr(self, 'test_function', fn) 2127 2128 def _make_predict_function(self): 2129 if not hasattr(self, 'predict_function'): 2130 self.predict_function = None 2131 if self.predict_function is None: 2132 inputs = self._feed_inputs 2133 # Gets network outputs. Does not update weights. 2134 # Does update the network states. 2135 kwargs = getattr(self, '_function_kwargs', {}) 2136 with K.name_scope(ModeKeys.PREDICT): 2137 self.predict_function = K.function( 2138 inputs, 2139 self.outputs, 2140 updates=self.state_updates, 2141 name='predict_function', 2142 **kwargs) 2143 2144 def _make_execution_function(self, mode): 2145 if mode == ModeKeys.TRAIN: 2146 self._make_train_function() 2147 return self.train_function 2148 if mode == ModeKeys.TEST: 2149 self._make_test_function() 2150 return self.test_function 2151 if mode == ModeKeys.PREDICT: 2152 self._make_predict_function() 2153 return self.predict_function 2154 2155 def _distribution_standardize_user_data(self, 2156 x, 2157 y=None, 2158 sample_weight=None, 2159 class_weight=None, 2160 batch_size=None, 2161 validation_split=0, 2162 shuffle=False, 2163 repeat=False, 2164 allow_partial_batch=False): 2165 """Runs validation checks on input and target data passed by the user. 2166 2167 This is called when using DistributionStrategy to train, evaluate or serve 2168 the model. 2169 2170 Args: 2171 x: Input data. A numpy array or `tf.data` dataset. 2172 y: Target data. A numpy array or None if x is a `tf.data` dataset. 2173 sample_weight: An optional sample-weight array passed by the user to 2174 weight the importance of each sample in `x`. 2175 class_weight: An optional class-weight array by the user to 2176 weight the importance of samples in `x` based on the class they belong 2177 to, as conveyed by `y`. 2178 batch_size: Integer batch size. If provided, it is used to run additional 2179 validation checks on stateful models. 2180 validation_split: Float between 0 and 1. 2181 Fraction of the training data to be used as validation data. 2182 shuffle: Boolean whether to shuffle the training data before each epoch. 2183 repeat: Boolean whether to repeat the numpy training data when converting 2184 to training dataset. 2185 allow_partial_batch: Boolean whether to enforce that all batches have the 2186 same size. 2187 2188 Returns: 2189 Dataset instance. 2190 2191 Raises: 2192 ValueError: In case of invalid user-provided data. 2193 RuntimeError: If the model was never compiled. 2194 """ 2195 if class_weight: 2196 raise NotImplementedError('`class_weight` is currently not supported ' 2197 'when using DistributionStrategy.') 2198 2199 if (sample_weight is not None and sample_weight.all() and 2200 distributed_training_utils.is_tpu_strategy( 2201 self._distribution_strategy)): 2202 raise NotImplementedError('`sample_weight` is currently not supported ' 2203 'when using TPUStrategy.') 2204 2205 if (self.stateful and distributed_training_utils.is_tpu_strategy( 2206 self._distribution_strategy) and self._distribution_strategy. 2207 num_replicas_in_sync != 1): 2208 raise ValueError('Single core must be used for computation on ' 2209 'stateful models. Consider adding `device_assignment` ' 2210 'parameter to TPUStrategy using\n' 2211 'topology = tf.contrib.distribute.' 2212 'initialize_tpu_system()\n' 2213 'device_assignment = tf.contrib.tpu.DeviceAssignment(' 2214 'topology, core_assignment=tf.contrib.tpu.' 2215 'SINGLE_CORE_ASSIGNMENT)\n' 2216 'tpu_strategy = tf.contrib.distribute.TPUStrategy(' 2217 'device_assignment=device_assignment)') 2218 2219 # Validates `steps` and `shuffle` arguments right at the beginning 2220 # since we use it to construct the dataset object. 2221 # TODO(anjalisridhar): Remove this check once we refactor the 2222 # _standardize_user_data code path. This check is already present elsewhere 2223 # in the codebase. 2224 if isinstance(x, dataset_ops.DatasetV2): 2225 if shuffle: 2226 training_utils.verify_dataset_shuffled(x) 2227 2228 strategy = self._distribution_strategy 2229 with strategy.scope(): 2230 # We should be sure to call get_session() inside the strategy.scope() 2231 # so the strategy can affect the session options. 2232 if ops.executing_eagerly_outside_functions(): 2233 session = None 2234 else: 2235 session = K.get_session() 2236 2237 first_x_value = nest.flatten(x)[0] 2238 if isinstance(first_x_value, np.ndarray): 2239 x = distributed_training_utils.list_to_tuple(x) 2240 if y is not None: 2241 y = distributed_training_utils.list_to_tuple(y) 2242 if sample_weight is not None: 2243 sample_weight = distributed_training_utils.list_to_tuple( 2244 sample_weight) 2245 in_tuple = (x, y, sample_weight) 2246 else: 2247 in_tuple = (x, y) 2248 else: 2249 in_tuple = x 2250 2251 ds = strategy.extended.experimental_make_numpy_dataset(in_tuple, 2252 session=session) 2253 if shuffle: 2254 # We want a buffer size that is larger than the batch size provided by 2255 # the user and provides sufficient randomness. Note that larger 2256 # numbers introduce more memory usage based on the size of each 2257 # sample. 2258 ds = ds.shuffle(max(1024, batch_size * 8)) 2259 if repeat: 2260 ds = ds.repeat() 2261 2262 # We need to use the drop_remainder argument to get a known static 2263 # input shape which is required for TPUs. 2264 drop_remainder = (not allow_partial_batch and 2265 strategy.extended.experimental_require_static_shapes) 2266 x = ds.batch(batch_size, drop_remainder=drop_remainder) 2267 else: 2268 assert isinstance(x, dataset_ops.DatasetV2) 2269 training_utils.validate_dataset_input(x, y, sample_weight, 2270 validation_split) 2271 return x 2272 2273 def _standardize_user_data(self, 2274 x, 2275 y=None, 2276 sample_weight=None, 2277 class_weight=None, 2278 batch_size=None, 2279 check_steps=False, 2280 steps_name='steps', 2281 steps=None, 2282 validation_split=0, 2283 shuffle=False, 2284 extract_tensors_from_dataset=False): 2285 """Runs validation checks on input and target data passed by the user. 2286 2287 Also standardizes the data to lists of arrays, in order. 2288 2289 Also builds and compiles the model on the fly if it is a subclassed model 2290 that has never been called before (and thus has no inputs/outputs). 2291 2292 This is a purely internal method, subject to refactoring at any time. 2293 2294 Args: 2295 x: Input data. It could be: 2296 - A Numpy array (or array-like), or a list of arrays 2297 (in case the model has multiple inputs). 2298 - A TensorFlow tensor, or a list of tensors 2299 (in case the model has multiple inputs). 2300 - A dict mapping input names to the corresponding array/tensors, 2301 if the model has named inputs. 2302 - A `tf.data` dataset or a dataset iterator. 2303 y: Target data. Like the input data `x`, 2304 it could be either Numpy array(s) or TensorFlow tensor(s). 2305 It should be consistent with `x` (you cannot have Numpy inputs and 2306 tensor targets, or inversely). If `x` is a dataset or a 2307 dataset iterator, `y` should not be specified 2308 (since targets will be obtained from the iterator). 2309 sample_weight: An optional sample-weight array passed by the user to 2310 weight the importance of each sample in `x`. 2311 class_weight: An optional class-weight array by the user to 2312 weight the importance of samples in `x` based on the class they belong 2313 to, as conveyed by `y`. If both `sample_weight` and `class_weight` are 2314 provided, the weights are multiplied. 2315 batch_size: Integer batch size. If provided, it is used to run additional 2316 validation checks on stateful models. 2317 check_steps: boolean, True if we want to check for validity of `steps` and 2318 False, otherwise. For example, when we are standardizing one batch of 2319 data for train_on_batch/predict_on_batch/test_on_batch APIs, `steps` 2320 value is not required and we should not check for its validity in these 2321 cases. 2322 steps_name: The public API's parameter name for `steps`. 2323 steps: Integer or `None`. Total number of steps (batches of samples) to 2324 execute. 2325 validation_split: Float between 0 and 1. 2326 Fraction of the training data to be used as validation data. 2327 shuffle: Boolean whether to shuffle the training data before each epoch. 2328 extract_tensors_from_dataset: Boolean. When `x` is a dataset instance, 2329 this indicates whether to extract actual tensors from the dataset or 2330 instead output the dataset instance itself. 2331 Set to True when calling from `train_on_batch`/etc. 2332 2333 Returns: 2334 A tuple of 3: inputs (arrays or dicts, depending on whether `x` was a dict 2335 or not), target arrays, sample-weight arrays. 2336 If the model's input and targets are symbolic, these lists are empty 2337 (since the model takes no user-provided data, instead the data comes 2338 from the symbolic inputs/targets). 2339 2340 Raises: 2341 ValueError: In case of invalid user-provided data. 2342 RuntimeError: If the model was never compiled. 2343 """ 2344 if isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): 2345 # Graph mode dataset. We'll pass the dataset as-is (unless 2346 # `extract_tensors_from_dataset` is True, in which case we extract 2347 # the tensors from the dataset and we output them. 2348 training_utils.validate_dataset_input(x, y, sample_weight, 2349 validation_split) 2350 if shuffle: 2351 training_utils.verify_dataset_shuffled(x) 2352 2353 is_dataset = True 2354 if extract_tensors_from_dataset: 2355 # We do this for `train_on_batch`/etc. 2356 x, y, sample_weight = training_utils.extract_tensors_from_dataset(x) 2357 elif isinstance(x, iterator_ops.Iterator): 2358 # Graph mode iterator. We extract the symbolic tensors. 2359 training_utils.validate_dataset_input(x, y, sample_weight, 2360 validation_split) 2361 iterator = x 2362 x, y, sample_weight = training_utils.unpack_iterator_input(iterator) 2363 is_dataset = True 2364 else: 2365 is_dataset = False 2366 2367 # Validates `steps` argument based on x's type. 2368 if check_steps: 2369 training_utils.check_steps_argument(x, steps, steps_name) 2370 2371 # First, we build/compile the model on the fly if necessary. 2372 all_inputs = [] 2373 is_build_called = False 2374 is_compile_called = False 2375 # Whether this is a subclassed model that expects dictionary inputs 2376 # rather than list inputs (e.g. FeatureColumn-based models). 2377 dict_inputs = False 2378 if not self.inputs: 2379 # We need to use `x_input` to set the model inputs. 2380 2381 # If input data is a dataset iterator in graph mode or if it is an eager 2382 # iterator and only one batch of samples is required, we fetch the data 2383 # tensors from the iterator and then standardize them. 2384 if isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): 2385 x_input, y_input, _ = training_utils.extract_tensors_from_dataset(x) 2386 else: 2387 x_input = x 2388 y_input = y 2389 # We type-check that `x_input` and `y_input` are either single arrays 2390 # or lists of arrays. 2391 if isinstance(x_input, (list, tuple)): 2392 if not all(isinstance(v, np.ndarray) or 2393 tensor_util.is_tensor(v) for v in x_input): 2394 raise ValueError('Please provide as model inputs either a single ' 2395 'array or a list of arrays. You passed: x=' + str(x)) 2396 all_inputs += list(x_input) 2397 elif isinstance(x_input, dict): 2398 dict_inputs = True 2399 keys = sorted(x_input.keys()) 2400 all_inputs = [x_input[k] for k in keys] 2401 else: 2402 if (not isinstance(x_input, np.ndarray) and 2403 not tensor_util.is_tensor(x_input)): 2404 raise ValueError('Please provide as model inputs either a single ' 2405 'array or a list of arrays. You passed: x=' + str(x)) 2406 all_inputs.append(x_input) 2407 2408 # Build the model using the retrieved inputs (value or symbolic). 2409 # If values or generated from a dataset, then in symbolic-mode 2410 # placeholders will be created to match the value shapes. 2411 is_build_called = True 2412 if is_dataset: 2413 cast_inputs = nest.map_structure(lambda v: v.shape, x_input) 2414 elif training_utils.has_tensors(x_input): 2415 cast_inputs = training_utils.cast_if_floating_dtype(x_input) 2416 else: 2417 cast_inputs = x_input 2418 self._set_inputs(cast_inputs) 2419 else: 2420 y_input = y 2421 dict_inputs = isinstance(self.inputs, dict) 2422 2423 if y_input is not None: 2424 if not self.optimizer: 2425 raise RuntimeError('You must compile a model before ' 2426 'training/testing. ' 2427 'Use `model.compile(optimizer, loss)`.') 2428 if not self._is_compiled: 2429 # On-the-fly compilation of the model. 2430 # We need to use `y` to set the model targets. 2431 if training_utils.has_tensors(y_input): 2432 y_input = training_utils.cast_if_floating_dtype(y_input) 2433 if isinstance(y_input, (list, tuple)): 2434 if not all(isinstance(v, np.ndarray) or 2435 tensor_util.is_tensor(v) for v in y_input): 2436 raise ValueError('Please provide as model targets either a single ' 2437 'array or a list of arrays. ' 2438 'You passed: y=' + str(y)) 2439 all_inputs += list(y_input) 2440 elif isinstance(y_input, dict): 2441 raise ValueError('You cannot pass a dictionary as model targets.') 2442 else: 2443 if (not isinstance(y_input, np.ndarray) and 2444 not tensor_util.is_tensor(y_input)): 2445 raise ValueError('Please provide as model targets either a single ' 2446 'array or a list of arrays. ' 2447 'You passed: y=' + str(y)) 2448 all_inputs.append(y_input) 2449 2450 # Typecheck that all inputs are *either* value *or* symbolic. 2451 # TODO(fchollet): this check could be removed in Eager mode? 2452 if any(tensor_util.is_tensor(v) for v in all_inputs): 2453 if not all(tensor_util.is_tensor(v) for v in all_inputs): 2454 raise ValueError('Do not pass inputs that mix Numpy arrays and ' 2455 'TensorFlow tensors. ' 2456 'You passed: x=' + str(x) + '; y=' + str(y)) 2457 2458 if is_dataset or context.executing_eagerly(): 2459 target_tensors = None 2460 else: 2461 # Handle target tensors if any passed. 2462 if not isinstance(y_input, (list, tuple)): 2463 y_input = [y_input] 2464 target_tensors = [v for v in y_input if _is_symbolic_tensor(v)] 2465 is_compile_called = True 2466 self.compile( 2467 optimizer=self.optimizer, 2468 loss=self.loss, 2469 metrics=self._compile_metrics, 2470 weighted_metrics=self._compile_weighted_metrics, 2471 loss_weights=self.loss_weights, 2472 target_tensors=target_tensors, 2473 run_eagerly=self.run_eagerly) 2474 2475 # In graph mode, if we had just set inputs and targets as symbolic tensors 2476 # by invoking build and compile on the model respectively, we do not have to 2477 # feed anything to the model. Model already has input and target data as 2478 # part of the graph. 2479 # Note: in this case, `any` and `all` are equivalent since we disallow 2480 # mixed symbolic/value inputs. 2481 if (not self.run_eagerly and is_build_called and is_compile_called and 2482 not is_dataset and any(_is_symbolic_tensor(v) for v in all_inputs)): 2483 return [], [], [] 2484 2485 # What follows is input validation and standardization to list format, 2486 # in the case where all inputs are value arrays. 2487 2488 if self.run_eagerly: 2489 # In eager mode, do not do shape validation 2490 # since the network has no input nodes (placeholders) to be fed. 2491 feed_input_names = self.input_names 2492 feed_input_shapes = None 2493 elif not self._is_graph_network: 2494 # Case: symbolic-mode subclassed network. Do not do shape validation. 2495 feed_input_names = self._feed_input_names 2496 feed_input_shapes = None 2497 else: 2498 # Case: symbolic-mode graph network. 2499 # In this case, we run extensive shape validation checks. 2500 feed_input_names = self._feed_input_names 2501 feed_input_shapes = self._feed_input_shapes 2502 2503 # Standardize the inputs. 2504 if not isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): 2505 # TODO(fchollet): run static checks with dataset output shape(s). 2506 x = training_utils.standardize_input_data( 2507 x, 2508 feed_input_names, 2509 feed_input_shapes, 2510 check_batch_axis=False, # Don't enforce the batch size. 2511 exception_prefix='input') 2512 2513 if y is not None: 2514 if not self._is_graph_network: 2515 feed_output_names = self._feed_output_names 2516 feed_output_shapes = None 2517 # Sample weighting not supported in this case. 2518 # TODO(fchollet): consider supporting it. 2519 feed_sample_weight_modes = [None for _ in self.outputs] 2520 else: 2521 feed_output_names = self._feed_output_names 2522 feed_sample_weight_modes = self._feed_sample_weight_modes 2523 feed_output_shapes = [] 2524 for output_shape, loss_fn in zip(self._feed_output_shapes, 2525 self._feed_loss_fns): 2526 if ((isinstance(loss_fn, losses.LossFunctionWrapper) and 2527 loss_fn.fn == losses.sparse_categorical_crossentropy)) or ( 2528 isinstance(loss_fn, losses.SparseCategoricalCrossentropy)): 2529 if K.image_data_format() == 'channels_first': 2530 feed_output_shapes.append( 2531 (output_shape[0], 1) + output_shape[2:]) 2532 else: 2533 feed_output_shapes.append(output_shape[:-1] + (1,)) 2534 elif (not isinstance(loss_fn, losses.Loss) or 2535 (isinstance(loss_fn, losses.LossFunctionWrapper) and 2536 (getattr(losses, loss_fn.fn.__name__, None) is None))): 2537 # If the given loss is not an instance of the `Loss` class (custom 2538 # class) or if the loss function that is wrapped is not in the 2539 # `losses` module, then it is a user-defined loss and we make no 2540 # assumptions about it. 2541 feed_output_shapes.append(None) 2542 else: 2543 feed_output_shapes.append(output_shape) 2544 2545 # Standardize the outputs. 2546 y = training_utils.standardize_input_data( 2547 y, 2548 feed_output_names, 2549 # Don't enforce target shapes to match output shapes. 2550 # Precise checks will be run in `check_loss_and_target_compatibility`. 2551 shapes=None, 2552 check_batch_axis=False, # Don't enforce the batch size. 2553 exception_prefix='target') 2554 2555 # Generate sample-wise weight values given the `sample_weight` and 2556 # `class_weight` arguments. 2557 sample_weights = training_utils.standardize_sample_weights( 2558 sample_weight, feed_output_names) 2559 class_weights = training_utils.standardize_class_weights( 2560 class_weight, feed_output_names) 2561 sample_weights = [ 2562 training_utils.standardize_weights(ref, sw, cw, mode) 2563 for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights, 2564 feed_sample_weight_modes) 2565 ] 2566 # Check that all arrays have the same length. 2567 if not self._distribution_strategy: 2568 training_utils.check_array_lengths(x, y, sample_weights) 2569 if self._is_graph_network and not self.run_eagerly: 2570 # Additional checks to avoid users mistakenly using improper loss fns. 2571 training_utils.check_loss_and_target_compatibility( 2572 y, self._feed_loss_fns, feed_output_shapes) 2573 else: 2574 y = [] 2575 sample_weights = [] 2576 2577 if self.stateful and batch_size: 2578 # Check that for stateful networks, number of samples is a multiple 2579 # of the static batch size. 2580 if x[0].shape[0] % batch_size != 0: 2581 raise ValueError('In a stateful network, ' 2582 'you should only pass inputs with ' 2583 'a number of samples that can be ' 2584 'divided by the batch size. Found: ' + 2585 str(x[0].shape[0]) + ' samples') 2586 2587 # If dictionary inputs were provided, we return a dictionary as well. 2588 if dict_inputs and not isinstance(x, (dataset_ops.DatasetV1, 2589 dataset_ops.DatasetV2)): 2590 x = dict(zip(feed_input_names, x)) 2591 return x, y, sample_weights 2592 2593 def _unpack_validation_data(self, validation_data): 2594 if (isinstance(validation_data, (iterator_ops.Iterator, 2595 iterator_ops.EagerIterator, 2596 dataset_ops.DatasetV2))): 2597 val_x = validation_data 2598 val_y = None 2599 val_sample_weight = None 2600 elif len(validation_data) == 2: 2601 val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence 2602 val_sample_weight = None 2603 elif len(validation_data) == 3: 2604 val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence 2605 else: 2606 raise ValueError( 2607 'When passing a `validation_data` argument, ' 2608 'it must contain either 2 items (x_val, y_val), ' 2609 'or 3 items (x_val, y_val, val_sample_weights), ' 2610 'or alternatively it could be a dataset or a ' 2611 'dataset or a dataset iterator. ' 2612 'However we received `validation_data=%s`' % validation_data) 2613 return val_x, val_y, val_sample_weight 2614 2615 # TODO(omalleyt): Consider changing to a more descriptive function name. 2616 def _set_inputs(self, inputs, outputs=None, training=None): 2617 """Set model's input and output specs based on the input data received. 2618 2619 This is to be used for Model subclasses, which do not know at instantiation 2620 time what their inputs look like. 2621 2622 Args: 2623 inputs: Single array, or list of arrays. The arrays could be placeholders, 2624 Numpy arrays, data tensors, or TensorShapes. 2625 - if placeholders: the model is built on top of these placeholders, 2626 and we expect Numpy data to be fed for them when calling `fit`/etc. 2627 - if Numpy data or TensorShapes: we create placeholders matching the 2628 TensorShapes or shapes of the Numpy arrays. We expect Numpy data to be 2629 fed for these placeholders when calling `fit`/etc. 2630 - if data tensors: the model is built on top of these tensors. 2631 We do not expect any Numpy data to be provided when calling `fit`/etc. 2632 outputs: None, a data tensor, or a list of tensors. If None, the 2633 outputs will be determined by invoking `self.call()`, otherwise the 2634 provided value will be used. 2635 training: Boolean or None. Only relevant in symbolic mode. Specifies 2636 whether to build the model's graph in inference mode (False), training 2637 mode (True), or using the Keras learning phase (None). 2638 Raises: 2639 ValueError: If dict inputs are passed to a Sequential Model where the 2640 first layer isn't FeatureLayer. 2641 """ 2642 inputs = self._set_input_attrs(inputs) 2643 2644 if outputs is None: 2645 kwargs = {'training': training} if self._expects_training_arg else {} 2646 try: 2647 outputs = self(inputs, **kwargs) 2648 except NotImplementedError: 2649 # This Model or a submodel is dynamic and hasn't overridden 2650 # `compute_output_shape`. 2651 outputs = None 2652 2653 self._set_output_attrs(outputs) 2654 2655 @trackable.no_automatic_dependency_tracking 2656 def _set_input_attrs(self, inputs): 2657 """Sets attributes related to the inputs of the Model.""" 2658 if self.inputs: 2659 raise ValueError('Model inputs are already set.') 2660 2661 if self.__class__.__name__ == 'Sequential' and not self.built: 2662 if tensor_util.is_tensor(inputs): 2663 input_shape = (None,) + tuple(inputs.shape.as_list()[1:]) 2664 elif isinstance(inputs, tensor_shape.TensorShape): 2665 input_shape = (None,) + tuple(inputs.as_list()[1:]) 2666 elif isinstance(inputs, dict): 2667 # We assert that the first layer is a FeatureLayer. 2668 if not training_utils.is_feature_layer(self.layers[0]): 2669 raise ValueError('Passing a dictionary input to a Sequential Model ' 2670 'which doesn\'t have FeatureLayer as the first layer' 2671 ' is an error.') 2672 input_shape = (None,) 2673 else: 2674 input_shape = (None,) + tuple(inputs.shape[1:]) 2675 self._build_input_shape = input_shape 2676 2677 # On-the-fly setting of symbolic model inputs (either by using the tensor 2678 # provided, or by creating a placeholder if Numpy data was provided). 2679 model_inputs = training_utils.ModelInputs(inputs) 2680 inputs = model_inputs.get_symbolic_inputs() 2681 self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True) 2682 self.input_names = model_inputs.get_input_names() 2683 2684 self._feed_inputs = [] 2685 self._feed_input_names = [] 2686 self._feed_input_shapes = [] 2687 2688 for k, v in model_inputs.as_dict(): 2689 if K.is_placeholder(v): 2690 self._feed_input_names.append(k) 2691 self._feed_inputs.append(v) 2692 self._feed_input_shapes.append(K.int_shape(v)) 2693 2694 return inputs 2695 2696 @trackable.no_automatic_dependency_tracking 2697 def _set_output_attrs(self, outputs): 2698 """Sets attributes related to the outputs of the Model.""" 2699 outputs = nest.flatten(outputs) 2700 self.outputs = outputs 2701 self.output_names = training_utils.generic_output_names(outputs) 2702 self.built = True 2703 2704 2705class DistributedCallbackModel(Model): 2706 """Model that is used for callbacks with DistributionStrategy.""" 2707 2708 def __init__(self, model): 2709 super(DistributedCallbackModel, self).__init__() 2710 self.optimizer = model.optimizer 2711 2712 def set_original_model(self, orig_model): 2713 self._original_model = orig_model 2714 2715 def save_weights(self, filepath, overwrite=True, save_format=None): 2716 self._replicated_model.save_weights(filepath, overwrite=overwrite, 2717 save_format=save_format) 2718 2719 def save(self, filepath, overwrite=True, include_optimizer=True): 2720 # save weights from the distributed model to the original model 2721 distributed_model_weights = self.get_weights() 2722 self._original_model.set_weights(distributed_model_weights) 2723 # TODO(anjalisridhar): Do we need to save the original model here? 2724 # Saving the first replicated model works as well. 2725 self._original_model.save(filepath, overwrite=True, include_optimizer=False) 2726 2727 def load_weights(self, filepath, by_name=False): 2728 self._original_model.load_weights(filepath, by_name=False) 2729 # Copy the weights from the original model to each of the replicated models. 2730 orig_model_weights = self._original_model.get_weights() 2731 distributed_training_utils.set_weights( 2732 self._original_model._distribution_strategy, self, # pylint: disable=protected-access 2733 orig_model_weights) 2734 2735 def __getattr__(self, item): 2736 # Whitelisted atttributes of the model that can be accessed by the user 2737 # during a callback. 2738 if item not in ['_setattr_tracking']: 2739 logging.warning('You are accessing attribute ' + item + ' of the ' 2740 'DistributedCallbackModel that may not have been set ' 2741 'correctly.') 2742 2743 2744def _is_symbolic_tensor(x): 2745 return tensor_util.is_tensor(x) and not isinstance(x, ops.EagerTensor) 2746