1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Part of the Keras training engine related to distributed training. 16""" 17# pylint: disable=protected-access 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import numpy as np 23 24from tensorflow.python.distribute import distribute_coordinator as dc 25from tensorflow.python.distribute import distribution_strategy_context 26from tensorflow.python.distribute import input_lib 27from tensorflow.python.distribute import reduce_util as ds_reduce_util 28from tensorflow.python.eager import context 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import errors 31from tensorflow.python.framework import ops 32from tensorflow.python.keras import backend as K 33from tensorflow.python.keras import callbacks as cbks 34from tensorflow.python.keras.distribute import distributed_training_utils_v1 as dist_utils 35from tensorflow.python.keras.engine import partial_batch_padding_handler as padding_util 36from tensorflow.python.keras.engine import training_arrays_v1 37from tensorflow.python.keras.engine import training_utils_v1 38from tensorflow.python.keras.utils.generic_utils import Progbar 39from tensorflow.python.keras.utils.mode_keys import ModeKeys 40from tensorflow.python.ops import array_ops 41from tensorflow.python.ops import control_flow_ops 42from tensorflow.python.platform import tf_logging as logging 43 44 45def _per_replica_execution_function(model, mode): 46 exec_func = model._make_execution_function(mode) 47 return (exec_func.inputs, exec_func.outputs, exec_func.updates_op, 48 exec_func.session_kwargs) 49 50 51def _build_model(strategy, model, mode, inputs, targets=None): 52 if model._compile_distribution: 53 dist_utils.clone_model_on_replicas( 54 model, strategy, mode, inputs=inputs, targets=targets) 55 else: 56 dist_utils._build_distributed_network(model, strategy, mode, inputs, 57 targets) 58 59 60def _make_train_step_fn(model, mode, strategy, output_labels): 61 """Create step fn. 62 63 Args: 64 model: a Keras Model instance. 65 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 66 strategy: a `tf.distribute.Strategy` instance. 67 output_labels: the output labels for the step function. 68 69 Returns: 70 A step function to run by `tf.distribute.Strategy`. 71 """ 72 73 def _step_fn(ctx, inputs): 74 """A step fn that returns update ops.""" 75 if isinstance(inputs, (tuple, list)) and len(inputs) == 2: 76 inputs, targets = inputs 77 else: 78 targets = None 79 80 # When input feature is a dictionary of tensors, dictionary is flattended 81 # to an array and passed as a model input. This results in input mismatch 82 # when model input layer names are not sorted in alphabetical order as 83 # `nest.flatten()`sorts dictionary elements by keys. As so, transform input 84 # tensors into an array and order it along `model._feed_input_names`. 85 if isinstance(inputs, dict): 86 inputs = [inputs[input_name] for input_name in model._feed_input_names] 87 88 _build_model(strategy, model, mode, inputs, targets) 89 90 (grouped_inputs, grouped_outputs, grouped_updates, 91 grouped_session_args) = strategy.extended.call_for_each_replica( 92 _per_replica_execution_function, 93 args=(dist_utils.get_distributed_model(model, mode), mode)) 94 (all_inputs, all_outputs, all_updates, 95 all_session_args) = dist_utils.unwrap_values(strategy, grouped_inputs, 96 grouped_outputs, 97 grouped_updates, 98 grouped_session_args) 99 combined_fn = K.function( 100 all_inputs, 101 all_outputs, 102 updates=all_updates, 103 name='distributed_' + str(mode) + '_function', 104 **all_session_args) 105 106 for label, output in zip(output_labels, combined_fn.outputs): 107 if label == 'loss': 108 reduce_op = ds_reduce_util.ReduceOp.SUM 109 else: 110 # We reduce all other metrics using mean for now. This is temporary 111 # workaround until new metrics are in place. 112 reduce_op = ds_reduce_util.ReduceOp.MEAN 113 ctx.set_last_step_output(label, output, reduce_op) 114 115 # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: 116 # feed_dict, session kwargs, run options, run_metadata for now. These should 117 # be handled appropriately 118 return combined_fn.updates_op 119 120 return _step_fn 121 122 123def experimental_tpu_fit_loop(model, 124 dataset, 125 epochs=100, 126 verbose=1, 127 callbacks=None, 128 initial_epoch=0, 129 steps_per_epoch=None, 130 val_dataset=None, 131 validation_steps=None, 132 validation_freq=1): 133 """Fit loop for training with TPU tf.distribute.Strategy. 134 135 Args: 136 model: Keras Model instance. 137 dataset: Dataset that returns inputs and targets 138 epochs: Number of times to iterate over the data 139 verbose: Integer, Verbosity mode, 0, 1 or 2 140 callbacks: List of callbacks to be called during training 141 initial_epoch: Epoch at which to start training 142 (useful for resuming a previous training run) 143 steps_per_epoch: Total number of steps (batches of samples) 144 before declaring one epoch finished and starting the 145 next epoch. Ignored with the default value of `None`. 146 val_dataset: Dataset for validation data. 147 validation_steps: Number of steps to run validation for 148 (only if doing validation from data tensors). 149 Ignored with the default value of `None`. 150 validation_freq: Only relevant if validation data is provided. Integer or 151 `collections.abc.Container` instance (e.g. list, tuple, etc.). If an 152 integer, specifies how many training epochs to run before a new 153 validation run is performed, e.g. `validation_freq=2` runs 154 validation every 2 epochs. If a Container, specifies the epochs on 155 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs 156 validation at the end of the 1st, 2nd, and 10th epochs. 157 158 Returns: 159 Returns `None`. 160 161 Raises: 162 ValueError: in case of invalid arguments. 163 """ 164 mode = ModeKeys.TRAIN 165 166 current_strategy = model._distribution_strategy 167 iteration_value = min(steps_per_epoch, 168 current_strategy.extended.steps_per_run) 169 steps_per_run = K.variable( 170 value=iteration_value, 171 dtype='int32', 172 name='steps_per_run') 173 174 # TODO(fchollet): add support for `steps_per_epoch=None` in TPU loops. 175 iterator = dist_utils.get_iterator(dataset, current_strategy) 176 177 scope = dist_utils.distributed_scope( 178 strategy=current_strategy, learning_phase=1) 179 scope.__enter__() 180 181 out_labels = model.metrics_names or [] 182 183 step_fn = _make_train_step_fn(model, ModeKeys.TRAIN, current_strategy, 184 out_labels) 185 186 # Add initial dummy values for loss and other metric tensors. 187 initial_loop_values = {} 188 initial_loop_values['loss'] = constant_op.constant(1e7) 189 for m in model._get_training_eval_metrics(): 190 tensor = m.result() 191 initial_loop_values[m.name] = array_ops.zeros(tensor.shape, tensor.dtype) 192 193 ctx = current_strategy.extended.experimental_run_steps_on_iterator( 194 step_fn, iterator, iterations=steps_per_run, 195 initial_loop_values=initial_loop_values) 196 train_op = ctx.run_op 197 output_tensors = ctx.last_step_outputs 198 199 do_validation = bool(validation_steps) 200 201 if model._compile_distribution: 202 dist_utils._copy_weights_to_distributed_model(model, mode) 203 204 callbacks = cbks.configure_callbacks( 205 callbacks, 206 model, 207 do_validation=do_validation, 208 epochs=epochs, 209 steps_per_epoch=steps_per_epoch, 210 verbose=verbose, 211 count_mode='steps', 212 mode=mode) 213 214 # Calculate the steps each time on the device. 215 steps_to_run = ([current_strategy.extended.steps_per_run] * 216 (steps_per_epoch // 217 current_strategy.extended.steps_per_run)) 218 if steps_per_epoch % current_strategy.extended.steps_per_run: 219 steps_to_run.append( 220 steps_per_epoch % current_strategy.extended.steps_per_run) 221 target_steps = len(steps_to_run) 222 223 callbacks._call_begin_hook(mode) 224 225 initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch, mode) 226 227 for epoch in range(initial_epoch, epochs): 228 dist_utils._reset_metrics(model) 229 callbacks.on_epoch_begin(epoch) 230 epoch_logs = {} 231 step_index = 0 232 prev_step_count = None 233 current_step = 0 234 while current_step < target_steps: 235 step_count = steps_to_run[current_step] 236 batch_logs = {'batch': step_index, 'size': 1, 'num_steps': step_count} 237 callbacks._call_batch_hook(mode, 'begin', step_index, batch_logs) 238 if prev_step_count is None or step_count != prev_step_count: 239 K.get_session().run(steps_per_run.assign(step_count)) 240 prev_step_count = step_count 241 try: 242 _, outputs = K.batch_get_value([train_op, output_tensors]) 243 except errors.OutOfRangeError: 244 logging.warning('Your dataset iterator ran out of data; ' 245 'interrupting training. Make sure that your dataset ' 246 'can generate at least `steps_per_epoch * epochs` ' 247 'batches (in this case, %d batches).' % 248 steps_per_epoch * epochs) 249 break 250 251 batch_logs.update(outputs) 252 callbacks._call_batch_hook(mode, 'end', step_index, batch_logs) 253 step_index = step_index + step_count 254 current_step += 1 255 256 if callbacks.model.stop_training: 257 break 258 259 if (do_validation and 260 training_utils_v1.should_run_validation(validation_freq, epoch)): 261 logging.info('Running validation at fit epoch: %s', epoch) 262 263 if model._compile_distribution: 264 # Since we create a new clone from the original model we need to copy 265 # the weights back to the original model before we can run validation. 266 dist_utils._copy_weights_to_original_model(model, ModeKeys.TRAIN) 267 268 val_outs = experimental_tpu_test_loop( # pylint: disable=undefined-variable 269 model, 270 val_dataset, 271 steps=validation_steps, 272 verbose=verbose, 273 callbacks=callbacks) 274 if not isinstance(val_outs, list): 275 val_outs = [val_outs] 276 # Same labels assumed. 277 for label, val_out in zip(out_labels, val_outs): 278 epoch_logs['val_' + label] = val_out 279 280 callbacks.on_epoch_end(epoch, epoch_logs) 281 if callbacks.model.stop_training: 282 break 283 model._successful_loop_finish = True 284 callbacks._call_end_hook(mode) 285 286 if model._compile_distribution: 287 # Copy the weights back from the replicated model to the original model. 288 dist_utils._copy_weights_to_original_model(model, ModeKeys.TRAIN) 289 scope.__exit__(None, None, None) 290 return model.history 291 292 293def experimental_tpu_test_loop(model, 294 dataset, 295 verbose=0, 296 steps=None, 297 callbacks=None): 298 """Test loop for evaluating with TPU tf.distribute.Strategy. 299 300 Args: 301 model: Keras Model instance. 302 dataset: Dataset for input data. 303 verbose: Integer, Verbosity mode 0 or 1. 304 steps: Total number of steps (batches of samples) 305 before declaring predictions finished. 306 Ignored with the default value of `None`. 307 callbacks: List of callbacks to be called during training 308 309 Returns: 310 Scalar loss (if the model has a single output and no metrics) 311 or list of scalars (if the model has multiple outputs 312 and/or metrics). The attribute `model.metrics_names` will give you 313 the display labels for the outputs. 314 """ 315 mode = ModeKeys.TEST 316 current_strategy = model._distribution_strategy 317 iterator = dist_utils.get_iterator(dataset, current_strategy) 318 319 scope = dist_utils.distributed_scope( 320 strategy=current_strategy, learning_phase=0) 321 scope.__enter__() 322 323 out_labels = model.metrics_names 324 325 def _test_step_fn(inputs): 326 """A fn that returns output of single test step.""" 327 if isinstance(inputs, (tuple, list)) and len(inputs) == 2: 328 inputs, targets = inputs 329 else: 330 targets = None 331 332 (distribution_strategy_context.get_replica_context().merge_call( 333 _build_model, args=(model, mode, inputs, targets))) 334 335 (_, outputs, updates, _) = _per_replica_execution_function( 336 dist_utils.get_distributed_model(model, mode), mode) 337 with ops.control_dependencies([updates]): 338 return [array_ops.identity(out) for out in outputs] 339 340 test_input_data = iterator.get_next() 341 per_replica_outputs = current_strategy.run( 342 _test_step_fn, args=(test_input_data,)) 343 output_tensors = {} 344 for label, output in zip(out_labels, per_replica_outputs): 345 if label == 'loss': 346 reduce_op = ds_reduce_util.ReduceOp.SUM 347 else: 348 # We reduce all other metrics using mean for now. This is temporary 349 # workaround until new metrics are in place. 350 reduce_op = ds_reduce_util.ReduceOp.MEAN 351 output_tensors[label] = current_strategy.reduce(reduce_op, output, 352 axis=None) 353 test_op = control_flow_ops.group(list(output_tensors.values())) 354 355 if verbose >= 1: 356 progbar = Progbar(target=steps) 357 358 if model._compile_distribution: 359 dist_utils._copy_weights_to_distributed_model(model, mode) 360 361 dist_utils._reset_metrics(model) 362 363 callbacks = cbks.configure_callbacks( 364 callbacks, 365 model, 366 do_validation=False, 367 epochs=1, 368 steps_per_epoch=steps, 369 verbose=verbose, 370 count_mode='steps', 371 mode=ModeKeys.TEST) 372 callbacks._call_begin_hook(mode) 373 374 outs = [0.] * len(model.metrics_names) 375 if steps is not None: 376 target_steps = steps 377 else: 378 raise ValueError('Number of steps could not be inferred from the data, ' 379 'please pass the steps argument.') 380 381 current_step = 0 382 while current_step < target_steps: 383 batch_logs = {'batch': current_step, 'size': 1} 384 callbacks._call_batch_hook(mode, 'begin', current_step, batch_logs) 385 try: 386 _, batch_outs = K.batch_get_value([test_op, output_tensors]) 387 except errors.OutOfRangeError: 388 warning_msg = ( 389 'Make sure that your dataset can generate at least ' 390 '`steps` batches (in this case, {} batches).'.format(steps)) 391 392 logging.warning('Your dataset iterator ran out of data; ' 393 'interrupting evaluation. ' + warning_msg) 394 target_steps = current_step 395 break 396 for i, label in enumerate(model.metrics_names): 397 if i == 0: 398 # Loss is stateless metrics. 399 outs[i] += batch_outs[label] 400 else: 401 # For all stateful metrics, the aggregation is handled by mirrored vars. 402 outs[i] = batch_outs[label] 403 404 batch_logs = cbks.make_logs(model, batch_logs, outs, mode) 405 callbacks._call_batch_hook(mode, 'end', current_step, batch_logs) 406 if verbose == 1: 407 progbar.update(current_step + 1) 408 current_step += 1 409 410 if verbose >= 1: 411 # Progress bar finishes at the end. 412 progbar.update(target_steps) 413 callbacks._call_end_hook(mode) 414 415 scope.__exit__(None, None, None) 416 if len(outs) >= 0: 417 outs[0] /= (target_steps) 418 419 if len(outs) == 1: 420 return outs[0] 421 return outs 422 423 424def experimental_tpu_predict_loop(model, 425 dataset, 426 verbose=0, 427 steps=None, 428 callbacks=None): 429 """Predict loop for predicting with TPU tf.distribute.Strategy. 430 431 Args: 432 model: Keras Model instance. 433 dataset: Dataset for input data. 434 verbose: Integer, Verbosity mode 0 or 1. 435 steps: Total number of steps (batches of samples) 436 before declaring `_predict_loop` finished. 437 Ignored with the default value of `None`. 438 callbacks: List of callbacks to be called during training 439 440 Returns: 441 Array of predictions (if the model has a single output) 442 or list of arrays of predictions 443 (if the model has multiple outputs). 444 """ 445 mode = ModeKeys.PREDICT 446 dataset_fully_shaped = dist_utils.is_dataset_shape_fully_defined(dataset) 447 padding_handler = None 448 if not dataset_fully_shaped: 449 # TODO(hongjunchoi): Investigate whether operations from 450 # PartialBatchPaddingHandler are unnecessarily pruned out 451 # during graph optimization. 452 padding_handler = padding_util.PartialBatchPaddingHandler( 453 model._feed_output_shapes) 454 batch_size, _, prefetch_buffer = input_lib._get_dataset_attributes(dataset) 455 padding_handler.padded_batch_size = batch_size 456 padding_handler.padding_mask = dataset.reduce(padding_handler.padding_mask, 457 padding_handler.update_mask) 458 459 dataset = dataset.map(padding_handler.pad_batch) 460 dataset = dataset.unbatch() 461 # Upon this point, it is guaranteed that the dataset does not 462 # have partial batches. Thus, we set `drop_remainder=True` to 463 # get static shape information about the elements in the dataset. 464 dataset = dataset.batch(batch_size, drop_remainder=True) 465 466 if prefetch_buffer is not None: 467 dataset = dataset.prefetch(prefetch_buffer) 468 469 current_strategy = model._distribution_strategy 470 iterator = dist_utils.get_iterator(dataset, current_strategy) 471 472 scope = dist_utils.distributed_scope( 473 strategy=current_strategy, learning_phase=0) 474 scope.__enter__() 475 476 def _predict_step_fn(inputs): 477 """A fn that returns output of single prediction step.""" 478 479 (distribution_strategy_context.get_replica_context().merge_call( 480 _build_model, args=(model, mode, inputs))) 481 482 (_, outputs, updates, _) = _per_replica_execution_function( 483 dist_utils.get_distributed_model(model, mode), mode) 484 485 with ops.control_dependencies([updates]): 486 return [array_ops.identity(out) for out in outputs] 487 488 # TODO(hongjunchoi): When numpy array is passed as an input to `predict()` 489 # use numpy arrays directly to avoid cumulating unnecessary input pipeline 490 # ops. 491 predict_input_data = iterator.get_next() 492 per_replica_outputs = current_strategy.run( 493 _predict_step_fn, args=(predict_input_data,)) 494 output_tensors = dist_utils.flatten_per_replica_values( 495 current_strategy, per_replica_outputs) 496 497 if verbose >= 1: 498 progbar = Progbar(target=steps) 499 500 if model._compile_distribution: 501 dist_utils._copy_weights_to_distributed_model(model, mode) 502 503 dist_utils._reset_metrics(model) 504 505 callbacks = cbks.configure_callbacks( 506 callbacks, 507 model, 508 do_validation=False, 509 epochs=1, 510 steps_per_epoch=steps, 511 verbose=verbose, 512 count_mode='steps', 513 mode=mode) 514 callbacks._call_begin_hook(mode) 515 516 # Since we do not know how many samples we will see, we cannot pre-allocate 517 # the returned Numpy arrays. Instead, we store one array per batch seen 518 # and concatenate them upon returning. 519 num_model_outputs = len(model.output_names) 520 unconcatenated_outs = [[] for _ in range(num_model_outputs)] 521 if steps is not None: 522 target_steps = steps 523 else: 524 raise ValueError('Number of steps could not be inferred from the data, ' 525 'please pass the steps argument.') 526 527 current_step = 0 528 while current_step < target_steps: 529 batch_logs = {'batch': current_step, 'size': 1} 530 callbacks._call_batch_hook(mode, 'begin', current_step, batch_logs) 531 try: 532 predict_ops = control_flow_ops.group(output_tensors) 533 _, batch_outs = K.batch_get_value([predict_ops, output_tensors]) 534 535 except errors.OutOfRangeError: 536 warning_msg = ( 537 'Make sure that your dataset can generate at least ' 538 '`steps` batches (in this case, {} batches).'.format(steps)) 539 540 logging.warning('Your dataset iterator ran out of data; ' 541 'interrupting evaluation. ' + warning_msg) 542 break 543 544 # TODO(priyag): maybe need to unwrap the outputs first for MirroredStrategy. 545 for i in range(num_model_outputs): 546 output_start_index = i * current_strategy.num_replicas_in_sync 547 output_end_index = ( 548 output_start_index + current_strategy.num_replicas_in_sync) 549 single_model_output = batch_outs[output_start_index:output_end_index] 550 unconcatenated_outs[i].extend(single_model_output) 551 552 batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode) 553 callbacks._call_batch_hook(mode, 'end', current_step, batch_logs) 554 if verbose == 1: 555 progbar.update(current_step + 1) 556 current_step += 1 557 558 if verbose >= 1: 559 # Progress bar finishes at the end. 560 progbar.update(current_step) 561 562 callbacks._call_end_hook(mode) 563 564 scope.__exit__(None, None, None) 565 566 if len(unconcatenated_outs) == 1: 567 prediction_result = np.concatenate(unconcatenated_outs[0], axis=0) 568 else: 569 prediction_result = [ 570 np.concatenate(out, axis=0) for out in unconcatenated_outs 571 ] 572 573 if padding_handler: 574 prediction_result = padding_handler.apply_mask(prediction_result) 575 576 return prediction_result 577 578 579class DistributionSingleWorkerTrainingLoop(training_utils_v1.TrainingLoop): 580 """Training loop for distribution strategy with single worker.""" 581 582 def fit(self, 583 model, 584 x=None, 585 y=None, 586 batch_size=None, 587 epochs=1, 588 verbose=1, 589 callbacks=None, 590 validation_split=0., 591 validation_data=None, 592 shuffle=True, 593 class_weight=None, 594 sample_weight=None, 595 initial_epoch=0, 596 steps_per_epoch=None, 597 validation_steps=None, 598 validation_freq=1, 599 **kwargs): 600 """Fit loop for Distribution Strategies.""" 601 dist_utils.validate_callbacks(input_callbacks=callbacks, 602 optimizer=model.optimizer) 603 dist_utils.validate_inputs(x, y) 604 605 batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size( 606 model._distribution_strategy, 607 x, 608 batch_size, 609 steps_per_epoch, 610 ModeKeys.TRAIN, 611 validation_split=validation_split) 612 batch_size = model._validate_or_infer_batch_size( 613 batch_size, steps_per_epoch, x) 614 dataset = model._distribution_standardize_user_data( 615 x, y, 616 sample_weight=sample_weight, 617 class_weight=class_weight, 618 batch_size=batch_size, 619 validation_split=validation_split, 620 shuffle=shuffle, 621 epochs=epochs) 622 if not dist_utils.is_distributing_by_cloning(model): 623 with model._distribution_strategy.scope(): 624 (dataset, _, _) = model._standardize_user_data( 625 dataset, 626 sample_weight=sample_weight, 627 class_weight=class_weight, 628 batch_size=batch_size, 629 validation_split=validation_split, 630 shuffle=shuffle) 631 632 val_dataset = None 633 if validation_data: 634 val_x, val_y, val_sample_weights = ( 635 training_utils_v1.unpack_validation_data(validation_data)) 636 dist_utils.validate_inputs(val_x, val_y) 637 _, validation_steps = dist_utils.process_batch_and_step_size( 638 model._distribution_strategy, val_x, batch_size, validation_steps, 639 ModeKeys.TEST) 640 641 val_dataset = model._distribution_standardize_user_data( 642 val_x, val_y, 643 sample_weight=val_sample_weights, 644 class_weight=None, 645 batch_size=batch_size, 646 validation_split=validation_split, 647 shuffle=shuffle, 648 allow_partial_batch=True) 649 elif validation_split: 650 raise ValueError('validation_split argument is not supported with ' 651 'distribution strategies.') 652 653 if K.is_tpu_strategy(model._distribution_strategy): 654 steps_per_epoch = training_utils_v1.infer_steps_for_dataset( 655 model, dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch') 656 if steps_per_epoch is None: 657 raise ValueError('Number of steps could not be inferred from the data, ' 658 'please pass the steps_per_epoch argument.') 659 660 if not context.executing_eagerly(): 661 # Run TPU training in a custom loop in graph mode. 662 return experimental_tpu_fit_loop( 663 model, 664 dataset, 665 epochs=epochs, 666 verbose=verbose, 667 callbacks=callbacks, 668 val_dataset=val_dataset, 669 initial_epoch=initial_epoch, 670 steps_per_epoch=steps_per_epoch, 671 validation_steps=validation_steps, 672 validation_freq=validation_freq) 673 674 return training_arrays_v1.fit_loop( 675 model, 676 dataset, 677 batch_size=batch_size, 678 epochs=epochs, 679 verbose=verbose, 680 callbacks=callbacks, 681 val_inputs=val_dataset, 682 shuffle=shuffle, 683 initial_epoch=initial_epoch, 684 steps_per_epoch=steps_per_epoch, 685 validation_steps=validation_steps, 686 validation_freq=validation_freq, 687 steps_name='steps_per_epoch') 688 689 def evaluate(self, 690 model, 691 x=None, 692 y=None, 693 batch_size=None, 694 verbose=1, 695 sample_weight=None, 696 steps=None, 697 callbacks=None, 698 **kwargs): 699 """Evaluate loop for Distribution Strategies.""" 700 dist_utils.validate_inputs(x, y) 701 batch_size, steps = dist_utils.process_batch_and_step_size( 702 model._distribution_strategy, x, batch_size, steps, ModeKeys.TEST) 703 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 704 dataset = model._distribution_standardize_user_data( 705 x, y, 706 sample_weight=sample_weight, 707 batch_size=batch_size, 708 allow_partial_batch=True) 709 710 if K.is_tpu_strategy(model._distribution_strategy): 711 steps = training_utils_v1.infer_steps_for_dataset( 712 model, dataset, steps, steps_name='steps') 713 if steps is None: 714 raise ValueError('Number of steps could not be inferred from the data, ' 715 'please pass the steps argument.') 716 717 if not context.executing_eagerly(): 718 # Run TPU evaluation in a custom loop in graph mode. 719 return experimental_tpu_test_loop( 720 model, dataset, verbose=verbose, steps=steps, callbacks=callbacks) 721 722 return training_arrays_v1.test_loop( 723 model, 724 inputs=dataset, 725 batch_size=batch_size, 726 verbose=verbose, 727 steps=steps, 728 callbacks=callbacks) 729 730 def predict(self, 731 model, 732 x, 733 batch_size=None, 734 verbose=0, 735 steps=None, 736 callbacks=None, 737 **kwargs): 738 """Predict loop for Distribution Strategies.""" 739 dist_utils.validate_inputs(x=x, y=None) 740 batch_size, steps = dist_utils.process_batch_and_step_size( 741 model._distribution_strategy, x, batch_size, steps, ModeKeys.PREDICT) 742 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 743 dataset = model._distribution_standardize_user_data( 744 x, 745 batch_size=batch_size, 746 allow_partial_batch=True) 747 if K.is_tpu_strategy(model._distribution_strategy): 748 steps = training_utils_v1.infer_steps_for_dataset( 749 model, dataset, steps, steps_name='steps') 750 if steps is None: 751 raise ValueError('Number of steps could not be inferred from the data, ' 752 'please pass the steps argument.') 753 if not context.executing_eagerly(): 754 return experimental_tpu_predict_loop( 755 model, dataset, verbose=verbose, steps=steps, callbacks=callbacks) 756 return training_arrays_v1.predict_loop( 757 model, 758 dataset, 759 batch_size=batch_size, 760 verbose=verbose, 761 steps=steps, 762 callbacks=callbacks) 763 764 765def _train_with_multi_worker(method): 766 """Decorator that handles multi worker training with distribution strategy.""" 767 768 def wrapper(model, **kwargs): 769 def _worker_fn(_): 770 callbacks = kwargs.pop('callbacks', None) 771 filtered_callbacks = dist_utils.filter_distributed_callbacks( 772 callbacks, model) 773 kwargs['callbacks'] = filtered_callbacks 774 return method(model, **kwargs) 775 776 return dc.run_distribute_coordinator( 777 _worker_fn, 778 model._distribution_strategy, 779 mode='independent_worker') 780 781 return wrapper 782 783 784class DistributionMultiWorkerTrainingLoop(training_utils_v1.TrainingLoop): 785 """Training loop for distribution strategy with multiple worker.""" 786 787 def __init__(self, single_worker_loop): 788 self._single_worker_loop = single_worker_loop 789 790 def fit(self, *args, **kwargs): 791 return _train_with_multi_worker(self._single_worker_loop.fit)( 792 *args, **kwargs) 793 794 def evaluate(self, *args, **kwargs): 795 return _train_with_multi_worker(self._single_worker_loop.evaluate)( 796 *args, **kwargs) 797 798 def predict(self, *args, **kwargs): 799 # Currently predict is still using the single worker implementation. 800 return self._single_worker_loop.predict(*args, **kwargs) 801