1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities related to distributed training.""" 16# pylint:disable=protected-access 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22 23import numpy as np 24 25from tensorflow.python.data.ops import dataset_ops 26from tensorflow.python.data.ops import iterator_ops 27from tensorflow.python.distribute import distribute_coordinator_context as dc_context 28from tensorflow.python.distribute import distribution_strategy_context as ds_context 29from tensorflow.python.distribute import multi_worker_util 30from tensorflow.python.distribute import reduce_util 31from tensorflow.python.eager import context 32from tensorflow.python.eager import def_function 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import ops 35from tensorflow.python.framework import sparse_tensor 36from tensorflow.python.framework import tensor_util 37from tensorflow.python.keras import backend as K 38from tensorflow.python.keras import callbacks 39from tensorflow.python.keras import metrics as metrics_module 40from tensorflow.python.keras import optimizers 41from tensorflow.python.keras.engine import training_utils 42from tensorflow.python.keras.optimizer_v2 import optimizer_v2 43from tensorflow.python.keras.utils.mode_keys import ModeKeys 44from tensorflow.python.ops import control_flow_ops 45from tensorflow.python.ops import math_ops 46from tensorflow.python.ops import sparse_ops 47from tensorflow.python.ops import variables 48from tensorflow.python.ops.ragged import ragged_concat_ops 49from tensorflow.python.ops.ragged import ragged_tensor 50from tensorflow.python.platform import tf_logging as logging 51from tensorflow.python.util import nest 52from tensorflow.python.util import tf_contextlib 53 54 55def set_weights(distribution_strategy, dist_model, weights): 56 """Sets the weights of the replicated models. 57 58 The weights of the replicated models are set to the weights of the original 59 model. The weights of the replicated model are Mirrored variables and hence 60 we need to use the `update` call within a DistributionStrategy scope. 61 62 Args: 63 distribution_strategy: DistributionStrategy used to distribute training 64 and validation. 65 dist_model: The replicated models on the different devices. 66 weights: The weights of the original model. 67 """ 68 assign_ops = [] 69 for layer in dist_model.layers: 70 num_param = len(layer.weights) 71 layer_weights = weights[:num_param] 72 for sw, w in zip(layer.weights, layer_weights): 73 if ops.executing_eagerly_outside_functions(): 74 sw.assign(w) 75 else: 76 assign_ops.append(distribution_strategy.unwrap(sw.assign(w))) 77 weights = weights[num_param:] 78 79 if not ops.executing_eagerly_outside_functions(): 80 K.get_session(assign_ops).run(assign_ops) 81 82 83def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs, 84 grouped_updates=None, grouped_session_args=None, 85 with_loss_tensor=False): 86 """Unwrap the list of values contained in the PerReplica parameters. 87 88 This function calls `flatten_per_replica_values` to parse each of the input 89 parameters into a list of values on the different devices. If we set 90 `with_loss_tensor` to be True, we also call `reduce` on the list of losses on 91 the different devices to give us one loss tensor. 92 93 Args: 94 distribution_strategy: DistributionStrategy used to distribute training and 95 validation. 96 grouped_inputs: PerReplica inputs returned from the train or test function 97 that we ran on each device. 98 grouped_outputs: PerReplica outputs returned from the train or test function 99 that we ran on each device. 100 grouped_updates: PerReplica updates returned from the train or test function 101 that we ran on each device. 102 grouped_session_args: PerReplica session args returned from the train or 103 test function that we ran on each device. 104 with_loss_tensor: Boolean that indicates if we need to add the reduced loss 105 tensor as one of the outputs. 106 107 Returns: 108 Values of each of the PerReplica parameters. 109 110 """ 111 # Unwrap per device values returned from each model's train function. 112 # This will be used to construct the main train function. 113 all_inputs = flatten_per_replica_values(distribution_strategy, 114 grouped_inputs) 115 all_outputs = unwrap_outputs(distribution_strategy, grouped_outputs, 116 with_loss_tensor) 117 118 if grouped_updates: 119 all_updates = flatten_per_replica_values(distribution_strategy, 120 grouped_updates) 121 else: 122 all_updates = None 123 124 all_session_args = {} 125 if grouped_session_args: 126 grouped_feed_dict = grouped_session_args.get('feed_dict') 127 if grouped_feed_dict: 128 all_session_args['feed_dict'] = flatten_per_replica_values( 129 distribution_strategy, grouped_feed_dict) 130 131 grouped_fetches = grouped_session_args.get('fetches') 132 if grouped_fetches: 133 all_session_args['fetches'] = flatten_per_replica_values( 134 distribution_strategy, grouped_fetches) 135 136 # TODO(priyag): Return only non empty/None values 137 return all_inputs, all_outputs, all_updates, all_session_args 138 139 140def unwrap_output_dict(strategy, grouped_outputs, mode): 141 """Unwrap the list of outputs contained in the PerReplica parameters.""" 142 if mode == ModeKeys.PREDICT: 143 return flatten_per_replica_values(strategy, grouped_outputs) 144 145 # In the case of fit/eval, the grouped_outputs is a dict, whereas in predict, 146 # the output is as same structure as model output. They need to be treated 147 # differently 148 total_loss = strategy.reduce(reduce_util.ReduceOp.SUM, 149 grouped_outputs['total_loss'][0], axis=None) 150 output_losses = flatten_per_replica_values(strategy, 151 grouped_outputs['output_losses']) 152 metrics = flatten_per_replica_values(strategy, 153 grouped_outputs['metrics']) 154 batch_size = strategy.reduce(reduce_util.ReduceOp.SUM, 155 grouped_outputs['batch_size'], axis=None) 156 if (is_tpu_strategy(strategy) and 157 ops.executing_eagerly_outside_functions()): 158 # Choose 1 value per replica in the TPU case since all replicas produce the 159 # same output. 160 # We only do this in eager mode for now since this function is used in 161 # both graph and eager mode and in the graph case we currently don't use 162 # experimental_run so would need to be removed when we converge the graph 163 # code path as well. 164 output_losses = output_losses[::strategy.num_replicas_in_sync] 165 metrics = metrics[::strategy.num_replicas_in_sync] 166 return {'total_loss': [total_loss], 167 'output_losses': output_losses, 168 'metrics': metrics, 169 'batch_size': batch_size} 170 171 172def unwrap_outputs(distribution_strategy, grouped_outputs, 173 with_loss_tensor=False): 174 """Unwrap the list of outputs contained in the PerReplica parameters. 175 176 This function calls `flatten_per_replica_values` to parse each of the input 177 parameters into a list of outputs on the different devices. If we set 178 `with_loss_tensor` to be True, we also call `reduce` on the list of losses on 179 the different devices to give us one loss tensor. 180 181 Args: 182 distribution_strategy: DistributionStrategy used to distribute training and 183 validation. 184 grouped_outputs: PerReplica outputs returned from the train or test function 185 that we ran on each device. 186 with_loss_tensor: Boolean that indicates if we need to add the reduced loss 187 tensor as one of the outputs. 188 189 Returns: 190 Values of each of the PerReplica outputs. 191 192 """ 193 if not with_loss_tensor: 194 return flatten_per_replica_values(distribution_strategy, 195 grouped_outputs) 196 197 if not isinstance(grouped_outputs, list): 198 grouped_outputs = [grouped_outputs] 199 # reduce loss tensor before adding it to the list of fetches 200 loss = distribution_strategy.reduce(reduce_util.ReduceOp.SUM, 201 grouped_outputs[0], axis=None) 202 all_outputs = flatten_per_replica_values(distribution_strategy, 203 grouped_outputs[1:]) 204 if (is_tpu_strategy(distribution_strategy) and 205 ops.executing_eagerly_outside_functions()): 206 # Choose 1 value per replica in the TPU case since all replicas produce the 207 # same output. 208 # We only do this in eager mode for now since this function is used in 209 # both graph and eager mode and in the graph case we currently don't use 210 # experimental_run so would need to be removed when we converge the graph 211 # code path as well. 212 all_outputs = all_outputs[::distribution_strategy.num_replicas_in_sync] 213 return [loss] + all_outputs 214 215 216def flatten_per_replica_values(distribution_strategy, per_replica_values): 217 """Unwraps and flattens a nest of PerReplica parameters. 218 219 PerReplica values have one value associated with each device. Each entry in 220 the PerReplica dict has a device `key` and the corresponding value on the 221 device as the `value`. In this function we take a PerReplica value or a list 222 of PerReplica values and return all the values in the PerReplica dict. 223 224 Args: 225 distribution_strategy: DistributionStrategy used to distribute training and 226 validation. 227 per_replica_values: List of PerReplica object or a single PerReplica object. 228 229 Returns: 230 List of values of all the PerReplica objects. 231 232 """ 233 # pylint: disable=g-complex-comprehension 234 # This function takes a PerReplica object or a list of PerReplica objects and 235 # returns all the values associated with it. 236 return [e for flattened in nest.flatten(per_replica_values) 237 for e in distribution_strategy.unwrap(flattened)] 238 239 240def validate_callbacks(input_callbacks, optimizer): 241 """Validate whether given callbacks are supported by DistributionStrategy. 242 243 Args: 244 input_callbacks: List of callbacks passed by the user to fit. 245 optimizer: Optimizer instance used to train the model. 246 247 Raises: 248 ValueError: If `LearningRateScheduler` or `ReduceLROnPlateau` is one of the 249 callbacks passed. 250 ValueError: If `write_grads` is one of the parameters passed as part of the 251 TensorBoard callback. 252 """ 253 if input_callbacks: 254 for callback in input_callbacks: 255 if isinstance(callback, (callbacks.LearningRateScheduler, 256 callbacks.ReduceLROnPlateau)): 257 258 if not isinstance(optimizer, optimizer_v2.OptimizerV2): 259 raise ValueError('You must specify a Keras Optimizer V2 when using ' 260 '%s callback with DistributionStrategy.' % callback) 261 262 # If users want to use the TensorBoard callback they cannot use certain 263 # features of the callback that involve accessing model attributes and 264 # running ops. 265 if isinstance(callback, callbacks.TensorBoard): 266 if getattr(callback, 'write_grads', False): 267 logging.warning( 268 UserWarning( 269 '`write_grads` in the TensorBoard callback is not supported ' 270 'when using DistributionStrategy. Setting `write_grads` ' 271 'to `False`.')) 272 callback.write_grads = False 273 274 275def validate_distributed_dataset_inputs(distribution_strategy, x, y, 276 sample_weights=None): 277 """Validate all the components of a DistributedValue Dataset input. 278 279 Args: 280 distribution_strategy: The current DistributionStrategy used to call 281 `fit`/`evaluate`. 282 x: Input Dataset DistributedValue object. For example, when we use 283 `MirroredStrategy` this is a PerReplica object with a tensor for each 284 device set in the dict. x can also be a tuple or dict. The keys of the 285 dict should match the names of the input layers of the model. 286 y: Target Dataset DistributedValue object. For example, when we use 287 `MirroredStrategy` this is a PerReplica object with a tensor for each 288 device set in the dict. y can also be a tuple or dict. The keys of the 289 dict should match the names of the output layers of the model. 290 sample_weights: Sample weights Dataset DistributedValue object. For example, 291 when we use `MirroredStrategy` this is a PerReplica object with a tensor 292 for each device set in the dict. 293 294 Returns: 295 The unwrapped values list of the x and y DistributedValues inputs. 296 297 Raises: 298 ValueError: If x and y do not have support for being evaluated as tensors. 299 or if x and y contain elements that are not tensors or if x and y 300 contain elements that have a shape or dtype mismatch. 301 """ 302 # If the input and target used to call the model are not dataset tensors, 303 # we need to raise an error. When using a DistributionStrategy, the input 304 # and targets to a model should be from a `tf.data.Dataset`. 305 306 # If each element of x and y are not tensors, we cannot standardize and 307 # validate the input and targets. 308 x_values_list = validate_per_replica_inputs(distribution_strategy, x) 309 310 if y is not None: 311 y_values_list = validate_per_replica_inputs(distribution_strategy, y) 312 else: 313 y_values_list = None 314 315 if sample_weights is not None: 316 sample_weights_list = validate_per_replica_inputs(distribution_strategy, 317 sample_weights) 318 else: 319 sample_weights_list = None 320 321 # Return the unwrapped values to avoid calling `unwrap` a second time. 322 return x_values_list, y_values_list, sample_weights_list 323 324 325def validate_per_replica_inputs(distribution_strategy, x): 326 """Validates PerReplica dataset input list. 327 328 Args: 329 distribution_strategy: The current DistributionStrategy used to call 330 `fit`, `evaluate` and `predict`. 331 x: A list of PerReplica objects that represent the input or 332 target values. 333 334 Returns: 335 List containing the first element of each of the PerReplica objects in 336 the input list. 337 338 Raises: 339 ValueError: If any of the objects in the `per_replica_list` is not a tensor. 340 341 """ 342 # Convert the inputs and targets into a list of PerReplica objects. 343 per_replica_list = nest.flatten(x, expand_composites=True) 344 x_values_list = [] 345 for x in per_replica_list: 346 if not tensor_util.is_tensor(x): 347 raise ValueError('Dataset input to the model should be tensors instead ' 348 'they are of type {}'.format(type(x))) 349 350 # At this point both x and y contain tensors in the `DistributedValues` 351 # structure. 352 x_values = distribution_strategy.unwrap(x) 353 354 if not context.executing_eagerly(): 355 # Validate that the shape and dtype of all the elements in x are the same. 356 validate_all_tensor_shapes(x, x_values) 357 validate_all_tensor_types(x, x_values) 358 359 x_values_list.append(x_values[0]) 360 return x_values_list 361 362 363def validate_all_tensor_types(x, x_values): 364 x_dtype = x_values[0].dtype 365 for i in range(1, len(x_values)): 366 if x_dtype != x_values[i].dtype: 367 raise ValueError('Input tensor dtypes do not match for distributed tensor' 368 ' inputs {}'.format(x)) 369 370 371def validate_all_tensor_shapes(x, x_values): 372 # Validate that the shape of all the elements in x have the same shape 373 x_shape = x_values[0].shape.as_list() 374 for i in range(1, len(x_values)): 375 if x_shape != x_values[i].shape.as_list(): 376 raise ValueError('Input tensor shapes do not match for distributed tensor' 377 ' inputs {}'.format(x)) 378 379 380def _wait_for_variable_initialization(session): 381 """Utility to wait for variables to be initialized.""" 382 all_variables = K._get_variables(K.get_graph()) # pylint: disable=protected-access 383 candidate_vars = [] 384 for v in all_variables: 385 if not getattr(v, '_keras_initialized', False): 386 candidate_vars.append(v) 387 388 if not candidate_vars: 389 return 390 391 while True: 392 is_initialized = session.run( 393 [variables.is_variable_initialized(v) for v in candidate_vars]) 394 uninitialized_vars = [] 395 for flag, v in zip(is_initialized, candidate_vars): 396 if not flag: 397 uninitialized_vars.append(v) 398 v._keras_initialized = True # pylint: disable=protected-access 399 if not uninitialized_vars: 400 break 401 402 403def init_restore_or_wait_for_variables(): 404 """Initialize or restore variables or wait for variables to be initialized.""" 405 session = K._get_session() # pylint: disable=protected-access 406 if not multi_worker_util.has_worker_context( 407 ) or multi_worker_util.should_load_checkpoint(): 408 # TODO(yuefengz): if checkpoints exist, restore from checkpoint. 409 K._initialize_variables(session) # pylint: disable=protected-access 410 else: 411 _wait_for_variable_initialization(session) 412 413 414def validate_inputs(x, y): 415 """Validate inputs when using DistributionStrategy. 416 417 Args: 418 x: Model Inputs. 419 y: Model Targets. 420 421 Raises: 422 ValueError: if input is not a Dataset or a numpy array(when we use 423 MirroredStrategy). 424 """ 425 if (isinstance(x, iterator_ops.Iterator) or 426 isinstance(y, iterator_ops.Iterator)): 427 raise ValueError('`DistributionStrategy` does not support inputs of type ' 428 'Iterator. You must pass a `tf.data.Dataset` object or a ' 429 'numpy array as input.') 430 431 432# TODO(b/118776054): Currently we support global batch size for TPUStrategy and 433# core MirroredStrategy only. Remove this check when contrib MirroredStrategy is 434# no longer needed. 435def global_batch_size_supported(distribution_strategy): 436 return distribution_strategy.extended._global_batch_size # pylint: disable=protected-access 437 438 439# TODO(sourabhbajaj): Remove this once we use the same API for all strategies. 440def is_tpu_strategy(strategy): 441 """We're executing TPU Strategy.""" 442 return (strategy is not None and 443 strategy.__class__.__name__.startswith('TPUStrategy')) 444 445 446def is_dataset_shape_fully_defined(dataset): 447 """Returns whether a dataset contains a final partial batch.""" 448 shapes = nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)) 449 unknown_shapes = [s for s in shapes if not s.is_fully_defined()] 450 return not unknown_shapes 451 452 453def process_batch_and_step_size(strategy, 454 inputs, 455 batch_size, 456 steps_per_epoch, 457 mode, 458 validation_split=0.): 459 """Process the batch size and step size based on input and dist strategy.""" 460 first_x_value = nest.flatten(inputs)[0] 461 if isinstance(first_x_value, np.ndarray): 462 num_samples = first_x_value.shape[0] 463 if validation_split and 0. < validation_split < 1.: 464 num_samples = int(num_samples * (1 - validation_split)) 465 # Until support for partial batch is implemented across all 466 # functions and distribution strategy, we pass `mode` to selectively 467 # relax the constraint to consume all the training samples. 468 steps_per_epoch, batch_size = get_input_params( 469 strategy, num_samples, steps_per_epoch, batch_size, mode=mode) 470 return batch_size, steps_per_epoch 471 472 473def get_input_params(distribution_strategy, 474 num_samples, 475 steps, 476 batch_size, 477 mode=None): 478 """Calculate the number of batches and steps/steps_per_epoch. 479 480 Args: 481 distribution_strategy: The DistributionStrategy used to compile the model. 482 num_samples: The number of samples from which we determine the batch size 483 and steps. 484 steps: The specified number of steps. 485 batch_size: The specified batch_size. 486 mode: ModeKey representing whether input will be used for training, 487 evaluation, or prediction. This is used to relax the constraints on 488 consuming all the training samples to keep compatibility till we support 489 partial batches. If none, then partial batches are not allowed. 490 491 Returns: 492 steps: The steps or steps_per_epoch argument depending on if a user is 493 calling `fit`, `evaluate` or `predict`. If the is_training flag is set 494 we don't require the number of samples to be used completely. 495 batch_size: The batch size to be used in model iterations. 496 497 Raises: 498 ValueError: If the number of batches or steps evaluates to 0. 499 500 """ 501 # TODO(b/118776054): Use global batch size for Keras/DS support. 502 # Currently this is only supported in TPUStrategy and CoreMirroredStrategy. 503 use_per_replica_batch = not global_batch_size_supported( 504 distribution_strategy) 505 506 # TODO(b/128995245): In eager mode, uneven batch sizes are allowed except for 507 # `fit()` on TPUStrategy. 508 # In graph mode, the zero batch case in batch norm is not handled due to 509 # XLA-GPU regression. Uneven batch sizes are not allowed except 510 # for `test()` and `predict()` on TPUStrategy. 511 if context.executing_eagerly(): 512 allow_partial_batch = (mode != ModeKeys.TRAIN or 513 not is_tpu_strategy(distribution_strategy)) 514 else: 515 allow_partial_batch = (mode == ModeKeys.TRAIN or 516 ((mode == ModeKeys.PREDICT or mode == ModeKeys.TEST) 517 and is_tpu_strategy(distribution_strategy))) 518 519 if steps is None: 520 if batch_size is None: 521 # If neither the batch size or number of steps are set. We choose the 522 # global batch size as the minimum of number of samples and 32. 32 is 523 # chosen to provide backward compatibility. 524 global_batch_size = min(num_samples, 32) 525 else: 526 # If the user provided the batch size we need to handle the case 527 # between different strategies that use the global/per-replica batch size 528 global_batch_size = batch_size 529 if use_per_replica_batch: 530 global_batch_size *= distribution_strategy.num_replicas_in_sync 531 if allow_partial_batch: 532 steps = np.ceil(num_samples / global_batch_size).astype(int) 533 else: 534 if num_samples % global_batch_size: 535 raise ValueError('The number of samples %s is not divisible by ' 536 'batch size %s.' % (num_samples, global_batch_size)) 537 steps = num_samples // global_batch_size 538 else: 539 if batch_size is None: 540 # We calculate the batch size based on the number of steps specified 541 if num_samples % steps: 542 raise ValueError('The number of samples %s is not divisible by ' 543 'steps %s. Please change the number of steps to a ' 544 'value that can consume all the samples' % ( 545 num_samples, steps)) 546 global_batch_size = num_samples // steps 547 else: 548 # If the user provided the batch size we need to handle the case 549 # between different strategies that use the global/per-replica batch size 550 global_batch_size = batch_size 551 if use_per_replica_batch: 552 global_batch_size *= distribution_strategy.num_replicas_in_sync 553 554 min_num_samples = global_batch_size * steps 555 if allow_partial_batch: 556 min_num_samples = global_batch_size * (steps-1) + 1 if steps > 1 else 0 557 558 if num_samples < min_num_samples: 559 raise ValueError('Number of samples %s is less than samples required ' 560 'for specified batch_size %s and steps %s' % ( 561 num_samples, global_batch_size, steps)) 562 563 # We need to return the per replica or global batch size based on the strategy 564 if use_per_replica_batch: 565 if global_batch_size % distribution_strategy.num_replicas_in_sync: 566 raise ValueError( 567 'The batch size (%s) could not be sharded evenly across the sync ' 568 'replicas (%s) in the distribution strategy.' % ( 569 global_batch_size, distribution_strategy.num_replicas_in_sync)) 570 batch_size = global_batch_size // distribution_strategy.num_replicas_in_sync 571 else: 572 batch_size = global_batch_size 573 574 return steps, batch_size 575 576 577def get_batch_dimension(iterator): 578 shapes = nest.flatten(dataset_ops.get_legacy_output_shapes(iterator)) 579 # Take the batch size from the first element, as it should be the same for 580 # all. 581 dims = shapes[0].dims 582 return dims[0] if dims else None 583 584 585def get_iterator(dataset, distribution_strategy): 586 with distribution_strategy.scope(): 587 iterator = distribution_strategy.make_dataset_iterator(dataset) 588 initialize_iterator(iterator, distribution_strategy) 589 return iterator 590 591 592def initialize_iterator(iterator, distribution_strategy): 593 with distribution_strategy.scope(): 594 init_op = control_flow_ops.group(iterator.initializer) 595 if not context.executing_eagerly(): 596 K.get_session((init_op,)).run(init_op) 597 598 599def _get_input_from_iterator(iterator, model): 600 """Get elements from the iterator and verify the input shape and type.""" 601 next_element = iterator.get_next() 602 603 # `len(nest.flatten(x))` is going to not count empty elements such as {}. 604 # len(nest.flatten([[0,1,2], {}])) is 3 and not 4. The `next_element` is 605 # going to get flattened in `_prepare_feed_values` to work around that. Empty 606 # elements are going to get filtered out as part of the flattening. 607 if len(nest.flatten(next_element)) == len(model.inputs): 608 x = next_element 609 y = None 610 sample_weights = None 611 elif len(nest.flatten(next_element)) == (len(model.inputs) + 612 len(model.outputs)): 613 x, y = next_element 614 sample_weights = None 615 else: 616 x, y, sample_weights = next_element 617 618 # Validate that all the elements in x and y are of the same type and shape. 619 validate_distributed_dataset_inputs( 620 model._distribution_strategy, x, y, sample_weights) 621 return x, y, sample_weights 622 623 624def _prepare_feed_values(model, inputs, targets, sample_weights, mode): 625 """Prepare feed values to the model execution function. 626 627 Arguments: 628 model: Model to prepare feed values for. 629 inputs: List or dict of model inputs. 630 targets: Optional list of model targets. 631 sample_weights: Optional list of sample weight arrays. 632 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 633 634 Returns: 635 Feed values for the model in the given mode. 636 """ 637 strategy = model._distribution_strategy 638 inputs, targets, sample_weights = _get_input_from_iterator(inputs, model) 639 if is_tpu_strategy(strategy): 640 if sample_weights is not None: 641 raise ValueError('TPUStrategy does not support sample weights.') 642 643 # When the inputs are dict, then we want to flatten it in the same order as 644 # the input layers, such that the data are fed into the input layers in the 645 # correct order. 646 if isinstance(inputs, dict): 647 inputs = [inputs[key] for key in model._feed_input_names] 648 if is_distributing_by_cloning(model): 649 inputs = flatten_per_replica_values(strategy, inputs) 650 targets = flatten_per_replica_values(strategy, targets) 651 # Expand 1-dimensional inputs. 652 # TODO(b/124535720): Remove once this standarize data logic is shared with 653 # main flow. 654 inputs, targets = nest.map_structure( 655 training_utils.standardize_single_array, (inputs, targets)) 656 else: 657 inputs = training_utils.ModelInputs(inputs).as_list() 658 659 if mode == ModeKeys.PREDICT: 660 sample_weights = [] 661 targets = [] 662 elif sample_weights is not None and is_distributing_by_cloning(model): 663 if context.executing_eagerly() and not model._compile_distribution: 664 raise NotImplementedError('`sample_weight` is not supported when using ' 665 'tf.distribute.Strategy in eager mode and ' 666 'cloning=True.') 667 sample_weights = flatten_per_replica_values(strategy, sample_weights) 668 669 ins = [inputs, targets, sample_weights] 670 return tuple(ins) 671 672 673def is_distributing_by_cloning(model): 674 """Decide whether this model is going to be distributed via cloning. 675 676 We are going to distribute the model by cloning in graph mode. 677 678 Args: 679 model: Keras model to distribute. 680 681 Returns: 682 True if the `model` is going to be distributed using cloning and False 683 otherwise. 684 """ 685 if (is_tpu_strategy(model._distribution_strategy) and 686 context.executing_eagerly): # b/137580852 687 return False 688 elif ops.executing_eagerly_outside_functions(): 689 return bool(model._compile_distribution) 690 return True 691 692 693def _custom_compile_for_predict(model): 694 """Custom compile for TPU predict mode.""" 695 if not model.built: 696 # Model is not compilable because it does not know its number of inputs 697 # and outputs, nor their shapes and names. We will compile after the first 698 # time the model gets called on training data. 699 return 700 model._is_compiled = True 701 model.total_loss = None 702 model.train_function = None 703 model.test_function = None 704 model.predict_function = None 705 706 707def _build_network_on_replica(model, mode, inputs=None, targets=None): 708 """Build an updated model on replicas. 709 710 We create a new Keras model while sharing the variables from the old graph. 711 Building a new sub-graph is required since the original keras model creates 712 placeholders for the input and the output that are not accessible till we 713 call iterator.get_next() inside the step_fn for `fit`/`evaluate`/`predict`. 714 715 The sharing of weights and layers between the old and the new model gaurantee 716 that we're using Strategy variables and any updates on either model are 717 reflected correctly in callbacks and loop iterations. 718 719 We need to make sure we share the optimizers between the old and the new model 720 as well so that optimizer state is not lost if the user is running fit 721 multiple times. 722 723 Args: 724 model: Model to be replicated across Replicas 725 mode: Which of fit/eval/predict is building the distributed network 726 inputs: Input variables to be passed to the model 727 targets: Target tensor to be passed to model.compile 728 729 Returns: 730 A new model with shared layers with the old model. 731 """ 732 # Need to do imports here since we run into a circular dependency error. 733 from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top 734 from tensorflow.python.keras.engine import sequential # pylint: disable=g-import-not-at-top 735 736 # We rely on the internal methods to avoid having share_weights weights in the 737 # public API. 738 if isinstance(model, sequential.Sequential): 739 updated_model = models._clone_sequential_model( 740 model, input_tensors=inputs, layer_fn=models.share_weights) 741 else: 742 updated_model = models._clone_functional_model( 743 model, input_tensors=inputs, layer_fn=models.share_weights) 744 # Callable losses added directly to a functional Model need to be added 745 # here. 746 updated_model._callable_losses = model._callable_losses 747 748 # Recast all low precision outputs back to float32 since we only casted 749 # the inputs to bfloat16 and not targets. This is done so that we can preserve 750 # precision when calculating the loss value. 751 def _upcast_low_precision_outputs(output): 752 if output.dtype == dtypes.bfloat16: 753 return math_ops.cast(output, dtypes.float32) 754 else: 755 return output 756 updated_model.outputs = [_upcast_low_precision_outputs(o) 757 for o in updated_model.outputs] 758 759 if isinstance(targets, tuple): 760 targets = nest.flatten(targets) 761 762 if mode == ModeKeys.PREDICT and inputs is not None: # TPU predict case 763 _custom_compile_for_predict(updated_model) 764 else: 765 updated_model.compile( 766 model.optimizer, 767 model.loss, 768 metrics=metrics_module.clone_metrics(model._compile_metrics), 769 loss_weights=model.loss_weights, 770 sample_weight_mode=model.sample_weight_mode, 771 weighted_metrics=metrics_module.clone_metrics( 772 model._compile_weighted_metrics), 773 target_tensors=targets) 774 return updated_model 775 776 777def _build_distributed_network(model, strategy, mode, inputs=None, 778 targets=None): 779 """Create a cloned model on each replica.""" 780 with K.get_graph().as_default(), strategy.scope(): 781 distributed_model = strategy.extended.call_for_each_replica( 782 _build_network_on_replica, 783 args=(model, mode, inputs, targets)) 784 set_distributed_model(model, mode, distributed_model) 785 786 787def _clone_and_build_model(model, mode, inputs=None, targets=None): 788 """Clone and build the given keras_model.""" 789 # We need to set the import here since we run into a circular dependency 790 # error. 791 from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top 792 cloned_model = models.clone_model(model, input_tensors=inputs) 793 794 # Compile and build model. 795 if isinstance(model.optimizer, optimizers.TFOptimizer): 796 optimizer = model.optimizer 797 else: 798 optimizer_config = model.optimizer.get_config() 799 optimizer = model.optimizer.__class__.from_config(optimizer_config) 800 801 # Recast all low precision outputs back to float32 since we only casted 802 # the inputs to bfloat16 and not targets. This is done so that we can preserve 803 # precision when calculating the loss value. 804 def _upcast_low_precision_outputs(output): 805 if output.dtype == dtypes.bfloat16: 806 return math_ops.cast(output, dtypes.float32) 807 else: 808 return output 809 cloned_model.outputs = [_upcast_low_precision_outputs(o) 810 for o in cloned_model.outputs] 811 812 if isinstance(targets, tuple): 813 targets = nest.flatten(targets) 814 if mode == ModeKeys.PREDICT and inputs is not None: # TPU predict case 815 _custom_compile_for_predict(cloned_model) 816 else: 817 cloned_model.compile( 818 optimizer, 819 model.loss, 820 metrics=metrics_module.clone_metrics(model._compile_metrics), 821 loss_weights=model.loss_weights, 822 sample_weight_mode=model.sample_weight_mode, 823 weighted_metrics=metrics_module.clone_metrics( 824 model._compile_weighted_metrics), 825 target_tensors=targets) 826 return cloned_model 827 828 829def clone_model_on_replicas(model, strategy, mode, inputs=None, targets=None): 830 """Create a cloned model on each replica.""" 831 with K.get_graph().as_default(), strategy.scope(): 832 distributed_model = strategy.extended.call_for_each_replica( 833 _clone_and_build_model, args=(model, mode, inputs, targets)) 834 set_distributed_model(model, mode, distributed_model) 835 if mode == ModeKeys.TRAIN: 836 model._make_callback_model(distributed_model) 837 838 839def _make_execution_function(model, mode): 840 """Makes or reuses function to run one step of distributed model execution.""" 841 if is_distributing_by_cloning(model): 842 return _make_execution_function_with_cloning(model, mode) 843 844 distributed_function = get_distributed_function(model, mode) 845 if distributed_function: 846 return distributed_function 847 848 distribution_function = _make_execution_function_without_cloning(model, mode) 849 set_distributed_function(model, mode, distribution_function) 850 return distribution_function 851 852 853def _make_execution_function_without_cloning(model, mode): 854 """Creates a function to run one step of distributed model execution.""" 855 strategy = model._distribution_strategy 856 857 with strategy.scope(): 858 per_replica_function = _make_replica_execution_function(model, mode) 859 860 def distributed_function(input_fn): 861 """A single step of the distributed execution across replicas.""" 862 x, y, sample_weights = input_fn() 863 # Call `Model.{train,test,predict}_on_batch` on every replica passing 864 # PerReplicas as arguments. On every replica inside this call, each 865 # PerReplica object will return the value for that replica. The outputs 866 # are PerReplicas too. 867 outputs = strategy.experimental_run_v2( 868 per_replica_function, args=(x, y, sample_weights)) 869 # Out of PerReplica outputs reduce or pick values to return. 870 all_outputs = unwrap_outputs( 871 strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT)) 872 return all_outputs 873 874 if not model.run_eagerly: 875 distributed_function = def_function.function(distributed_function) 876 def execution_function(input_fn): 877 # `numpy` translates Tensors to values in Eager mode. 878 return [out.numpy() for out in distributed_function(input_fn)] 879 else: 880 execution_function = distributed_function 881 882 return execution_function 883 884 885def _make_replica_execution_function(model, mode): 886 """A single step of the distributed execution on a replica.""" 887 if mode == ModeKeys.TRAIN: 888 func = model.train_on_batch 889 elif mode == ModeKeys.TEST: 890 func = model.test_on_batch 891 else: 892 893 def predict_on_batch(x, y=None, sample_weights=None): 894 del y, sample_weights 895 return model.predict_on_batch(x) 896 897 func = predict_on_batch 898 899 if mode != ModeKeys.PREDICT: 900 # `reset_metrics` is set to False to maintain stateful metrics across 901 # batch-level calls. 902 func = functools.partial(func, reset_metrics=False) 903 904 return func 905 906 907def _make_replicated_models_with_cloning(model, mode): 908 """Build models on each replica.""" 909 strategy = model._distribution_strategy 910 911 # If distributed_model is not built, create one for `mode`. 912 if model._compile_distribution: 913 clone_model_on_replicas(model, strategy, mode) 914 else: 915 _build_distributed_network(model, strategy, mode) 916 917 918def _make_execution_function_with_cloning(model, mode): 919 """Clones or re-uses models to run one step of distributed model execution.""" 920 distributed_model = get_distributed_model(model, mode) 921 # TODO(b/134069401): Create a cache for the distributed model and exec 922 # function that incorporates additional attributes to be part of the cache key 923 # than just the mode. 924 # If distributed model for a particular `mode` is already built, use the 925 # `_distribution_function` on that distributed model. 926 # If you have updated the sample_weight_mode on the model, then you will need 927 # to recompile metrics and recreate the execution function. This is indicated 928 # by the `_recompile_exec_function` property. 929 if (distributed_model and hasattr(distributed_model, '_distribution_function') 930 and not (hasattr(distributed_model, '_recompile_exec_function') and 931 distributed_model._recompile_exec_function)): 932 return distributed_model._distributed_function 933 934 if not distributed_model: 935 _make_replicated_models_with_cloning(model, mode) 936 distributed_model = get_distributed_model(model, mode) 937 assert distributed_model 938 939 # Also create an execution fuction on that distributed model. 940 if context.executing_eagerly(): 941 distributed_function = _make_eager_execution_function(model, mode) 942 else: 943 distributed_function = _make_graph_execution_function(model, mode) 944 945 # We cache the distributed execution function on the model since creating 946 # distributed models and execution functions are expensive. 947 distributed_model._distributed_function = distributed_function 948 distributed_model._recompile_exec_function = False 949 return distributed_function 950 951 952def _make_graph_execution_function(model, mode): 953 """Makes function to run one step of distributed model in graph mode.""" 954 955 def _per_replica_function(model): 956 f = model._make_execution_function(mode) 957 return (f.inputs, f.outputs, f.updates_op, f.session_kwargs) 958 959 strategy = model._distribution_strategy 960 with strategy.scope(): 961 # Create train ops on each of the devices when we call 962 # `_per_replica_fit_function`. 963 (grouped_inputs, grouped_outputs, grouped_updates, 964 grouped_session_args) = strategy.extended.call_for_each_replica( 965 _per_replica_function, args=(get_distributed_model(model, mode),)) 966 967 # Initialize the variables in the replicated model. This is necessary for 968 # multi-worker training because on some workers, initialization is not 969 # needed. This method does initialization or waiting for initialization 970 # according to the context object of distribute coordinator. 971 init_restore_or_wait_for_variables() 972 973 # Unwrap all the per device values returned from `call_for_each_replica`. 974 # Unwrapping per device values gives you a list of values that can be 975 # used to construct a new train function that is composed of update ops on 976 # all the devices over which the model is distributed. 977 (all_inputs, all_outputs, all_updates, all_session_args) = unwrap_values( 978 strategy, 979 grouped_inputs, 980 grouped_outputs, 981 grouped_updates, 982 grouped_session_args, 983 with_loss_tensor=(mode != ModeKeys.PREDICT)) 984 985 return K.function( 986 all_inputs, 987 all_outputs, 988 updates=all_updates, 989 name='distributed_{}_function'.format(mode), 990 **all_session_args) 991 992 993def _make_eager_execution_function(model, mode): 994 """Makes function to run one step of distributed model eager execution.""" 995 def _per_replica_function(model): 996 f = model._make_execution_function(mode) 997 return (f.inputs, f.outputs) 998 999 # NOTE(priyag): Try creating a new FuncGraph within DS scope instead of using 1000 # the global one. 1001 strategy = model._distribution_strategy 1002 global_graph = K.get_graph() 1003 1004 with global_graph.as_default(), strategy.scope(): 1005 # First we gather the relevant portions of the model across all replicas. 1006 # `K._scratch_graph(global_graph)` signals to Keras that it should not 1007 # lift to a separate graph when creating the per-replica functions. 1008 with K._scratch_graph(global_graph): 1009 # Create train ops on each of the devices when we call 1010 # `_per_replica_fit_function`. 1011 grouped = strategy.extended.call_for_each_replica( 1012 _per_replica_function, args=(get_distributed_model(model, mode),)) 1013 grouped_inputs, grouped_outputs = grouped 1014 1015 # Unwrap all the per device values returned from `call_for_each_replica`. 1016 # Unwrapping per device values gives you a list of values that can be 1017 # used to construct a new train function that is composed of 1018 # inputs/outputs on all the devices over which the model is distributed. 1019 (all_inputs, all_outputs, _, _) = unwrap_values( 1020 strategy, 1021 grouped_inputs, 1022 grouped_outputs, 1023 with_loss_tensor=(mode != ModeKeys.PREDICT)) 1024 1025 # Finally, a joint Keras function is created; this one will be created in 1026 # a separate FuncGraph. 1027 return K.function( 1028 all_inputs, 1029 all_outputs, 1030 name='eager_distributed_{}_function'.format(mode)) 1031 1032 1033def _copy_weights_to_distributed_model(original_model, mode): 1034 """Copies weights from original model to distributed models.""" 1035 strategy = original_model._distribution_strategy 1036 distributed_model = get_distributed_model(original_model, mode) 1037 if strategy: 1038 # Copy the weights from the original model to each of the replicated 1039 # models. 1040 orig_model_weights = original_model.get_weights() 1041 first_model = strategy.unwrap(distributed_model)[0] 1042 set_weights(strategy, first_model, orig_model_weights) 1043 1044 1045def _copy_weights_to_original_model(model, mode): 1046 """Copies weights from first distributed model back to original model.""" 1047 if model._distribution_strategy and mode == ModeKeys.TRAIN: 1048 distributed_model = get_distributed_model(model, mode) 1049 updated_weights = model._distribution_strategy.unwrap( 1050 distributed_model)[0].get_weights() 1051 model.set_weights(updated_weights) 1052 1053 1054def _per_replica_aggregate_batch(strategy, batch_outs, model, mode): 1055 """Aggregates the per-replica batch-level outputs from a distributed step.""" 1056 if strategy is not None and mode == ModeKeys.PREDICT: 1057 total_batch_outs = [] 1058 for i in range(len(model.outputs)): 1059 num_replicas = strategy.num_replicas_in_sync 1060 nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas] 1061 total_batch_outs.append( 1062 concat_along_batch_dimension(nest.flatten(nested_outs))) 1063 return total_batch_outs 1064 return batch_outs 1065 1066 1067def _reset_metrics(model): 1068 if model._distribution_strategy: 1069 for mode in [ModeKeys.TRAIN, ModeKeys.TEST, ModeKeys.PREDICT]: 1070 distributed_model = get_distributed_model(model, mode) 1071 if distributed_model: 1072 first_model = model._distribution_strategy.unwrap(distributed_model)[0] 1073 first_model.reset_metrics() 1074 1075 1076def get_distributed_model(model, mode): 1077 key = _generate_cache_key(mode) 1078 return model._distributed_model_cache.get(key, None) 1079 1080 1081def set_distributed_model(model, mode, distributed_model): 1082 key = _generate_cache_key(mode) 1083 model._distributed_model_cache[key] = distributed_model 1084 1085 1086def get_distributed_function(model, mode): 1087 key = _generate_cache_key(mode) 1088 return model._distributed_function_cache.get(key, None) 1089 1090 1091def set_distributed_function(model, mode, distributed_function): 1092 key = _generate_cache_key(mode) 1093 model._distributed_function_cache[key] = distributed_function 1094 1095 1096def _generate_cache_key(mode): 1097 key = hash(mode) 1098 return key 1099 1100 1101@tf_contextlib.contextmanager 1102def distributed_scope(strategy, learning_phase): 1103 with strategy.scope(), K.learning_phase_scope(learning_phase): 1104 yield 1105 1106 1107def call_replica_local_fn(fn, *args, **kwargs): 1108 """Call a function that uses replica-local variables. 1109 1110 This function correctly handles calling `fn` in a cross-replica 1111 context. 1112 1113 Arguments: 1114 fn: The function to call. 1115 *args: Positional arguments to the `fn`. 1116 **kwargs: Keyword argument to `fn`. 1117 1118 Returns: 1119 The result of calling `fn`. 1120 """ 1121 # TODO(b/132666209): Remove this function when we support assign_* 1122 # for replica-local variables. 1123 strategy = None 1124 if 'strategy' in kwargs: 1125 strategy = kwargs.pop('strategy') 1126 else: 1127 if ds_context.has_strategy(): 1128 strategy = ds_context.get_strategy() 1129 1130 # TODO(b/120571621): TPUStrategy does not implement replica-local variables. 1131 is_tpu = is_tpu_strategy(strategy) 1132 if ((not is_tpu) and strategy and ds_context.in_cross_replica_context()): 1133 with strategy.scope(): 1134 return strategy.extended.call_for_each_replica(fn, args, kwargs) 1135 return fn(*args, **kwargs) 1136 1137 1138def is_current_worker_chief(): 1139 return dc_context.get_current_worker_context().is_chief 1140 1141 1142def filter_distributed_callbacks(callbacks_list, model): 1143 """Filter Callbacks based on the worker context when running multi-worker. 1144 1145 Arguments: 1146 callbacks_list: A list of `Callback` instances. 1147 model: Keras model instance. 1148 1149 Returns: 1150 The list of `Callback` instances that should be run on this worker. 1151 """ 1152 1153 if not model._in_multi_worker_mode(): 1154 raise ValueError( 1155 'filter_distributed_callbacks() should only be called when Keras ' 1156 'is in multi worker mode.') 1157 1158 callbacks_list = callbacks_list or [] 1159 if not [ 1160 c for c in callbacks_list if isinstance(c, callbacks.ModelCheckpoint) 1161 ]: 1162 # TODO(rchao): Consider providing a ModelCheckpoint here if the user 1163 # fails to (possibly with tempfile directory). 1164 logging.warning('ModelCheckpoint callback is not provided. ' 1165 'Workers will need to restart training if any fails.') 1166 1167 if callbacks_list is None or is_current_worker_chief(): 1168 return callbacks_list 1169 1170 # Some Callbacks should only run on the chief worker. 1171 return [ 1172 callback for callback in callbacks_list if not callback._chief_worker_only 1173 ] # pylint: disable=protected-access 1174 1175 1176def _update_sample_weight_modes(model, mode, sample_weights): 1177 """Update sample_weight_mode of the distributed model.""" 1178 if is_distributing_by_cloning(model): 1179 distributed_model = get_distributed_model(model, mode) 1180 if not distributed_model: 1181 _make_replicated_models_with_cloning(model, mode) 1182 distributed_model = get_distributed_model(model, mode) 1183 distributed_model._recompile_exec_function = any( 1184 [e.sample_weights_mismatch() for e in model._training_endpoints]) 1185 1186 if sample_weights: 1187 distributed_models = flatten_per_replica_values( 1188 model._distribution_strategy, distributed_model) 1189 # sample_weights is a tuple of 1 list where the number of elements in the 1190 # list is equal to the number of replicas in sync. 1191 sample_weights = sample_weights[0] 1192 if sample_weights and None not in sample_weights: 1193 for m, sw in zip(distributed_models, sample_weights): 1194 m._update_sample_weight_modes(sample_weights=[sw]) 1195 1196 1197def concat_along_batch_dimension(outputs): 1198 """Concats prediction outputs along the batch dimension.""" 1199 if isinstance(outputs[0], sparse_tensor.SparseTensor): 1200 return sparse_ops.sparse_concat_v2(axis=0, sp_inputs=outputs) 1201 if isinstance(outputs[0], ragged_tensor.RaggedTensor): 1202 return ragged_concat_ops.concat(outputs, axis=0) 1203 return np.concatenate(outputs) 1204