1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities related to distributed training.""" 16# pylint:disable=protected-access 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22 23import numpy as np 24 25from tensorflow.python.data.ops import dataset_ops 26from tensorflow.python.data.ops import iterator_ops 27from tensorflow.python.distribute import distribute_coordinator_context as dc_context 28from tensorflow.python.distribute import multi_worker_util 29from tensorflow.python.distribute import reduce_util 30from tensorflow.python.eager import context 31from tensorflow.python.eager import def_function 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import sparse_tensor 35from tensorflow.python.framework import tensor_util 36from tensorflow.python.keras import backend as K 37from tensorflow.python.keras import callbacks 38from tensorflow.python.keras import metrics as metrics_module 39from tensorflow.python.keras import optimizers 40from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils 41from tensorflow.python.keras.engine import training_utils_v1 42from tensorflow.python.keras.optimizer_v2 import optimizer_v2 43from tensorflow.python.keras.utils import tf_contextlib 44from tensorflow.python.keras.utils.mode_keys import ModeKeys 45from tensorflow.python.ops import array_ops 46from tensorflow.python.ops import control_flow_ops 47from tensorflow.python.ops import math_ops 48from tensorflow.python.ops import sparse_ops 49from tensorflow.python.ops import variables 50from tensorflow.python.ops.ragged import ragged_tensor 51from tensorflow.python.platform import tf_logging as logging 52from tensorflow.python.util import nest 53 54 55def set_weights(distribution_strategy, dist_model, weights): 56 """Sets the weights of the replicated models. 57 58 The weights of the replicated models are set to the weights of the original 59 model. The weights of the replicated model are Mirrored variables and hence 60 we need to use the `update` call within a DistributionStrategy scope. 61 62 Args: 63 distribution_strategy: DistributionStrategy used to distribute training 64 and validation. 65 dist_model: The replicated models on the different devices. 66 weights: The weights of the original model. 67 """ 68 assign_ops = [] 69 for layer in dist_model.layers: 70 num_param = len(layer.weights) 71 layer_weights = weights[:num_param] 72 for sw, w in zip(layer.weights, layer_weights): 73 if ops.executing_eagerly_outside_functions(): 74 sw.assign(w) 75 else: 76 assign_ops.append(distribution_strategy.unwrap(sw.assign(w))) 77 weights = weights[num_param:] 78 79 if not ops.executing_eagerly_outside_functions(): 80 K.get_session(assign_ops).run(assign_ops) 81 82 83def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs, 84 grouped_updates=None, grouped_session_args=None, 85 with_loss_tensor=False): 86 """Unwrap the list of values contained in the PerReplica parameters. 87 88 This function calls `flatten_per_replica_values` to parse each of the input 89 parameters into a list of values on the different devices. If we set 90 `with_loss_tensor` to be True, we also call `reduce` on the list of losses on 91 the different devices to give us one loss tensor. 92 93 Args: 94 distribution_strategy: DistributionStrategy used to distribute training and 95 validation. 96 grouped_inputs: PerReplica inputs returned from the train or test function 97 that we ran on each device. 98 grouped_outputs: PerReplica outputs returned from the train or test function 99 that we ran on each device. 100 grouped_updates: PerReplica updates returned from the train or test function 101 that we ran on each device. 102 grouped_session_args: PerReplica session args returned from the train or 103 test function that we ran on each device. 104 with_loss_tensor: Boolean that indicates if we need to add the reduced loss 105 tensor as one of the outputs. 106 107 Returns: 108 Values of each of the PerReplica parameters. 109 110 """ 111 # Unwrap per device values returned from each model's train function. 112 # This will be used to construct the main train function. 113 all_inputs = flatten_per_replica_values(distribution_strategy, 114 grouped_inputs) 115 all_outputs = unwrap_outputs(distribution_strategy, grouped_outputs, 116 with_loss_tensor) 117 118 if grouped_updates: 119 all_updates = flatten_per_replica_values(distribution_strategy, 120 grouped_updates) 121 else: 122 all_updates = None 123 124 all_session_args = {} 125 if grouped_session_args: 126 grouped_feed_dict = grouped_session_args.get('feed_dict') 127 if grouped_feed_dict: 128 all_session_args['feed_dict'] = flatten_per_replica_values( 129 distribution_strategy, grouped_feed_dict) 130 131 grouped_fetches = grouped_session_args.get('fetches') 132 if grouped_fetches: 133 all_session_args['fetches'] = flatten_per_replica_values( 134 distribution_strategy, grouped_fetches) 135 136 # TODO(priyag): Return only non empty/None values 137 return all_inputs, all_outputs, all_updates, all_session_args 138 139 140def unwrap_output_dict(strategy, grouped_outputs, mode): 141 """Unwrap the list of outputs contained in the PerReplica parameters.""" 142 if mode == ModeKeys.PREDICT: 143 return flatten_per_replica_values(strategy, grouped_outputs) 144 145 # In the case of fit/eval, the grouped_outputs is a dict, whereas in predict, 146 # the output is as same structure as model output. They need to be treated 147 # differently 148 total_loss = strategy.reduce(reduce_util.ReduceOp.SUM, 149 grouped_outputs['total_loss'][0], axis=None) 150 output_losses = flatten_per_replica_values(strategy, 151 grouped_outputs['output_losses']) 152 metrics = flatten_per_replica_values(strategy, 153 grouped_outputs['metrics']) 154 batch_size = strategy.reduce(reduce_util.ReduceOp.SUM, 155 grouped_outputs['batch_size'], axis=None) 156 if (K.is_tpu_strategy(strategy) and 157 ops.executing_eagerly_outside_functions()): 158 # Choose 1 value per replica in the TPU case since all replicas produce the 159 # same output. 160 # We only do this in eager mode for now since this function is used in 161 # both graph and eager mode and in the graph case we currently don't use 162 # experimental_run so would need to be removed when we converge the graph 163 # code path as well. 164 output_losses = output_losses[::strategy.num_replicas_in_sync] 165 metrics = metrics[::strategy.num_replicas_in_sync] 166 return {'total_loss': [total_loss], 167 'output_losses': output_losses, 168 'metrics': metrics, 169 'batch_size': batch_size} 170 171 172def unwrap_outputs(distribution_strategy, grouped_outputs, 173 with_loss_tensor=False): 174 """Unwrap the list of outputs contained in the PerReplica parameters. 175 176 This function calls `flatten_per_replica_values` to parse each of the input 177 parameters into a list of outputs on the different devices. If we set 178 `with_loss_tensor` to be True, we also call `reduce` on the list of losses on 179 the different devices to give us one loss tensor. 180 181 Args: 182 distribution_strategy: DistributionStrategy used to distribute training and 183 validation. 184 grouped_outputs: PerReplica outputs returned from the train or test function 185 that we ran on each device. 186 with_loss_tensor: Boolean that indicates if we need to add the reduced loss 187 tensor as one of the outputs. 188 189 Returns: 190 Values of each of the PerReplica outputs. 191 192 """ 193 if not with_loss_tensor: 194 return flatten_per_replica_values(distribution_strategy, 195 grouped_outputs) 196 197 if not isinstance(grouped_outputs, list): 198 grouped_outputs = [grouped_outputs] 199 # reduce loss tensor before adding it to the list of fetches 200 loss = distribution_strategy.reduce(reduce_util.ReduceOp.SUM, 201 grouped_outputs[0], axis=None) 202 all_outputs = flatten_per_replica_values(distribution_strategy, 203 grouped_outputs[1:]) 204 if (K.is_tpu_strategy(distribution_strategy) and 205 ops.executing_eagerly_outside_functions()): 206 # Choose 1 value per replica in the TPU case since all replicas produce the 207 # same output. 208 # We only do this in eager mode for now since this function is used in 209 # both graph and eager mode and in the graph case we currently don't use 210 # experimental_run so would need to be removed when we converge the graph 211 # code path as well. 212 all_outputs = all_outputs[::distribution_strategy.num_replicas_in_sync] 213 return [loss] + all_outputs 214 215 216def flatten_per_replica_values(distribution_strategy, per_replica_values): 217 """Unwraps and flattens a nest of PerReplica parameters. 218 219 PerReplica values have one value associated with each device. Each entry in 220 the PerReplica dict has a device `key` and the corresponding value on the 221 device as the `value`. In this function we take a PerReplica value or a list 222 of PerReplica values and return all the values in the PerReplica dict. 223 224 Args: 225 distribution_strategy: DistributionStrategy used to distribute training and 226 validation. 227 per_replica_values: List of PerReplica object or a single PerReplica object. 228 229 Returns: 230 List of values of all the PerReplica objects. 231 232 """ 233 # pylint: disable=g-complex-comprehension 234 # This function takes a PerReplica object or a list of PerReplica objects and 235 # returns all the values associated with it. 236 return [e for flattened in nest.flatten(per_replica_values) 237 for e in distribution_strategy.unwrap(flattened)] 238 239 240def validate_callbacks(input_callbacks, optimizer): 241 """Validate whether given callbacks are supported by DistributionStrategy. 242 243 Args: 244 input_callbacks: List of callbacks passed by the user to fit. 245 optimizer: Optimizer instance used to train the model. 246 247 Raises: 248 ValueError: If `LearningRateScheduler` or `ReduceLROnPlateau` is one of the 249 callbacks passed. 250 ValueError: If `write_grads` is one of the parameters passed as part of the 251 TensorBoard callback. 252 """ 253 if input_callbacks: 254 for callback in input_callbacks: 255 if isinstance(callback, (callbacks.LearningRateScheduler, 256 callbacks.ReduceLROnPlateau)): 257 258 if not isinstance(optimizer, optimizer_v2.OptimizerV2): 259 raise ValueError('You must specify a Keras Optimizer V2 when using ' 260 '%s callback with DistributionStrategy.' % callback) 261 262 # If users want to use the TensorBoard callback they cannot use certain 263 # features of the callback that involve accessing model attributes and 264 # running ops. 265 if isinstance(callback, callbacks.TensorBoard): 266 if getattr(callback, 'write_grads', False): 267 logging.warning( 268 UserWarning( 269 '`write_grads` in the TensorBoard callback is not supported ' 270 'when using DistributionStrategy. Setting `write_grads` ' 271 'to `False`.')) 272 callback.write_grads = False 273 274 275def validate_distributed_dataset_inputs(distribution_strategy, x, y, 276 sample_weights=None): 277 """Validate all the components of a DistributedValue Dataset input. 278 279 Args: 280 distribution_strategy: The current DistributionStrategy used to call 281 `fit`/`evaluate`. 282 x: Input Dataset DistributedValue object. For example, when we use 283 `MirroredStrategy` this is a PerReplica object with a tensor for each 284 device set in the dict. x can also be a tuple or dict. The keys of the 285 dict should match the names of the input layers of the model. 286 y: Target Dataset DistributedValue object. For example, when we use 287 `MirroredStrategy` this is a PerReplica object with a tensor for each 288 device set in the dict. y can also be a tuple or dict. The keys of the 289 dict should match the names of the output layers of the model. 290 sample_weights: Sample weights Dataset DistributedValue object. For example, 291 when we use `MirroredStrategy` this is a PerReplica object with a tensor 292 for each device set in the dict. 293 294 Returns: 295 The unwrapped values list of the x and y DistributedValues inputs. 296 297 Raises: 298 ValueError: If x and y do not have support for being evaluated as tensors. 299 or if x and y contain elements that are not tensors or if x and y 300 contain elements that have a shape or dtype mismatch. 301 """ 302 # If the input and target used to call the model are not dataset tensors, 303 # we need to raise an error. When using a DistributionStrategy, the input 304 # and targets to a model should be from a `tf.data.Dataset`. 305 306 # If each element of x and y are not tensors, we cannot standardize and 307 # validate the input and targets. 308 x_values_list = validate_per_replica_inputs(distribution_strategy, x) 309 310 if y is not None: 311 y_values_list = validate_per_replica_inputs(distribution_strategy, y) 312 else: 313 y_values_list = None 314 315 if sample_weights is not None: 316 sample_weights_list = validate_per_replica_inputs(distribution_strategy, 317 sample_weights) 318 else: 319 sample_weights_list = None 320 321 # Return the unwrapped values to avoid calling `unwrap` a second time. 322 return x_values_list, y_values_list, sample_weights_list 323 324 325def validate_per_replica_inputs(distribution_strategy, x): 326 """Validates PerReplica dataset input list. 327 328 Args: 329 distribution_strategy: The current DistributionStrategy used to call 330 `fit`, `evaluate` and `predict`. 331 x: A list of PerReplica objects that represent the input or 332 target values. 333 334 Returns: 335 List containing the first element of each of the PerReplica objects in 336 the input list. 337 338 Raises: 339 ValueError: If any of the objects in the `per_replica_list` is not a tensor. 340 341 """ 342 # Convert the inputs and targets into a list of PerReplica objects. 343 per_replica_list = nest.flatten(x, expand_composites=True) 344 x_values_list = [] 345 for x in per_replica_list: 346 # At this point x should contain only tensors. 347 x_values = distribution_strategy.unwrap(x) 348 for value in x_values: 349 if not tensor_util.is_tf_type(value): 350 raise ValueError('Dataset input to the model should be tensors instead ' 351 'they are of type {}'.format(type(value))) 352 353 if not context.executing_eagerly(): 354 # Validate that the shape and dtype of all the elements in x are the same. 355 validate_all_tensor_shapes(x, x_values) 356 validate_all_tensor_types(x, x_values) 357 358 x_values_list.append(x_values[0]) 359 return x_values_list 360 361 362def validate_all_tensor_types(x, x_values): 363 x_dtype = x_values[0].dtype 364 for i in range(1, len(x_values)): 365 if x_dtype != x_values[i].dtype: 366 raise ValueError('Input tensor dtypes do not match for distributed tensor' 367 ' inputs {}'.format(x)) 368 369 370def validate_all_tensor_shapes(x, x_values): 371 # Validate that the shape of all the elements in x have the same shape 372 x_shape = x_values[0].shape.as_list() 373 for i in range(1, len(x_values)): 374 if x_shape != x_values[i].shape.as_list(): 375 raise ValueError('Input tensor shapes do not match for distributed tensor' 376 ' inputs {}'.format(x)) 377 378 379def _wait_for_variable_initialization(session): 380 """Utility to wait for variables to be initialized.""" 381 all_variables = K._get_variables(K.get_graph()) # pylint: disable=protected-access 382 candidate_vars = [] 383 for v in all_variables: 384 if not getattr(v, '_keras_initialized', False): 385 candidate_vars.append(v) 386 387 if not candidate_vars: 388 return 389 390 while True: 391 is_initialized = session.run( 392 [variables.is_variable_initialized(v) for v in candidate_vars]) 393 uninitialized_vars = [] 394 for flag, v in zip(is_initialized, candidate_vars): 395 if not flag: 396 uninitialized_vars.append(v) 397 v._keras_initialized = True # pylint: disable=protected-access 398 if not uninitialized_vars: 399 break 400 401 402def init_restore_or_wait_for_variables(): 403 """Initialize or restore variables or wait for variables to be initialized.""" 404 session = K._get_session() # pylint: disable=protected-access 405 if not multi_worker_util.has_worker_context( 406 ) or multi_worker_util.should_load_checkpoint(): 407 # TODO(yuefengz): if checkpoints exist, restore from checkpoint. 408 K._initialize_variables(session) # pylint: disable=protected-access 409 else: 410 _wait_for_variable_initialization(session) 411 412 413def validate_inputs(x, y): 414 """Validate inputs when using DistributionStrategy. 415 416 Args: 417 x: Model Inputs. 418 y: Model Targets. 419 420 Raises: 421 ValueError: if input is not a Dataset or a numpy array(when we use 422 MirroredStrategy). 423 """ 424 if (isinstance(x, iterator_ops.Iterator) or 425 isinstance(y, iterator_ops.Iterator)): 426 raise ValueError('`DistributionStrategy` does not support inputs of type ' 427 'Iterator. You must pass a `tf.data.Dataset` object or a ' 428 'numpy array as input.') 429 430 431def is_dataset_shape_fully_defined(dataset): 432 """Returns whether a dataset contains a final partial batch.""" 433 shapes = nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)) 434 unknown_shapes = [s for s in shapes if not s.is_fully_defined()] 435 return not unknown_shapes 436 437 438def process_batch_and_step_size(strategy, 439 inputs, 440 batch_size, 441 steps_per_epoch, 442 mode, 443 validation_split=0.): 444 """Process the batch size and step size based on input and dist strategy.""" 445 first_x_value = nest.flatten(inputs)[0] 446 if isinstance(first_x_value, np.ndarray): 447 num_samples = first_x_value.shape[0] 448 if validation_split and 0. < validation_split < 1.: 449 num_samples = int(num_samples * (1 - validation_split)) 450 # Until support for partial batch is implemented across all 451 # functions and distribution strategy, we pass `mode` to selectively 452 # relax the constraint to consume all the training samples. 453 steps_per_epoch, batch_size = get_input_params( 454 strategy, num_samples, steps_per_epoch, batch_size, mode=mode) 455 return batch_size, steps_per_epoch 456 457 458def get_input_params(distribution_strategy, 459 num_samples, 460 steps, 461 batch_size, 462 mode=None): 463 """Calculate the number of batches and steps/steps_per_epoch. 464 465 Args: 466 distribution_strategy: The DistributionStrategy used to compile the model. 467 num_samples: The number of samples from which we determine the batch size 468 and steps. 469 steps: The specified number of steps. 470 batch_size: The specified batch_size. 471 mode: ModeKey representing whether input will be used for training, 472 evaluation, or prediction. This is used to relax the constraints on 473 consuming all the training samples to keep compatibility till we support 474 partial batches. If none, then partial batches are not allowed. 475 476 Returns: 477 steps: The steps or steps_per_epoch argument depending on if a user is 478 calling `fit`, `evaluate` or `predict`. If the is_training flag is set 479 we don't require the number of samples to be used completely. 480 batch_size: The batch size to be used in model iterations. 481 482 Raises: 483 ValueError: If the number of batches or steps evaluates to 0. 484 485 """ 486 # TODO(b/118776054): Use global batch size for Keras/DS support. 487 # Currently this is only supported in TPUStrategy and CoreMirroredStrategy. 488 use_per_replica_batch = not dist_utils.global_batch_size_supported( 489 distribution_strategy) 490 491 # TODO(b/128995245): In eager mode, uneven batch sizes are allowed except for 492 # `fit()` on TPUStrategy. 493 # In graph mode, the zero batch case in batch norm is not handled due to 494 # XLA-GPU regression. Uneven batch sizes are not allowed except 495 # for `test()` and `predict()` on TPUStrategy. 496 if context.executing_eagerly(): 497 allow_partial_batch = ( 498 mode != ModeKeys.TRAIN or 499 not K.is_tpu_strategy(distribution_strategy)) 500 else: 501 allow_partial_batch = ( 502 mode == ModeKeys.TRAIN or 503 ((mode == ModeKeys.PREDICT or mode == ModeKeys.TEST) and 504 K.is_tpu_strategy(distribution_strategy))) 505 506 if steps is None: 507 if batch_size is None: 508 # If neither the batch size or number of steps are set. We choose the 509 # global batch size as the minimum of number of samples and 32. 32 is 510 # chosen to provide backward compatibility. 511 global_batch_size = min(num_samples, 32) 512 else: 513 # If the user provided the batch size we need to handle the case 514 # between different strategies that use the global/per-replica batch size 515 global_batch_size = batch_size 516 if use_per_replica_batch: 517 global_batch_size *= distribution_strategy.num_replicas_in_sync 518 if allow_partial_batch: 519 steps = np.ceil(num_samples / global_batch_size).astype(int) 520 else: 521 if num_samples % global_batch_size: 522 raise ValueError('The number of samples %s is not divisible by ' 523 'batch size %s.' % (num_samples, global_batch_size)) 524 steps = num_samples // global_batch_size 525 else: 526 if batch_size is None: 527 # We calculate the batch size based on the number of steps specified 528 if num_samples % steps: 529 raise ValueError('The number of samples %s is not divisible by ' 530 'steps %s. Please change the number of steps to a ' 531 'value that can consume all the samples' % ( 532 num_samples, steps)) 533 global_batch_size = num_samples // steps 534 else: 535 # If the user provided the batch size we need to handle the case 536 # between different strategies that use the global/per-replica batch size 537 global_batch_size = batch_size 538 if use_per_replica_batch: 539 global_batch_size *= distribution_strategy.num_replicas_in_sync 540 541 min_num_samples = global_batch_size * steps 542 if allow_partial_batch: 543 min_num_samples = global_batch_size * (steps-1) + 1 if steps > 1 else 0 544 545 if num_samples < min_num_samples: 546 raise ValueError('Number of samples %s is less than samples required ' 547 'for specified batch_size %s and steps %s' % ( 548 num_samples, global_batch_size, steps)) 549 550 # We need to return the per replica or global batch size based on the strategy 551 if use_per_replica_batch: 552 if global_batch_size % distribution_strategy.num_replicas_in_sync: 553 raise ValueError( 554 'The batch size (%s) could not be sharded evenly across the sync ' 555 'replicas (%s) in the distribution strategy.' % ( 556 global_batch_size, distribution_strategy.num_replicas_in_sync)) 557 batch_size = global_batch_size // distribution_strategy.num_replicas_in_sync 558 else: 559 batch_size = global_batch_size 560 561 return steps, batch_size 562 563 564def get_batch_dimension(iterator): 565 shapes = nest.flatten(dataset_ops.get_legacy_output_shapes(iterator)) 566 # Take the batch size from the first element, as it should be the same for 567 # all. 568 dims = shapes[0].dims 569 return dims[0] if dims else None 570 571 572def get_iterator(dataset, distribution_strategy): 573 with distribution_strategy.scope(): 574 iterator = distribution_strategy.make_dataset_iterator(dataset) 575 initialize_iterator(iterator, distribution_strategy) 576 return iterator 577 578 579def initialize_iterator(iterator, distribution_strategy): 580 with distribution_strategy.scope(): 581 init_op = control_flow_ops.group(iterator.initializer) 582 if not context.executing_eagerly(): 583 K.get_session((init_op,)).run(init_op) 584 585 586def _get_input_from_iterator(iterator, model): 587 """Get elements from the iterator and verify the input shape and type.""" 588 next_element = iterator.get_next() 589 590 # `len(nest.flatten(x))` is going to not count empty elements such as {}. 591 # len(nest.flatten([[0,1,2], {}])) is 3 and not 4. The `next_element` is 592 # going to get flattened in `_prepare_feed_values` to work around that. Empty 593 # elements are going to get filtered out as part of the flattening. 594 if len(nest.flatten(next_element)) == len(model.inputs): 595 x = next_element 596 y = None 597 sample_weights = None 598 elif len(nest.flatten(next_element)) == (len(model.inputs) + 599 len(model.outputs)): 600 x, y = next_element 601 sample_weights = None 602 else: 603 x, y, sample_weights = next_element 604 605 # Validate that all the elements in x and y are of the same type and shape. 606 validate_distributed_dataset_inputs( 607 model._distribution_strategy, x, y, sample_weights) 608 return x, y, sample_weights 609 610 611def _prepare_feed_values(model, inputs, targets, sample_weights, mode): 612 """Prepare feed values to the model execution function. 613 614 Args: 615 model: Model to prepare feed values for. 616 inputs: List or dict of model inputs. 617 targets: Optional list of model targets. 618 sample_weights: Optional list of sample weight arrays. 619 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 620 621 Returns: 622 Feed values for the model in the given mode. 623 """ 624 strategy = model._distribution_strategy 625 inputs, targets, sample_weights = _get_input_from_iterator(inputs, model) 626 if K.is_tpu_strategy(strategy): 627 if sample_weights is not None: 628 raise ValueError('TPUStrategy does not support sample weights.') 629 630 # When the inputs are dict, then we want to flatten it in the same order as 631 # the input layers, such that the data are fed into the input layers in the 632 # correct order. 633 if isinstance(inputs, dict): 634 inputs = [inputs[key] for key in model._feed_input_names] 635 if is_distributing_by_cloning(model): 636 inputs = flatten_per_replica_values(strategy, inputs) 637 targets = flatten_per_replica_values(strategy, targets) 638 # Expand 1-dimensional inputs. 639 # TODO(b/124535720): Remove once this standarize data logic is shared with 640 # main flow. 641 inputs, targets = nest.map_structure( 642 training_utils_v1.standardize_single_array, (inputs, targets)) 643 else: 644 inputs = training_utils_v1.ModelInputs(inputs).as_list() 645 646 if mode == ModeKeys.PREDICT: 647 sample_weights = [] 648 targets = [] 649 elif sample_weights is not None and is_distributing_by_cloning(model): 650 if context.executing_eagerly() and not model._compile_distribution: 651 raise NotImplementedError('`sample_weight` is not supported when using ' 652 'tf.distribute.Strategy in eager mode and ' 653 'cloning=True.') 654 sample_weights = flatten_per_replica_values(strategy, sample_weights) 655 656 ins = [inputs, targets, sample_weights] 657 return tuple(ins) 658 659 660def is_distributing_by_cloning(model): 661 """Decide whether this model is going to be distributed via cloning. 662 663 We are going to distribute the model by cloning in graph mode. 664 665 Args: 666 model: Keras model to distribute. 667 668 Returns: 669 True if the `model` is going to be distributed using cloning and False 670 otherwise. 671 """ 672 if (K.is_tpu_strategy(model._distribution_strategy) and 673 context.executing_eagerly): # b/137580852 674 return False 675 elif ops.executing_eagerly_outside_functions(): 676 return bool(model._compile_distribution) 677 return True 678 679 680def _custom_compile_for_predict(model): 681 """Custom compile for TPU predict mode.""" 682 if not model.built: 683 # Model is not compilable because it does not know its number of inputs 684 # and outputs, nor their shapes and names. We will compile after the first 685 # time the model gets called on training data. 686 return 687 model._is_compiled = True 688 model.total_loss = None 689 model.train_function = None 690 model.test_function = None 691 model.predict_function = None 692 693 694def _build_network_on_replica(model, mode, inputs=None, targets=None): 695 """Build an updated model on replicas. 696 697 We create a new Keras model while sharing the variables from the old graph. 698 Building a new sub-graph is required since the original keras model creates 699 placeholders for the input and the output that are not accessible till we 700 call iterator.get_next() inside the step_fn for `fit`/`evaluate`/`predict`. 701 702 The sharing of weights and layers between the old and the new model guarantee 703 that we're using Strategy variables and any updates on either model are 704 reflected correctly in callbacks and loop iterations. 705 706 We need to make sure we share the optimizers between the old and the new model 707 as well so that optimizer state is not lost if the user is running fit 708 multiple times. 709 710 Args: 711 model: Model to be replicated across Replicas 712 mode: Which of fit/eval/predict is building the distributed network 713 inputs: Input variables to be passed to the model 714 targets: Target tensor to be passed to model.compile 715 716 Returns: 717 A new model with shared layers with the old model. 718 """ 719 # Need to do imports here since we run into a circular dependency error. 720 from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top 721 from tensorflow.python.keras.engine import sequential # pylint: disable=g-import-not-at-top 722 723 # We rely on the internal methods to avoid having share_weights weights in the 724 # public API. 725 if isinstance(model, sequential.Sequential): 726 updated_model = models._clone_sequential_model( 727 model, input_tensors=inputs, layer_fn=models.share_weights) 728 else: 729 updated_model = models._clone_functional_model( 730 model, input_tensors=inputs, layer_fn=models.share_weights) 731 # Callable losses added directly to a functional Model need to be added 732 # here. 733 updated_model._callable_losses = model._callable_losses 734 735 # Recast all low precision outputs back to float32 since we only casted 736 # the inputs to bfloat16 and not targets. This is done so that we can preserve 737 # precision when calculating the loss value. 738 def _upcast_low_precision_outputs(output): 739 if output.dtype == dtypes.bfloat16: 740 return math_ops.cast(output, dtypes.float32) 741 else: 742 return output 743 updated_model.outputs = [_upcast_low_precision_outputs(o) 744 for o in updated_model.outputs] 745 746 if isinstance(targets, tuple): 747 targets = nest.flatten(targets) 748 749 if mode == ModeKeys.PREDICT and inputs is not None: # TPU predict case 750 _custom_compile_for_predict(updated_model) 751 else: 752 updated_model.compile( 753 model.optimizer, 754 model.loss, 755 metrics=metrics_module.clone_metrics(model._compile_metrics), 756 loss_weights=model.loss_weights, 757 sample_weight_mode=model.sample_weight_mode, 758 weighted_metrics=metrics_module.clone_metrics( 759 model._compile_weighted_metrics), 760 target_tensors=targets) 761 return updated_model 762 763 764def _build_distributed_network(model, strategy, mode, inputs=None, 765 targets=None): 766 """Create a cloned model on each replica.""" 767 with K.get_graph().as_default(), strategy.scope(): 768 distributed_model = strategy.extended.call_for_each_replica( 769 _build_network_on_replica, 770 args=(model, mode, inputs, targets)) 771 set_distributed_model(model, mode, distributed_model) 772 773 774def _clone_and_build_model(model, mode, inputs=None, targets=None): 775 """Clone and build the given keras_model.""" 776 # We need to set the import here since we run into a circular dependency 777 # error. 778 from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top 779 cloned_model = models.clone_model(model, input_tensors=inputs) 780 781 # Compile and build model. 782 if isinstance(model.optimizer, optimizers.TFOptimizer): 783 optimizer = model.optimizer 784 else: 785 optimizer_config = model.optimizer.get_config() 786 optimizer = model.optimizer.__class__.from_config(optimizer_config) 787 788 # Recast all low precision outputs back to float32 since we only casted 789 # the inputs to bfloat16 and not targets. This is done so that we can preserve 790 # precision when calculating the loss value. 791 def _upcast_low_precision_outputs(output): 792 if output.dtype == dtypes.bfloat16: 793 return math_ops.cast(output, dtypes.float32) 794 else: 795 return output 796 cloned_model.outputs = [_upcast_low_precision_outputs(o) 797 for o in cloned_model.outputs] 798 799 if isinstance(targets, tuple): 800 targets = nest.flatten(targets) 801 if mode == ModeKeys.PREDICT and inputs is not None: # TPU predict case 802 _custom_compile_for_predict(cloned_model) 803 else: 804 cloned_model.compile( 805 optimizer, 806 model.loss, 807 metrics=metrics_module.clone_metrics(model._compile_metrics), 808 loss_weights=model.loss_weights, 809 sample_weight_mode=model.sample_weight_mode, 810 weighted_metrics=metrics_module.clone_metrics( 811 model._compile_weighted_metrics), 812 target_tensors=targets) 813 return cloned_model 814 815 816def clone_model_on_replicas(model, strategy, mode, inputs=None, targets=None): 817 """Create a cloned model on each replica.""" 818 with K.get_graph().as_default(), strategy.scope(): 819 distributed_model = strategy.extended.call_for_each_replica( 820 _clone_and_build_model, args=(model, mode, inputs, targets)) 821 set_distributed_model(model, mode, distributed_model) 822 if mode == ModeKeys.TRAIN: 823 model._make_callback_model(distributed_model) 824 825 826def _make_execution_function(model, mode): 827 """Makes or reuses function to run one step of distributed model execution.""" 828 if is_distributing_by_cloning(model): 829 return _make_execution_function_with_cloning(model, mode) 830 831 distributed_function = get_distributed_function(model, mode) 832 if distributed_function: 833 return distributed_function 834 835 distribution_function = _make_execution_function_without_cloning(model, mode) 836 set_distributed_function(model, mode, distribution_function) 837 return distribution_function 838 839 840def _make_execution_function_without_cloning(model, mode): 841 """Creates a function to run one step of distributed model execution.""" 842 strategy = model._distribution_strategy 843 844 with strategy.scope(): 845 per_replica_function = _make_replica_execution_function(model, mode) 846 847 def distributed_function(input_fn): 848 """A single step of the distributed execution across replicas.""" 849 x, y, sample_weights = input_fn() 850 # Call `Model.{train,test,predict}_on_batch` on every replica passing 851 # PerReplicas as arguments. On every replica inside this call, each 852 # PerReplica object will return the value for that replica. The outputs 853 # are PerReplicas too. 854 outputs = strategy.run(per_replica_function, args=(x, y, sample_weights)) 855 # Out of PerReplica outputs reduce or pick values to return. 856 all_outputs = unwrap_outputs( 857 strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT)) 858 return all_outputs 859 860 if not model.run_eagerly: 861 distributed_function = def_function.function(distributed_function) 862 def execution_function(input_fn): 863 # `numpy` translates Tensors to values in Eager mode. 864 return [out.numpy() for out in distributed_function(input_fn)] 865 else: 866 execution_function = distributed_function 867 868 return execution_function 869 870 871def _make_replica_execution_function(model, mode): 872 """A single step of the distributed execution on a replica.""" 873 if mode == ModeKeys.TRAIN: 874 func = model.train_on_batch 875 elif mode == ModeKeys.TEST: 876 func = model.test_on_batch 877 else: 878 879 def predict_on_batch(x, y=None, sample_weights=None): 880 del y, sample_weights 881 return model.predict_on_batch(x) 882 883 func = predict_on_batch 884 885 if mode != ModeKeys.PREDICT: 886 # `reset_metrics` is set to False to maintain stateful metrics across 887 # batch-level calls. 888 func = functools.partial(func, reset_metrics=False) 889 890 return func 891 892 893def _make_replicated_models_with_cloning(model, mode): 894 """Build models on each replica.""" 895 strategy = model._distribution_strategy 896 897 # If distributed_model is not built, create one for `mode`. 898 if model._compile_distribution: 899 clone_model_on_replicas(model, strategy, mode) 900 else: 901 _build_distributed_network(model, strategy, mode) 902 903 904def _make_execution_function_with_cloning(model, mode): 905 """Clones or re-uses models to run one step of distributed model execution.""" 906 distributed_model = get_distributed_model(model, mode) 907 # TODO(b/134069401): Create a cache for the distributed model and exec 908 # function that incorporates additional attributes to be part of the cache key 909 # than just the mode. 910 # If distributed model for a particular `mode` is already built, use the 911 # `_distribution_function` on that distributed model. 912 # If you have updated the sample_weight_mode on the model, then you will need 913 # to recompile metrics and recreate the execution function. This is indicated 914 # by the `_recompile_exec_function` property. 915 if (distributed_model and hasattr(distributed_model, '_distribution_function') 916 and not (hasattr(distributed_model, '_recompile_exec_function') and 917 distributed_model._recompile_exec_function)): 918 return distributed_model._distributed_function 919 920 if not distributed_model: 921 _make_replicated_models_with_cloning(model, mode) 922 distributed_model = get_distributed_model(model, mode) 923 assert distributed_model 924 925 # Also create an execution function on that distributed model. 926 if context.executing_eagerly(): 927 distributed_function = _make_eager_execution_function(model, mode) 928 else: 929 distributed_function = _make_graph_execution_function(model, mode) 930 931 # We cache the distributed execution function on the model since creating 932 # distributed models and execution functions are expensive. 933 distributed_model._distributed_function = distributed_function 934 distributed_model._recompile_exec_function = False 935 return distributed_function 936 937 938def _make_graph_execution_function(model, mode): 939 """Makes function to run one step of distributed model in graph mode.""" 940 941 def _per_replica_function(model): 942 f = model._make_execution_function(mode) 943 return (f.inputs, f.outputs, f.updates_op, f.session_kwargs) 944 945 strategy = model._distribution_strategy 946 with strategy.scope(): 947 # Create train ops on each of the devices when we call 948 # `_per_replica_fit_function`. 949 (grouped_inputs, grouped_outputs, grouped_updates, 950 grouped_session_args) = strategy.extended.call_for_each_replica( 951 _per_replica_function, args=(get_distributed_model(model, mode),)) 952 953 # Initialize the variables in the replicated model. This is necessary for 954 # multi-worker training because on some workers, initialization is not 955 # needed. This method does initialization or waiting for initialization 956 # according to the context object of distribute coordinator. 957 init_restore_or_wait_for_variables() 958 959 # Unwrap all the per device values returned from `call_for_each_replica`. 960 # Unwrapping per device values gives you a list of values that can be 961 # used to construct a new train function that is composed of update ops on 962 # all the devices over which the model is distributed. 963 (all_inputs, all_outputs, all_updates, all_session_args) = unwrap_values( 964 strategy, 965 grouped_inputs, 966 grouped_outputs, 967 grouped_updates, 968 grouped_session_args, 969 with_loss_tensor=(mode != ModeKeys.PREDICT)) 970 971 return K.function( 972 all_inputs, 973 all_outputs, 974 updates=all_updates, 975 name='distributed_{}_function'.format(mode), 976 **all_session_args) 977 978 979def _make_eager_execution_function(model, mode): 980 """Makes function to run one step of distributed model eager execution.""" 981 def _per_replica_function(model): 982 f = model._make_execution_function(mode) 983 return (f.inputs, f.outputs) 984 985 # NOTE(priyag): Try creating a new FuncGraph within DS scope instead of using 986 # the global one. 987 strategy = model._distribution_strategy 988 global_graph = K.get_graph() 989 990 with global_graph.as_default(), strategy.scope(): 991 # First we gather the relevant portions of the model across all replicas. 992 # `K._scratch_graph(global_graph)` signals to Keras that it should not 993 # lift to a separate graph when creating the per-replica functions. 994 with K._scratch_graph(global_graph): 995 # Create train ops on each of the devices when we call 996 # `_per_replica_fit_function`. 997 grouped = strategy.extended.call_for_each_replica( 998 _per_replica_function, args=(get_distributed_model(model, mode),)) 999 grouped_inputs, grouped_outputs = grouped 1000 1001 # Unwrap all the per device values returned from `call_for_each_replica`. 1002 # Unwrapping per device values gives you a list of values that can be 1003 # used to construct a new train function that is composed of 1004 # inputs/outputs on all the devices over which the model is distributed. 1005 (all_inputs, all_outputs, _, _) = unwrap_values( 1006 strategy, 1007 grouped_inputs, 1008 grouped_outputs, 1009 with_loss_tensor=(mode != ModeKeys.PREDICT)) 1010 1011 # Finally, a joint Keras function is created; this one will be created in 1012 # a separate FuncGraph. 1013 return K.function( 1014 all_inputs, 1015 all_outputs, 1016 name='eager_distributed_{}_function'.format(mode)) 1017 1018 1019def _copy_weights_to_distributed_model(original_model, mode): 1020 """Copies weights from original model to distributed models.""" 1021 strategy = original_model._distribution_strategy 1022 distributed_model = get_distributed_model(original_model, mode) 1023 if strategy: 1024 # Copy the weights from the original model to each of the replicated 1025 # models. 1026 orig_model_weights = original_model.get_weights() 1027 first_model = strategy.unwrap(distributed_model)[0] 1028 set_weights(strategy, first_model, orig_model_weights) 1029 1030 1031def _copy_weights_to_original_model(model, mode): 1032 """Copies weights from first distributed model back to original model.""" 1033 if model._distribution_strategy and mode == ModeKeys.TRAIN: 1034 distributed_model = get_distributed_model(model, mode) 1035 updated_weights = model._distribution_strategy.unwrap( 1036 distributed_model)[0].get_weights() 1037 model.set_weights(updated_weights) 1038 1039 1040def _per_replica_aggregate_batch(strategy, batch_outs, model, mode): 1041 """Aggregates the per-replica batch-level outputs from a distributed step.""" 1042 if strategy is not None and mode == ModeKeys.PREDICT: 1043 total_batch_outs = [] 1044 for i in range(len(model.outputs)): 1045 num_replicas = strategy.num_replicas_in_sync 1046 nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas] 1047 total_batch_outs.append( 1048 concat_along_batch_dimension(nest.flatten(nested_outs))) 1049 return total_batch_outs 1050 return batch_outs 1051 1052 1053def _reset_metrics(model): 1054 if model._distribution_strategy: 1055 for mode in [ModeKeys.TRAIN, ModeKeys.TEST, ModeKeys.PREDICT]: 1056 distributed_model = get_distributed_model(model, mode) 1057 if distributed_model: 1058 first_model = model._distribution_strategy.unwrap(distributed_model)[0] 1059 first_model.reset_metrics() 1060 1061 1062def get_distributed_model(model, mode): 1063 key = _generate_cache_key(mode) 1064 return model._distributed_model_cache.get(key, None) 1065 1066 1067def set_distributed_model(model, mode, distributed_model): 1068 key = _generate_cache_key(mode) 1069 model._distributed_model_cache[key] = distributed_model 1070 1071 1072def get_distributed_function(model, mode): 1073 key = _generate_cache_key(mode) 1074 return model._distributed_function_cache.get(key, None) 1075 1076 1077def set_distributed_function(model, mode, distributed_function): 1078 key = _generate_cache_key(mode) 1079 model._distributed_function_cache[key] = distributed_function 1080 1081 1082def _generate_cache_key(mode): 1083 key = hash(mode) 1084 return key 1085 1086 1087@tf_contextlib.contextmanager 1088def distributed_scope(strategy, learning_phase): 1089 with strategy.scope(), K.learning_phase_scope(learning_phase): 1090 yield 1091 1092 1093def is_current_worker_chief(): 1094 return dc_context.get_current_worker_context().is_chief 1095 1096 1097def filter_distributed_callbacks(callbacks_list, model): 1098 """Filter Callbacks based on the worker context when running multi-worker. 1099 1100 Args: 1101 callbacks_list: A list of `Callback` instances. 1102 model: Keras model instance. 1103 1104 Returns: 1105 The list of `Callback` instances that should be run on this worker. 1106 """ 1107 1108 if not model._in_multi_worker_mode(): 1109 raise ValueError( 1110 'filter_distributed_callbacks() should only be called when Keras ' 1111 'is in multi worker mode.') 1112 1113 callbacks_list = callbacks_list or [] 1114 if not [ 1115 c for c in callbacks_list if isinstance(c, callbacks.ModelCheckpoint) 1116 ]: 1117 # TODO(rchao): Consider providing a ModelCheckpoint here if the user 1118 # fails to (possibly with tempfile directory). 1119 logging.warning('ModelCheckpoint callback is not provided. ' 1120 'Workers will need to restart training if any fails.') 1121 1122 if callbacks_list is None or is_current_worker_chief(): 1123 return callbacks_list 1124 1125 # Some Callbacks should only run on the chief worker. 1126 return [ 1127 callback for callback in callbacks_list if not callback._chief_worker_only 1128 ] # pylint: disable=protected-access 1129 1130 1131def _update_sample_weight_modes(model, mode, sample_weights): 1132 """Update sample_weight_mode of the distributed model.""" 1133 if is_distributing_by_cloning(model): 1134 distributed_model = get_distributed_model(model, mode) 1135 if not distributed_model: 1136 _make_replicated_models_with_cloning(model, mode) 1137 distributed_model = get_distributed_model(model, mode) 1138 distributed_model._recompile_exec_function = any( 1139 [e.sample_weights_mismatch() for e in model._training_endpoints]) 1140 1141 if sample_weights: 1142 distributed_models = flatten_per_replica_values( 1143 model._distribution_strategy, distributed_model) 1144 # sample_weights is a tuple of 1 list where the number of elements in the 1145 # list is equal to the number of replicas in sync. 1146 sample_weights = sample_weights[0] 1147 if sample_weights and None not in sample_weights: 1148 for m, sw in zip(distributed_models, sample_weights): 1149 m._update_sample_weight_modes(sample_weights=[sw]) 1150 1151 1152def concat_along_batch_dimension(outputs): 1153 """Concats prediction outputs along the batch dimension.""" 1154 if isinstance(outputs[0], sparse_tensor.SparseTensor): 1155 return sparse_ops.sparse_concat_v2(axis=0, sp_inputs=outputs) 1156 if isinstance(outputs[0], ragged_tensor.RaggedTensor): 1157 return array_ops.concat(outputs, axis=0) 1158 return np.concatenate(outputs) 1159