1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Part of the Keras training engine related to plain array data. 16""" 17# pylint: disable=protected-access 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import functools 23 24import numpy as np 25 26from tensorflow.python.data.ops import dataset_ops 27from tensorflow.python.data.ops import iterator_ops 28from tensorflow.python.eager import context 29from tensorflow.python.framework import errors 30from tensorflow.python.keras import backend as K 31from tensorflow.python.keras import callbacks as cbks 32from tensorflow.python.keras.distribute import distributed_training_utils_v1 33from tensorflow.python.keras.engine import training_utils_v1 34from tensorflow.python.keras.utils.generic_utils import make_batches 35from tensorflow.python.keras.utils.generic_utils import slice_arrays 36from tensorflow.python.keras.utils.mode_keys import ModeKeys 37from tensorflow.python.platform import tf_logging as logging 38from tensorflow.python.util import nest 39 40try: 41 from scipy.sparse import issparse # pylint: disable=g-import-not-at-top 42except ImportError: 43 issparse = None 44 45 46def model_iteration(model, 47 inputs, 48 targets=None, 49 sample_weights=None, 50 batch_size=None, 51 epochs=1, 52 verbose=1, 53 callbacks=None, 54 val_inputs=None, 55 val_targets=None, 56 val_sample_weights=None, 57 shuffle=True, 58 initial_epoch=0, 59 steps_per_epoch=None, 60 validation_steps=None, 61 validation_freq=1, 62 mode=ModeKeys.TRAIN, 63 validation_in_fit=False, 64 prepared_feed_values_from_dataset=False, 65 steps_name='steps', 66 **kwargs): 67 """Loop function for arrays of data with modes TRAIN/TEST/PREDICT. 68 69 Args: 70 model: Keras Model instance. 71 inputs: Either a list or dictionary of arrays, or a dataset instance. 72 targets: List/dictionary of input arrays. 73 sample_weights: Optional list of sample weight arrays. 74 batch_size: Integer batch size or None if unknown. 75 epochs: Number of times to iterate over the data 76 verbose: 0, 1, or 2. Verbosity mode. 77 0 = silent, 1 = progress bar, 2 = one line per epoch. 78 Note that the progress bar is not particularly useful when 79 logged to a file, so verbose=2 is recommended when not running 80 interactively (eg, in a production environment). 81 callbacks: List of callbacks to be called during training 82 val_inputs: Either a list or dictionary of arrays, or a dataset instance. 83 val_targets: List/dictionary of target arrays. 84 val_sample_weights: Optional list of sample weight arrays. 85 shuffle: Whether to shuffle the data at the beginning of each epoch 86 concatenation of list the display names of the outputs of `f` and the 87 list of display names of the outputs of `f_val`. 88 initial_epoch: Epoch at which to start training (useful for resuming a 89 previous training run) 90 steps_per_epoch: Total number of steps (batches of samples) before 91 declaring one epoch finished and starting the next epoch. Ignored with 92 the default value of `None`. 93 validation_steps: Number of steps to run validation for (only if doing 94 validation from data tensors). Ignored with the default value of 95 `None`. 96 validation_freq: Only relevant if validation data is provided. Integer or 97 `collections.abc.Container` instance (e.g. list, tuple, etc.). If an 98 integer, specifies how many training epochs to run before a new 99 validation run is performed, e.g. `validation_freq=2` runs 100 validation every 2 epochs. If a Container, specifies the epochs on 101 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs 102 validation at the end of the 1st, 2nd, and 10th epochs. 103 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 104 validation_in_fit: if true, then this method is invoked from within 105 training iteration (for validation). In the case where `val_inputs` is 106 a dataset, this flag indicates that its iterator and feed values are 107 already created so should properly reuse resources. 108 prepared_feed_values_from_dataset: if True, `inputs` is a list of feed 109 tensors returned from `_prepare_feed_values` call on the validation 110 dataset, so do not call it again on `inputs`. Should only be used for 111 inline validation (i.e., only if `validation_in_fit` is also True). 112 steps_name: The string name of the steps argument, either `steps`, 113 `validation_steps`, or `steps_per_epoch`. Only used for error message 114 formatting. 115 **kwargs: Additional arguments for backwards compatibility. 116 117 Returns: 118 - In TRAIN mode: `History` object. 119 - In TEST mode: Evaluation metrics. 120 - In PREDICT mode: Outputs of the Model called on inputs. 121 122 Raises: 123 ValueError: in case of invalid arguments. 124 """ 125 # Backwards compatibility. 126 if 'steps' in kwargs: 127 steps_per_epoch = kwargs.pop('steps') 128 if kwargs: 129 raise TypeError('Unknown arguments: %s' % (kwargs,)) 130 131 # In case we were passed a dataset, we extract symbolic tensors from it. 132 reset_dataset_after_each_epoch = False 133 input_iterator = None 134 is_dataset = isinstance(inputs, 135 (dataset_ops.DatasetV1, dataset_ops.DatasetV2)) 136 # TODO(fchollet): consider moving `steps_per_epoch` inference to 137 # _standardize_user_data and set reset_dataset_after_each_epoch as an 138 # attribute on the dataset instance. 139 if is_dataset: 140 if steps_per_epoch is None: 141 reset_dataset_after_each_epoch = True 142 steps_per_epoch = training_utils_v1.infer_steps_for_dataset( 143 model, inputs, steps_per_epoch, epochs=epochs, steps_name=steps_name) 144 input_iterator = _get_iterator(inputs, model._distribution_strategy) 145 146 # Enter tf.distribute.Strategy scope. 147 if model._distribution_strategy: 148 scope = distributed_training_utils_v1.distributed_scope( 149 strategy=model._distribution_strategy, 150 learning_phase=(1 if mode == ModeKeys.TRAIN else 0)) 151 scope.__enter__() 152 153 use_steps = is_dataset or steps_per_epoch is not None 154 do_validation = val_inputs is not None 155 156 # Prepare input data. 157 inputs = input_iterator or inputs 158 if validation_in_fit and prepared_feed_values_from_dataset: 159 # When invoking validation in training loop, avoid creating iterator and 160 # list of feed values for the same validation dataset multiple times (which 161 # essentially would call `iterator.get_next()` that slows down execution and 162 # leads to OOM errors eventually. 163 ins = inputs 164 else: 165 ins = _prepare_feed_values(model, inputs, targets, sample_weights, mode) 166 # `ins` is a function when a distribute strategy is used in Eager mode. In 167 # that case `is_dataset` is True. The code branches that have requirements 168 # about the type of `ins` do not trigger in the distributed case. 169 170 if not is_dataset: 171 num_samples_or_steps = _get_num_samples_or_steps(ins, batch_size, 172 steps_per_epoch) 173 else: 174 num_samples_or_steps = steps_per_epoch 175 176 # Update sample_weight_mode of the model if sample_weights is specified by the 177 # user. We need to call this function after we have a handle on the inputs 178 # (both numpy arrays and datasets) in order to determine if the user has 179 # specified sample_weights. 180 _update_sample_weight_mode(model, mode, ins) 181 182 # Get step function and loop type. As part of building the execution 183 # function we recompile the metrics based on the updated 184 # sample_weight_mode value. 185 f = _make_execution_function(model, mode) 186 187 # Prepare validation data. Hold references to the iterator and the input list 188 # to properly reinitialize and reuse in multiple validation passes. 189 val_iterator = None 190 if isinstance(val_inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): 191 if validation_steps is None: 192 # Because we pass an iterator feed instead of a Dataset to the eval 193 # model_iteration() call, it will not trigger the dataset-input path 194 # that determines the number of steps required. To avoid this issue, 195 # set validation_steps here if validation_steps is None. 196 validation_steps = training_utils_v1.infer_steps_for_dataset( 197 model, 198 val_inputs, 199 validation_steps, 200 epochs=epochs, 201 steps_name='validation_steps') 202 val_iterator = _get_iterator(val_inputs, model._distribution_strategy) 203 val_inputs = _prepare_feed_values( 204 model, val_iterator, val_targets, val_sample_weights, ModeKeys.TEST) 205 # Get num steps for printing. 206 val_samples_or_steps = validation_steps 207 else: 208 # Get num samples for printing. 209 val_samples_or_steps = val_inputs and nest.flatten( 210 val_inputs)[0].shape[0] or None 211 212 if mode == ModeKeys.TRAIN and verbose: 213 _print_train_info(num_samples_or_steps, val_samples_or_steps, is_dataset) 214 215 # Configure callbacks. 216 count_mode = 'steps' if use_steps else 'samples' 217 callbacks = cbks.configure_callbacks( 218 callbacks, 219 model, 220 do_validation=do_validation, 221 batch_size=batch_size, 222 epochs=epochs, 223 steps_per_epoch=steps_per_epoch, 224 samples=num_samples_or_steps, 225 count_mode=count_mode, 226 verbose=verbose, 227 mode=mode) 228 229 # Find beforehand arrays that need sparse-to-dense conversion. 230 if issparse is not None and not use_steps: 231 indices_for_conversion_to_dense = [] 232 feed = _get_model_feed(model, mode) 233 for i, (input_data, feed_tensor) in enumerate(zip(ins, feed)): 234 if issparse(input_data) and not K.is_sparse(feed_tensor): 235 indices_for_conversion_to_dense.append(i) 236 237 # Select aggregation method. 238 if mode == ModeKeys.PREDICT: 239 aggregator = training_utils_v1.OutputsAggregator( 240 use_steps, 241 num_samples=None if steps_per_epoch else num_samples_or_steps, 242 steps=steps_per_epoch) 243 else: 244 aggregator = training_utils_v1.MetricsAggregator( 245 use_steps, 246 num_samples=None if steps_per_epoch else num_samples_or_steps, 247 steps=steps_per_epoch) 248 249 if model._compile_distribution: 250 distributed_training_utils_v1._copy_weights_to_distributed_model( 251 model, mode) 252 253 callbacks.model.stop_training = False 254 callbacks._call_begin_hook(mode) 255 256 initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch, mode) 257 258 for epoch in range(initial_epoch, epochs): 259 if callbacks.model.stop_training: 260 break 261 262 # Setup work for each epoch 263 epoch_logs = {} 264 if mode != ModeKeys.PREDICT: 265 # Collecting and resetting metrics has non-zero cost and will needlessly 266 # slow down model.predict. 267 model.reset_metrics() 268 if mode == ModeKeys.TRAIN: 269 callbacks.on_epoch_begin(epoch, epoch_logs) 270 271 if use_steps: 272 # Step-wise loop. 273 if steps_per_epoch is None: 274 # Loop over dataset until `OutOfRangeError` is raised. 275 target_steps = np.inf 276 else: 277 # Loop over dataset for the specified number of steps. 278 target_steps = steps_per_epoch 279 280 step = 0 281 while step < target_steps: 282 batch_logs = {'batch': step, 'size': 1} 283 callbacks._call_batch_hook(mode, 'begin', step, batch_logs) 284 285 # Get outputs. 286 try: 287 # `ins` can be callable in tf.distribute.Strategy + eager case. 288 if not callable(ins) or (model._distribution_strategy and 289 not distributed_training_utils_v1 290 .is_distributing_by_cloning(model)): 291 actual_inputs = ins 292 else: 293 actual_inputs = ins() 294 batch_outs = f(actual_inputs) 295 except errors.OutOfRangeError: 296 if is_dataset: 297 # The dataset passed by the user ran out of batches. 298 # Now we know the cardinality of the dataset. 299 # If steps_per_epoch was specified, then running out of data is 300 # unexpected, so we stop training and inform the user. 301 if steps_per_epoch: 302 callbacks.model.stop_training = True 303 logging.warning( 304 'Your dataset ran out of data; interrupting training. ' 305 'Make sure that your dataset can generate at least ' 306 '`%s * epochs` batches (in this case, %d batches). ' 307 'You may need to use the repeat() function when ' 308 'building your dataset.' 309 % (steps_name, steps_per_epoch * epochs)) 310 elif step > 0: 311 steps_per_epoch = step 312 aggregator.steps = steps_per_epoch 313 else: 314 # We ran out of batches while the user passed an iterator (legacy). 315 callbacks.model.stop_training = True 316 logging.warning( 317 'Your dataset iterator ran out of data; ' 318 'interrupting training. Make sure that your iterator ' 319 'can generate at least `%s * epochs` ' 320 'batches (in this case, %d batches). You may need to' 321 'use the repeat() function when building your ' 322 'dataset.' % (steps_name, steps_per_epoch * epochs)) 323 break 324 325 if not isinstance(batch_outs, list): 326 batch_outs = [batch_outs] 327 328 if model._distribution_strategy: 329 batch_outs = ( 330 distributed_training_utils_v1._per_replica_aggregate_batch( 331 model._distribution_strategy, batch_outs, model, mode)) 332 333 # Aggregate results. 334 if step == 0: 335 aggregator.create(batch_outs) 336 aggregator.aggregate(batch_outs) 337 338 # Callbacks batch end. 339 batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode) 340 callbacks._call_batch_hook(mode, 'end', step, batch_logs) 341 step += 1 342 343 if callbacks.model.stop_training: 344 break 345 else: 346 # Sample-wise loop. 347 index_array = np.arange(num_samples_or_steps) 348 if shuffle == 'batch': 349 index_array = training_utils_v1.batch_shuffle(index_array, batch_size) 350 elif shuffle: 351 np.random.shuffle(index_array) 352 batches = make_batches(num_samples_or_steps, batch_size) 353 for batch_index, (batch_start, batch_end) in enumerate(batches): 354 batch_ids = index_array[batch_start:batch_end] 355 # Slice into a batch. 356 if len(batches) == 1: 357 # If we only have one batch, do not slice. This takes care of 358 # composite tensors in non-Dataset modes; we currently don't support 359 # slicing them. 360 # TODO(b/133517906): Add slicing support. 361 ins_batch = ins 362 else: 363 try: 364 if ins and isinstance(ins[-1], int): 365 # Do not slice the training phase flag. 366 ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] 367 else: 368 ins_batch = slice_arrays(ins, batch_ids) 369 except TypeError: 370 raise TypeError('TypeError while preparing batch. ' 371 'If using HDF5 input data, ' 372 'pass shuffle="batch".') 373 374 # Sparse to dense conversion. 375 if issparse is not None: 376 for i in indices_for_conversion_to_dense: 377 ins_batch[i] = ins_batch[i].toarray() 378 379 # Callbacks batch_begin. 380 batch_logs = {'batch': batch_index, 'size': len(batch_ids)} 381 callbacks._call_batch_hook(mode, 'begin', batch_index, batch_logs) 382 383 # Get outputs. 384 batch_outs = f(ins_batch) 385 if not isinstance(batch_outs, list): 386 batch_outs = [batch_outs] 387 388 # Aggregate results. 389 if batch_index == 0: 390 aggregator.create(batch_outs) 391 aggregator.aggregate(batch_outs, batch_start, batch_end) 392 393 # Callbacks batch end. 394 batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode) 395 callbacks._call_batch_hook(mode, 'end', batch_index, batch_logs) 396 397 if callbacks.model.stop_training: 398 break 399 400 aggregator.finalize() 401 results = aggregator.results 402 epoch_logs = cbks.make_logs(model, epoch_logs, results, mode) 403 if len(results) == 1: 404 results = results[0] 405 406 # Run the test loop every `validation_freq` epochs during training. 407 if (do_validation and 408 training_utils_v1.should_run_validation(validation_freq, epoch) and 409 not callbacks.model.stop_training): 410 411 if model._compile_distribution: 412 # Since we create a new clone from the original model we need to copy 413 # the weights back to the original model before we can run validation. 414 distributed_training_utils_v1._copy_weights_to_original_model( 415 model, ModeKeys.TRAIN) 416 417 val_results = model_iteration( 418 model, 419 val_inputs, 420 targets=val_targets, 421 sample_weights=val_sample_weights, 422 batch_size=batch_size, 423 steps_per_epoch=validation_steps, 424 callbacks=callbacks, 425 verbose=0, 426 mode=ModeKeys.TEST, 427 validation_in_fit=True, 428 prepared_feed_values_from_dataset=(val_iterator is not None), 429 steps_name='validation_steps') 430 if not isinstance(val_results, list): 431 val_results = [val_results] 432 epoch_logs = cbks.make_logs( 433 model, epoch_logs, val_results, mode, prefix='val_') 434 if val_iterator and epoch < epochs - 1: 435 _reinitialize_iterator(val_iterator, model._distribution_strategy) 436 437 if mode == ModeKeys.TRAIN: 438 # Epochs only apply to `fit`. 439 callbacks.on_epoch_end(epoch, epoch_logs) 440 441 # Reinitialize dataset iterator for the next epoch. 442 if reset_dataset_after_each_epoch and epoch < epochs - 1: 443 _reinitialize_iterator(input_iterator, model._distribution_strategy) 444 445 model._successful_loop_finish = True 446 callbacks._call_end_hook(mode) 447 448 if model._distribution_strategy: 449 if model._compile_distribution: 450 # TODO(priyag, psv): Copy back metrics to the original model as well? 451 distributed_training_utils_v1._copy_weights_to_original_model(model, mode) 452 scope.__exit__(None, None, None) 453 454 if mode == ModeKeys.TRAIN: 455 return model.history 456 return results 457 458 459def _get_model_feed(model, mode): 460 if mode == ModeKeys.PREDICT: 461 feed = model._feed_inputs 462 else: 463 feed = ( 464 model._feed_inputs + model._feed_targets + model._feed_sample_weights) 465 return feed 466 467 468def _print_train_info(num_samples_or_steps, val_samples_or_steps, is_dataset): 469 increment = 'steps' if is_dataset else 'samples' 470 msg = 'Train on {0} {increment}'.format( 471 num_samples_or_steps, increment=increment) 472 if val_samples_or_steps: 473 msg += ', validate on {0} {increment}'.format( 474 val_samples_or_steps, increment=increment) 475 print(msg) 476 477 478def _get_num_samples_or_steps(ins, batch_size, steps_per_epoch): 479 """Returns total number of samples (when training in batch mode) or steps.""" 480 if steps_per_epoch: 481 return steps_per_epoch 482 return training_utils_v1.check_num_samples(ins, batch_size, steps_per_epoch, 483 'steps_per_epoch') 484 485 486def _prepare_feed_values(model, inputs, targets, sample_weights, mode): 487 """Prepare feed values to the model execution function. 488 489 Args: 490 model: Model to prepare feed values for. 491 inputs: List or dict of model inputs. 492 targets: Optional list of model targets. 493 sample_weights: Optional list of sample weight arrays. 494 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 495 496 Returns: 497 Feed values for the model in the given mode. 498 """ 499 if model._distribution_strategy: 500 if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): 501 inputs = distributed_training_utils_v1.get_iterator( 502 inputs, model._distribution_strategy) 503 504 def get_distributed_inputs(): 505 return distributed_training_utils_v1._prepare_feed_values( 506 model, inputs, targets, sample_weights, mode) 507 508 # In the eager case, we want to call the input method per step, so return 509 # a lambda from here that can be called. Note that this is applicable only 510 # in Distribution Strategy case as it follows the same code path for both 511 # eager and graph modes. 512 # TODO(priyag,omalleyt): Either we should move the training DS with 513 # IteratorBase to use training_generator code path, or figure out how to 514 # set a symbolic Iterator out of a Dataset when in eager mode. 515 if context.executing_eagerly(): 516 return get_distributed_inputs 517 else: 518 return get_distributed_inputs() 519 520 if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2, 521 iterator_ops.Iterator)): 522 inputs, targets, sample_weights = model._standardize_user_data( 523 inputs, 524 extract_tensors_from_dataset=True) 525 526 inputs = training_utils_v1.ModelInputs(inputs).as_list() 527 targets = list(targets or []) 528 sample_weights = list(sample_weights or []) 529 ins = inputs + targets + sample_weights 530 if mode == ModeKeys.TRAIN and not isinstance(K.symbolic_learning_phase(), 531 int): 532 ins += [True] # Add learning phase value. 533 return ins 534 535 536def _get_iterator(inputs, distribution_strategy=None): 537 if distribution_strategy: 538 return distributed_training_utils_v1.get_iterator( 539 inputs, distribution_strategy) 540 return training_utils_v1.get_iterator(inputs) 541 542 543def _reinitialize_iterator(iterator, distribution_strategy=None): 544 if distribution_strategy: 545 distributed_training_utils_v1.initialize_iterator( 546 iterator, distribution_strategy) 547 else: 548 training_utils_v1.initialize_iterator(iterator) 549 550 551def _make_execution_function(model, mode): 552 """Makes function to run one step of model execution.""" 553 if model._distribution_strategy: 554 return distributed_training_utils_v1._make_execution_function(model, mode) 555 return model._make_execution_function(mode) 556 557 558def _update_sample_weight_mode(model, mode, inputs): 559 """Updates the sample_weight_mode of a given model.""" 560 # Add a quick return to prevent us from calling model._feed_targets that 561 # accesses certain model properties that may not be set in the `PREDICT` mode. 562 if mode == ModeKeys.PREDICT: 563 return 564 565 sample_weights = None 566 # `inputs` is the model's inputs + targets + sample_weights + 567 # learning phase placeholder if specified. To update the sample_weight_mode 568 # we need to determine if the user has passed sample weights as part of the 569 # input. 570 if not callable(inputs): 571 sample_weights = inputs[len(model._feed_inputs) + len(model._feed_targets):] 572 has_learning_phase_pl = (mode == ModeKeys.TRAIN and 573 not isinstance(K.symbolic_learning_phase(), int)) 574 if has_learning_phase_pl: 575 sample_weights = sample_weights[:-1] 576 model._update_sample_weight_modes(sample_weights=sample_weights) 577 578 # Call the DistributionStrategy specific function to update the 579 # sample_weight_mode on the model. 580 if model._distribution_strategy: 581 distributed_training_utils_v1._update_sample_weight_modes(model, mode, 582 sample_weights) 583 584# For backwards compatibility for internal users of these loops. 585fit_loop = functools.partial(model_iteration, mode=ModeKeys.TRAIN) 586test_loop = functools.partial( 587 model_iteration, mode=ModeKeys.TEST, shuffle=False) 588predict_loop = functools.partial( 589 model_iteration, mode=ModeKeys.PREDICT, shuffle=False) 590 591 592class ArrayLikeTrainingLoop(training_utils_v1.TrainingLoop): 593 """TrainingLoop that handle inputs like array. 594 595 This is the default handler for most of the input data types, includes 596 symbolic tensors or Numpy array-like, Datasets and iterators in graph mode 597 (since they generate symbolic tensors). This Function is used to handle model 598 with `run_eagerly` = False. 599 """ 600 601 def fit(self, 602 model, 603 x=None, 604 y=None, 605 batch_size=None, 606 epochs=1, 607 verbose=1, 608 callbacks=None, 609 validation_split=0., 610 validation_data=None, 611 shuffle=True, 612 class_weight=None, 613 sample_weight=None, 614 initial_epoch=0, 615 steps_per_epoch=None, 616 validation_steps=None, 617 validation_freq=1, 618 **kwargs): 619 batch_size = model._validate_or_infer_batch_size(batch_size, 620 steps_per_epoch, x) 621 622 x, y, sample_weights = model._standardize_user_data( 623 x, 624 y, 625 sample_weight=sample_weight, 626 class_weight=class_weight, 627 batch_size=batch_size, 628 check_steps=True, 629 steps_name='steps_per_epoch', 630 steps=steps_per_epoch, 631 validation_split=validation_split, 632 shuffle=shuffle) 633 634 if validation_data: 635 val_x, val_y, val_sample_weights = model._prepare_validation_data( 636 validation_data, batch_size, validation_steps) 637 elif validation_split and 0. < validation_split < 1.: 638 (x, y, sample_weights, val_x, val_y, val_sample_weights 639 ) = training_utils_v1.split_training_and_validation_data( 640 x, y, sample_weights, validation_split) 641 else: 642 if validation_steps: 643 raise ValueError('`validation_steps` should not be specified if ' 644 '`validation_data` is None.') 645 val_x, val_y, val_sample_weights = None, None, None 646 647 return fit_loop( 648 model, 649 inputs=x, 650 targets=y, 651 sample_weights=sample_weights, 652 batch_size=batch_size, 653 epochs=epochs, 654 verbose=verbose, 655 callbacks=callbacks, 656 val_inputs=val_x, 657 val_targets=val_y, 658 val_sample_weights=val_sample_weights, 659 shuffle=shuffle, 660 initial_epoch=initial_epoch, 661 steps_per_epoch=steps_per_epoch, 662 validation_steps=validation_steps, 663 validation_freq=validation_freq, 664 steps_name='steps_per_epoch') 665 666 def evaluate(self, 667 model, 668 x=None, 669 y=None, 670 batch_size=None, 671 verbose=1, 672 sample_weight=None, 673 steps=None, 674 callbacks=None, 675 **kwargs): 676 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 677 x, y, sample_weights = model._standardize_user_data( 678 x, 679 y, 680 sample_weight=sample_weight, 681 batch_size=batch_size, 682 check_steps=True, 683 steps_name='steps', 684 steps=steps) 685 return test_loop( 686 model, 687 inputs=x, 688 targets=y, 689 sample_weights=sample_weights, 690 batch_size=batch_size, 691 verbose=verbose, 692 steps=steps, 693 callbacks=callbacks) 694 695 def predict(self, 696 model, 697 x, 698 batch_size=None, 699 verbose=0, 700 steps=None, 701 callbacks=None, 702 **kwargs): 703 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 704 x, _, _ = model._standardize_user_data( 705 x, check_steps=True, steps_name='steps', steps=steps) 706 return predict_loop( 707 model, 708 x, 709 batch_size=batch_size, 710 verbose=verbose, 711 steps=steps, 712 callbacks=callbacks) 713