1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities related to distributed training.""" 16# pylint:disable=protected-access 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.data.ops import iterator_ops 25from tensorflow.python.distribute import distribute_coordinator_context as dc_context 26from tensorflow.python.distribute import reduce_util 27from tensorflow.python.eager import context 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_util 31from tensorflow.python.keras import backend as K 32from tensorflow.python.keras import callbacks 33from tensorflow.python.keras import metrics as metrics_module 34from tensorflow.python.keras import optimizers 35from tensorflow.python.keras.engine import training_utils 36from tensorflow.python.keras.optimizer_v2 import optimizer_v2 37from tensorflow.python.keras.utils.mode_keys import ModeKeys 38from tensorflow.python.ops import control_flow_ops 39from tensorflow.python.ops import math_ops 40from tensorflow.python.ops import variables 41from tensorflow.python.platform import tf_logging as logging 42from tensorflow.python.util import nest 43from tensorflow.python.util import tf_contextlib 44 45 46def set_weights(distribution_strategy, dist_model, weights): 47 """Sets the weights of the replicated models. 48 49 The weights of the replicated models are set to the weights of the original 50 model. The weights of the replicated model are Mirrored variables and hence 51 we need to use the `update` call within a DistributionStrategy scope. 52 53 Args: 54 distribution_strategy: DistributionStrategy used to distribute training 55 and validation. 56 dist_model: The replicated models on the different devices. 57 weights: The weights of the original model. 58 """ 59 assign_ops = [] 60 for layer in dist_model.layers: 61 num_param = len(layer.weights) 62 layer_weights = weights[:num_param] 63 for sw, w in zip(layer.weights, layer_weights): 64 if ops.executing_eagerly_outside_functions(): 65 sw.assign(w) 66 else: 67 assign_ops.append(distribution_strategy.unwrap(sw.assign(w))) 68 weights = weights[num_param:] 69 70 if not ops.executing_eagerly_outside_functions(): 71 K.get_session(assign_ops).run(assign_ops) 72 73 74def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs, 75 grouped_updates=None, grouped_session_args=None, 76 with_loss_tensor=False): 77 """Unwrap and return the list of values contained in the PerDevice parameters. 78 79 This function calls `flatten_perdevice_values` to parse each of the input 80 parameters into a list of values on the different devices. If we set 81 `with_loss_tensor` to be True, we also call `reduce` on the list of losses on 82 the different devices to give us one loss tensor. 83 84 Args: 85 distribution_strategy: DistributionStrategy used to distribute training and 86 validation. 87 grouped_inputs: PerDevice inputs returned from the train or test function 88 that we ran on each device. 89 grouped_outputs: PerDevice outputs returned from the train or test function 90 that we ran on each device. 91 grouped_updates: PerDevice updates returned from the train or test function 92 that we ran on each device. 93 grouped_session_args: PerDevice session args returned from the train or 94 test function that we ran on each device. 95 with_loss_tensor: Boolean that indicates if we need to add the reduced loss 96 tensor as one of the outputs. 97 98 Returns: 99 Values of each of the PerDevice parameters. 100 101 """ 102 # Unwrap per device values returned from each model's train function. 103 # This will be used to construct the main train function. 104 all_inputs = flatten_perdevice_values(distribution_strategy, 105 grouped_inputs) 106 if with_loss_tensor: 107 # reduce loss tensor before adding it to the list of fetches 108 loss = distribution_strategy.reduce(reduce_util.ReduceOp.SUM, 109 grouped_outputs[0]) 110 all_outputs = flatten_perdevice_values(distribution_strategy, 111 grouped_outputs[1:]) 112 all_outputs = [loss] + all_outputs 113 else: 114 all_outputs = flatten_perdevice_values(distribution_strategy, 115 grouped_outputs) 116 117 if grouped_updates: 118 all_updates = flatten_perdevice_values(distribution_strategy, 119 grouped_updates) 120 else: 121 all_updates = None 122 123 all_session_args = {} 124 if grouped_session_args: 125 grouped_feed_dict = grouped_session_args.get('feed_dict') 126 if grouped_feed_dict: 127 all_session_args['feed_dict'] = flatten_perdevice_values( 128 distribution_strategy, grouped_feed_dict) 129 130 grouped_fetches = grouped_session_args.get('fetches') 131 if grouped_fetches: 132 all_session_args['fetches'] = flatten_perdevice_values( 133 distribution_strategy, grouped_fetches) 134 135 # TODO(priyag): Return only non empty/None values 136 return all_inputs, all_outputs, all_updates, all_session_args 137 138 139def flatten_perdevice_values(distribution_strategy, perdevice_values): 140 """Unwraps and flattens a nest of PerDevice parameters. 141 142 PerDevice values have one value associated with each device. Each entry in 143 the PerDevice dict has a device `key` and the corresponding value on the 144 device as the `value`. In this function we take a PerDevice value or a list of 145 PerDevice values and return all the values in the PerDevice dict. 146 147 Args: 148 distribution_strategy: DistributionStrategy used to distribute training and 149 validation. 150 perdevice_values: List of PerDevice object or a single PerDevice object. 151 152 Returns: 153 List of values of all the PerDevice objects. 154 155 """ 156 # This function takes a PerDevice object or a list of PerDevice objects and 157 # returns all the values associated with it. 158 return [e for flattened in nest.flatten(perdevice_values) 159 for e in distribution_strategy.unwrap(flattened)] 160 161 162def validate_callbacks(input_callbacks, optimizer): 163 """Validate whether given callbacks are supported by DistributionStrategy. 164 165 Args: 166 input_callbacks: List of callbacks passed by the user to fit. 167 optimizer: Optimizer instance used to train the model. 168 169 Raises: 170 ValueError: If `LearningRateScheduler` or `ReduceLROnPlateau` is one of the 171 callbacks passed. 172 ValueError: If `histogram_freq` or `write_grads` is one of the parameters 173 passed as part of the TensorBoard callback. 174 """ 175 if input_callbacks: 176 for callback in input_callbacks: 177 if callback not in [callbacks.TensorBoard, callbacks.ReduceLROnPlateau, 178 callbacks.LearningRateScheduler, callbacks.CSVLogger, 179 callbacks.EarlyStopping, callbacks.ModelCheckpoint, 180 callbacks.TerminateOnNaN, callbacks.ProgbarLogger, 181 callbacks.History, callbacks.RemoteMonitor]: 182 logging.warning('Your input callback is not one of the predefined ' 183 'Callbacks that supports DistributionStrategy. You ' 184 'might encounter an error if you access one of the ' 185 'model\'s attributes as part of the callback since ' 186 'these attributes are not set. You can access each of ' 187 'the individual distributed models using the ' 188 '`_grouped_model` attribute of your original model.') 189 if isinstance(callback, (callbacks.LearningRateScheduler, 190 callbacks.ReduceLROnPlateau)): 191 192 if not isinstance(optimizer, optimizer_v2.OptimizerV2): 193 raise ValueError('You must specify a Keras Optimizer V2 when using ' 194 '%s callback with DistributionStrategy.' % callback) 195 196 # If users want to use the TensorBoard callback they cannot use certain 197 # features of the callback that involve accessing model attributes and 198 # running ops. 199 if isinstance(callback, callbacks.TensorBoard): 200 if getattr(callback, 'histogram_freq', False): 201 logging.warning( 202 UserWarning( 203 '`histogram_freq` in the TensorBoard callback is not ' 204 'supported when using DistributionStrategy. Setting ' 205 '`histogram_freq` to `0`.')) 206 callback.histogram_freq = 0 207 if getattr(callback, 'write_grads', False): 208 logging.warning( 209 UserWarning( 210 '`write_grads` in the TensorBoard callback is not supported ' 211 'when using DistributionStrategy. Setting `write_grads` ' 212 'to `False`.')) 213 callback.histogram_freq = False 214 215 216def validate_distributed_dataset_inputs(distribution_strategy, x, y, 217 sample_weights=None): 218 """Validate all the components of a DistributedValue Dataset input. 219 220 Args: 221 distribution_strategy: The current DistributionStrategy used to call 222 `fit`/`evaluate`. 223 x: Input Dataset DistributedValue object. For example, when we use 224 `MirroredStrategy` this is a PerDevice object with a tensor for each 225 device set in the dict. x can also be a tuple or dict. The keys of the 226 dict should match the names of the input layers of the model. 227 y: Target Dataset DistributedValue object. For example, when we use 228 `MirroredStrategy` this is a PerDevice object with a tensor for each 229 device set in the dict. y can also be a tuple or dict. The keys of the 230 dict should match the names of the output layers of the model. 231 sample_weights: Sample weights Dataset DistributedValue object. For example, 232 when we use `MirroredStrategy` this is a PerDevice object with a tensor 233 for each device set in the dict. 234 235 Returns: 236 The unwrapped values list of the x and y DistributedValues inputs. 237 238 Raises: 239 ValueError: If x and y do not have support for being evaluated as tensors. 240 or if x and y contain elements that are not tensors or if x and y 241 contain elements that have a shape or dtype mismatch. 242 """ 243 # If the input and target used to call the model are not dataset tensors, 244 # we need to raise an error. When using a DistributionStrategy, the input 245 # and targets to a model should be from a `tf.data.Dataset`. 246 247 # If each element of x and y are not tensors, we cannot standardize and 248 # validate the input and targets. 249 x_values_list = validate_per_device_inputs(distribution_strategy, x) 250 251 if y is not None: 252 y_values_list = validate_per_device_inputs(distribution_strategy, y) 253 else: 254 y_values_list = None 255 256 if sample_weights is not None: 257 sample_weights_list = validate_per_device_inputs(distribution_strategy, 258 sample_weights) 259 else: 260 sample_weights_list = None 261 262 # Return the unwrapped values to avoid calling `unwrap` a second time. 263 return x_values_list, y_values_list, sample_weights_list 264 265 266def validate_per_device_inputs(distribution_strategy, x): 267 """Validates PerDevice dataset input list. 268 269 Args: 270 distribution_strategy: The current DistributionStrategy used to call 271 `fit`, `evaluate` and `predict`. 272 x: A list of PerDevice objects that represent the input or 273 target values. 274 275 Returns: 276 List containing the first element of each of the PerDevice objects in 277 the input list. 278 279 Raises: 280 ValueError: If any of the objects in the `per_device_list` is not a tensor. 281 282 """ 283 # Convert the inputs and targets into a list of PerDevice objects. 284 per_device_list = nest.flatten(x) 285 x_values_list = [] 286 for x in per_device_list: 287 if not tensor_util.is_tensor(x): 288 raise ValueError('Dataset input to the model should be tensors instead ' 289 'they are of type {}'.format(type(x))) 290 291 # At this point both x and y contain tensors in the `DistributedValues` 292 # structure. 293 x_values = distribution_strategy.unwrap(x) 294 295 # Validate that the shape and dtype of all the elements in x are the same. 296 validate_all_tensor_shapes(x, x_values) 297 validate_all_tensor_types(x, x_values) 298 299 x_values_list.append(x_values[0]) 300 return x_values_list 301 302 303def validate_all_tensor_types(x, x_values): 304 x_dtype = x_values[0].dtype 305 for i in range(1, len(x_values)): 306 if x_dtype != x_values[i].dtype: 307 raise ValueError('Input tensor dtypes do not match for distributed tensor' 308 ' inputs {}'.format(x)) 309 310 311def validate_all_tensor_shapes(x, x_values): 312 # Validate that the shape of all the elements in x have the same shape 313 x_shape = x_values[0].get_shape().as_list() 314 for i in range(1, len(x_values)): 315 if x_shape != x_values[i].get_shape().as_list(): 316 raise ValueError('Input tensor shapes do not match for distributed tensor' 317 ' inputs {}'.format(x)) 318 319 320def _wait_for_variable_initialization(session): 321 """Utility to wait for variables to be initialized.""" 322 all_variables = K._get_variables(K.get_graph()) # pylint: disable=protected-access 323 candidate_vars = [] 324 for v in all_variables: 325 if not getattr(v, '_keras_initialized', False): 326 candidate_vars.append(v) 327 328 if not candidate_vars: 329 return 330 331 while True: 332 is_initialized = session.run( 333 [variables.is_variable_initialized(v) for v in candidate_vars]) 334 uninitialized_vars = [] 335 for flag, v in zip(is_initialized, candidate_vars): 336 if not flag: 337 uninitialized_vars.append(v) 338 v._keras_initialized = True # pylint: disable=protected-access 339 if not uninitialized_vars: 340 break 341 342 343def init_restore_or_wait_for_variables(): 344 """Initialize or restore variables or wait for variables to be initialized.""" 345 session = K._get_session() # pylint: disable=protected-access 346 worker_context = dc_context.get_current_worker_context() 347 if not worker_context or worker_context.experimental_should_init: 348 # TODO(yuefengz): if checkpoints exist, restore from checkpoint. 349 K._initialize_variables(session) # pylint: disable=protected-access 350 else: 351 _wait_for_variable_initialization(session) 352 353 354def validate_inputs(x, y, distribution_strategy, allow_partial_batch=False): 355 """Validate inputs when using DistributionStrategy. 356 357 Args: 358 x: Model Inputs. 359 y: Model Targets. 360 distribution_strategy: The DistributionStrategy with which the model is 361 compiled. 362 allow_partial_batch: Boolean. If false, datasets must have fully 363 defined shapes. 364 365 Raises: 366 ValueError: if input is not a Dataset or a numpy array(when we use 367 MirroredStrategy). 368 """ 369 if (isinstance(x, iterator_ops.Iterator) or 370 isinstance(y, iterator_ops.Iterator)): 371 raise ValueError('`DistributionStrategy` does not support inputs of type ' 372 'Iterator. You must pass a `tf.data.Dataset` object or a ' 373 'numpy array as input.') 374 375 if is_tpu_strategy(distribution_strategy): 376 for i in [x, y]: 377 if (isinstance(i, dataset_ops.DatasetV2) and not allow_partial_batch): 378 if not is_dataset_shape_fully_defined(i): 379 raise ValueError( 380 'Using TPUs currently requires fully defined shapes. Either use ' 381 'set_shape() on the input tensors or use ' 382 'dataset.batch(..., drop_remainder=True).' 383 'Found unknown shape in input {}.'.format(i)) 384 385 386# TODO(b/118776054): Currently we support global batch size for TPUStrategy and 387# core MirroredStrategy only. Remove this check when contrib MirroredStrategy is 388# no longer needed. 389def global_batch_size_supported(distribution_strategy): 390 return distribution_strategy.extended._global_batch_size # pylint: disable=protected-access 391 392 393# TODO(sourabhbajaj): Remove this once we use the same API for all strategies. 394def is_tpu_strategy(strategy): 395 """We're executing TPU Strategy.""" 396 return strategy is not None and strategy.__class__.__name__ == 'TPUStrategy' 397 398 399def is_dataset_shape_fully_defined(dataset): 400 """Returns whether a dataset contains a final partial batch.""" 401 shapes = nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)) 402 unknown_shapes = [s for s in shapes if not s.is_fully_defined()] 403 return not unknown_shapes 404 405 406def get_input_params(distribution_strategy, first_x_value, steps, batch_size, 407 mode=None): 408 """Calculate the number of batches and steps/steps_per_epoch. 409 410 Args: 411 distribution_strategy: The DistributionStrategy used to compile the model. 412 first_x_value: This is the first input numpy array that is passed in as the 413 model input. 414 steps: The specified number of steps. 415 batch_size: The specified batch_size. 416 mode: ModeKey representing whether input will be used for training, 417 evaluation, or prediction. This is used to relax the constraints on 418 consuming all the training samples to keep compatibility till we 419 support partial batches. If none, then partial batches are not allowed. 420 421 Returns: 422 steps: The steps or steps_per_epoch argument depending on if a user is 423 calling `fit`, `evaluate` or `predict`. If the is_training flag is set 424 we don't require the number of samples to be used completely. 425 batch_size: The batch size to be used in model iterations. 426 427 Raises: 428 ValueError: If the number of batches or steps evaluates to 0. 429 430 """ 431 num_samples = first_x_value.shape[0] 432 # TODO(b/118776054): Use global batch size for Keras/DS support. 433 # Currently this is only supported in TPUStrategy and CoreMirroredStrategy. 434 use_per_replica_batch = not global_batch_size_supported( 435 distribution_strategy) 436 437 # Partial batches are allowed for training as we repeat the 438 # dataset when converting numpy arrays into a dataset. 439 # For other modes uneven batch sizes are not allowed except 440 # for `predict()` on TPUStrategy. 441 allow_partial_batch = (mode == ModeKeys.TRAIN or 442 (mode == ModeKeys.PREDICT 443 and is_tpu_strategy(distribution_strategy))) 444 445 if steps is None: 446 if batch_size is None: 447 # If neither the batch size or number of steps are set. We choose the 448 # global batch size as the minimum of number of samples and 32. 32 is 449 # chosen to provide backward compatibility. 450 global_batch_size = min(num_samples, 32) 451 else: 452 # If the user provided the batch size we need to handle the case 453 # between different strategies that use the global/per-replica batch size 454 global_batch_size = batch_size 455 if use_per_replica_batch: 456 global_batch_size *= distribution_strategy.num_replicas_in_sync 457 if allow_partial_batch: 458 steps = np.ceil(num_samples / global_batch_size).astype(int) 459 else: 460 if num_samples % global_batch_size: 461 raise ValueError('The number of samples %s is not divisible by ' 462 'batch size %s.' % (num_samples, global_batch_size)) 463 steps = num_samples // global_batch_size 464 else: 465 if batch_size is None: 466 # We calculate the batch size based on the number of steps specified 467 if num_samples % steps: 468 raise ValueError('The number of samples %s is not divisible by ' 469 'steps %s. Please change the number of steps to a ' 470 'value that can consume all the samples' % ( 471 num_samples, steps)) 472 global_batch_size = num_samples // steps 473 else: 474 # If the user provided the batch size we need to handle the case 475 # between different strategies that use the global/per-replica batch size 476 global_batch_size = batch_size 477 if use_per_replica_batch: 478 global_batch_size *= distribution_strategy.num_replicas_in_sync 479 480 min_num_samples = global_batch_size * steps 481 if allow_partial_batch: 482 min_num_samples = global_batch_size * (steps-1) + 1 if steps > 1 else 0 483 484 if num_samples < min_num_samples: 485 raise ValueError('Number of samples %s is less than samples required ' 486 'for specified batch_size %s and steps %s' % ( 487 num_samples, global_batch_size, steps)) 488 489 # We need to return the per replica or global batch size based on the strategy 490 if use_per_replica_batch: 491 if global_batch_size % distribution_strategy.num_replicas_in_sync: 492 raise ValueError( 493 'The batch size (%s) could not be sharded evenly across the sync ' 494 'replicas (%s) in the distribution strategy.' % ( 495 global_batch_size, distribution_strategy.num_replicas_in_sync)) 496 batch_size = global_batch_size // distribution_strategy.num_replicas_in_sync 497 else: 498 batch_size = global_batch_size 499 500 return steps, batch_size 501 502 503def get_batch_dimension(iterator): 504 shapes = nest.flatten(dataset_ops.get_legacy_output_shapes(iterator)) 505 # Take the batch size from the first element, as it should be the same for 506 # all. 507 dims = shapes[0].dims 508 return dims[0] if dims else None 509 510 511def list_to_tuple(maybe_list): 512 """Datasets treat lists specially, so switch them to tuples.""" 513 if isinstance(maybe_list, list): 514 return tuple(maybe_list) 515 return maybe_list 516 517 518def get_iterator(dataset, distribution_strategy): 519 with distribution_strategy.scope(): 520 iterator = distribution_strategy.make_dataset_iterator(dataset) 521 initialize_iterator(iterator, distribution_strategy) 522 return iterator 523 524 525def initialize_iterator(iterator, distribution_strategy): 526 with distribution_strategy.scope(): 527 init_op = control_flow_ops.group(iterator.initialize()) 528 if not context.executing_eagerly(): 529 K.get_session((init_op,)).run(init_op) 530 531 532def _get_input_from_iterator(iterator, model): 533 """Get elements from the iterator and verify the input shape and type.""" 534 next_element = iterator.get_next() 535 536 if len(nest.flatten(next_element)) == len(model.inputs): 537 x = next_element 538 y = None 539 sample_weights = None 540 elif len(nest.flatten(next_element)) == (len(model.inputs) + 541 len(model.outputs)): 542 x, y = next_element 543 sample_weights = None 544 else: 545 x, y, sample_weights = next_element 546 547 # Validate that all the elements in x and y are of the same type and shape. 548 validate_distributed_dataset_inputs( 549 model._distribution_strategy, x, y, sample_weights) 550 return x, y, sample_weights 551 552 553def _prepare_feed_values(model, inputs, targets, sample_weights, mode): 554 """Prepare feed values to the model execution function. 555 556 Arguments: 557 model: Model to prepare feed values for. 558 inputs: List or dict of model inputs. 559 targets: Optional list of model targets. 560 sample_weights: Optional list of sample weight arrays. 561 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 562 563 Returns: 564 Feed values for the model in the given mode. 565 """ 566 strategy = model._distribution_strategy 567 inputs, targets, sample_weights = _get_input_from_iterator(inputs, model) 568 inputs = flatten_perdevice_values(strategy, inputs) 569 targets = flatten_perdevice_values(strategy, targets) 570 # Expand 1-dimensional inputs. 571 # TODO(b/124535720): Remove once this standarize data logic is shared with 572 # main flow. 573 inputs, targets = nest.map_structure(training_utils.standardize_single_array, 574 (inputs, targets)) 575 if mode == ModeKeys.PREDICT: 576 sample_weights = [] 577 targets = [] 578 else: 579 sample_weights = [ 580 None for _ in range(len(model.outputs) * strategy.num_replicas_in_sync) 581 ] 582 ins = inputs + targets + sample_weights 583 if mode == ModeKeys.TRAIN and not isinstance(K.symbolic_learning_phase(), 584 int): 585 ins += [True] 586 return ins 587 588 589def _custom_compile_for_predict(model): 590 """Custom compile for TPU predict mode.""" 591 if not model.built: 592 # Model is not compilable because it does not know its number of inputs 593 # and outputs, nor their shapes and names. We will compile after the first 594 # time the model gets called on training data. 595 return 596 model._is_compiled = True 597 model.total_loss = None 598 model.train_function = None 599 model.test_function = None 600 model.predict_function = None 601 602 603def _build_network_on_replica(model, mode, inputs=None, targets=None): 604 """Build an updated model on replicas. 605 606 We create a new Keras model while sharing the variables from the old graph. 607 Building a new sub-graph is required since the original keras model creates 608 placeholders for the input and the output that are not accessible till we 609 call iterator.get_next() inside the step_fn for `fit`/`evaluate`/`predict`. 610 611 The sharing of weights and layers between the old and the new model gaurantee 612 that we're using Strategy variables and any updates on either model are 613 reflected correctly in callbacks and loop iterations. 614 615 We need to make sure we share the optimizers between the old and the new model 616 as well so that optimizer state is not lost if the user is running fit 617 multiple times. 618 619 Args: 620 model: Model to be replicated across Replicas 621 mode: Which of fit/eval/predict is building the distributed network 622 inputs: Input variables to be passed to the model 623 targets: Target tensor to be passed to model.compile 624 625 Returns: 626 A new model with shared layers with the old model. 627 """ 628 # Need to do imports here since we run into a circular dependency error. 629 from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top 630 from tensorflow.python.keras.engine import sequential # pylint: disable=g-import-not-at-top 631 632 # We rely on the internal methods to avoid having share_weights weights in the 633 # public API. 634 if isinstance(model, sequential.Sequential): 635 updated_model = models._clone_sequential_model(model, input_tensors=inputs, 636 share_weights=True) 637 else: 638 updated_model = models._clone_functional_model(model, input_tensors=inputs, 639 share_weights=True) 640 641 # Recast all low precision outputs back to float32 since we only casted 642 # the inputs to bfloat16 and not targets. This is done so that we can preserve 643 # precision when calculating the loss value. 644 def _upcast_low_precision_outputs(output): 645 if output.dtype == dtypes.bfloat16: 646 return math_ops.cast(output, dtypes.float32) 647 else: 648 return output 649 updated_model.outputs = [_upcast_low_precision_outputs(o) 650 for o in updated_model.outputs] 651 652 if isinstance(targets, tuple): 653 targets = nest.flatten(targets) 654 655 if mode == ModeKeys.PREDICT and inputs is not None: # TPU predict case 656 _custom_compile_for_predict(updated_model) 657 else: 658 updated_model.compile( 659 model.optimizer, 660 model.loss, 661 metrics=metrics_module.clone_metrics(model._compile_metrics), 662 loss_weights=model.loss_weights, 663 sample_weight_mode=model.sample_weight_mode, 664 weighted_metrics=metrics_module.clone_metrics( 665 model._compile_weighted_metrics), 666 target_tensors=targets) 667 return updated_model 668 669 670def _build_distributed_network(model, strategy, mode, inputs=None, 671 targets=None): 672 """Create a cloned model on each replica.""" 673 with K.get_graph().as_default(), strategy.scope(): 674 distributed_model = strategy.extended.call_for_each_replica( 675 _build_network_on_replica, 676 args=(model, mode, inputs, targets)) 677 set_distributed_model(model, mode, distributed_model) 678 679 680def _clone_and_build_model(model, mode, inputs=None, targets=None): 681 """Clone and build the given keras_model.""" 682 # We need to set the import here since we run into a circular dependency 683 # error. 684 from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top 685 cloned_model = models.clone_model(model, input_tensors=inputs) 686 687 # Compile and build model. 688 if isinstance(model.optimizer, optimizers.TFOptimizer): 689 optimizer = model.optimizer 690 else: 691 optimizer_config = model.optimizer.get_config() 692 optimizer = model.optimizer.__class__.from_config(optimizer_config) 693 694 # Recast all low precision outputs back to float32 since we only casted 695 # the inputs to bfloat16 and not targets. This is done so that we can preserve 696 # precision when calculating the loss value. 697 def _upcast_low_precision_outputs(output): 698 if output.dtype == dtypes.bfloat16: 699 return math_ops.cast(output, dtypes.float32) 700 else: 701 return output 702 cloned_model.outputs = [_upcast_low_precision_outputs(o) 703 for o in cloned_model.outputs] 704 705 if isinstance(targets, tuple): 706 targets = nest.flatten(targets) 707 if mode == ModeKeys.PREDICT and inputs is not None: # TPU predict case 708 _custom_compile_for_predict(cloned_model) 709 else: 710 cloned_model.compile( 711 optimizer, 712 model.loss, 713 metrics=metrics_module.clone_metrics(model._compile_metrics), 714 loss_weights=model.loss_weights, 715 sample_weight_mode=model.sample_weight_mode, 716 weighted_metrics=metrics_module.clone_metrics( 717 model._compile_weighted_metrics), 718 target_tensors=targets) 719 return cloned_model 720 721 722def clone_model_on_replicas(model, strategy, mode, inputs=None, targets=None): 723 """Create a cloned model on each replica.""" 724 with K.get_graph().as_default(), strategy.scope(): 725 distributed_model = strategy.extended.call_for_each_replica( 726 _clone_and_build_model, args=(model, mode, inputs, targets)) 727 set_distributed_model(model, mode, distributed_model) 728 if mode == ModeKeys.TRAIN: 729 model._make_callback_model(distributed_model) 730 731 732def _make_execution_function(model, mode): 733 """Makes or reuses function to run one step of distributed model execution.""" 734 strategy = model._distribution_strategy 735 736 distributed_model = get_distributed_model(model, mode) 737 # If distributed model for a particular `mode` is already built, use the 738 # `_distribution_function` on that distributed model. 739 if distributed_model: 740 return distributed_model._distributed_function 741 742 # If distributed_model is not built, create one for `mode`. 743 if model._compile_distribution: 744 clone_model_on_replicas(model, strategy, mode) 745 else: 746 _build_distributed_network(model, strategy, mode) 747 748 # We've just created the distributed model. So `distributed_model` should be 749 # not None. 750 distributed_model = get_distributed_model(model, mode) 751 assert distributed_model 752 753 # Also create an execution fuction on that distributed model. 754 if context.executing_eagerly(): 755 distributed_function = _make_eager_execution_function(model, mode) 756 else: 757 distributed_function = _make_graph_execution_function(model, mode) 758 759 # We cache the distributed execution function on the model since creating 760 # distributed models and exection functions are expensive. 761 distributed_model._distributed_function = distributed_function 762 return distributed_function 763 764 765def _make_graph_execution_function(model, mode): 766 """Makes function to run one step of distributed model in graph mode.""" 767 768 def _per_device_function(model): 769 f = model._make_execution_function(mode) 770 return (f.inputs, f.outputs, f.updates_op, f.session_kwargs) 771 772 strategy = model._distribution_strategy 773 with strategy.scope(): 774 # Create train ops on each of the devices when we call 775 # `_per_device_fit_function`. 776 (grouped_inputs, grouped_outputs, grouped_updates, 777 grouped_session_args) = strategy.extended.call_for_each_replica( 778 _per_device_function, args=(get_distributed_model(model, mode),)) 779 780 # Initialize the variables in the replicated model. This is necessary for 781 # multi-worker training because on some workers, initialization is not 782 # needed. This method does initialization or waiting for initialization 783 # according to the context object of distribute coordinator. 784 init_restore_or_wait_for_variables() 785 786 # Unwrap all the per device values returned from `call_for_each_replica`. 787 # Unwrapping per device values gives you a list of values that can be 788 # used to construct a new train function that is composed of update ops on 789 # all the devices over which the model is distributed. 790 (all_inputs, all_outputs, all_updates, all_session_args) = unwrap_values( 791 strategy, 792 grouped_inputs, 793 grouped_outputs, 794 grouped_updates, 795 grouped_session_args, 796 with_loss_tensor=(mode != ModeKeys.PREDICT)) 797 798 return K.function( 799 all_inputs, 800 all_outputs, 801 updates=all_updates, 802 name='distributed_{}_function'.format(mode), 803 **all_session_args) 804 805 806def _make_eager_execution_function(model, mode): 807 """Makes function to run one step of distributed model eager execution.""" 808 def _per_device_function(model): 809 f = model._make_execution_function(mode) 810 return (f.inputs, f.outputs) 811 812 # NOTE(priyag): Try creating a new FuncGraph within DS scope instead of using 813 # the global one. 814 strategy = model._distribution_strategy 815 global_graph = K.get_graph() 816 817 with global_graph.as_default(), strategy.scope(): 818 # First we gather the relevant portions of the model across all replicas. 819 # `K._scratch_graph(global_graph)` signals to Keras that it should not 820 # lift to a separate graph when creating the per-replica functions. 821 with K._scratch_graph(global_graph): 822 # Create train ops on each of the devices when we call 823 # `_per_device_fit_function`. 824 grouped = strategy.extended.call_for_each_replica( 825 _per_device_function, args=(get_distributed_model(model, mode),)) 826 grouped_inputs, grouped_outputs = grouped 827 828 # Unwrap all the per device values returned from `call_for_each_replica`. 829 # Unwrapping per device values gives you a list of values that can be 830 # used to construct a new train function that is composed of 831 # inputs/outputs on all the devices over which the model is distributed. 832 (all_inputs, all_outputs, _, _) = unwrap_values( 833 strategy, 834 grouped_inputs, 835 grouped_outputs, 836 with_loss_tensor=(mode != ModeKeys.PREDICT)) 837 838 # Finally, a joint Keras function is created; this one will be created in 839 # a separate FuncGraph. 840 return K.function( 841 all_inputs, 842 all_outputs, 843 name='eager_distributed_{}_function'.format(mode)) 844 845 846def _copy_weights_to_distributed_model(original_model, mode): 847 """Copies weights from original model to distributed models.""" 848 strategy = original_model._distribution_strategy 849 distributed_model = get_distributed_model(original_model, mode) 850 if strategy: 851 # Copy the weights from the original model to each of the replicated 852 # models. 853 orig_model_weights = original_model.get_weights() 854 first_model = strategy.unwrap(distributed_model)[0] 855 set_weights(strategy, first_model, orig_model_weights) 856 857 858def _copy_weights_to_original_model(model, mode): 859 """Copies weights from first distributed model back to original model.""" 860 if model._distribution_strategy and mode == ModeKeys.TRAIN: 861 distributed_model = get_distributed_model(model, mode) 862 updated_weights = model._distribution_strategy.unwrap( 863 distributed_model)[0].get_weights() 864 model.set_weights(updated_weights) 865 866 867def _per_device_aggregate_batch(batch_outs, model, mode): 868 """Aggregates the per-device batch-level outputs from a distributed step.""" 869 if model._distribution_strategy is not None and mode == ModeKeys.PREDICT: 870 total_batch_outs = [] 871 for i in range(len(model.outputs)): 872 num_replicas = model._distribution_strategy.num_replicas_in_sync 873 nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas] 874 total_batch_outs.append(np.concatenate(nest.flatten(nested_outs))) 875 return total_batch_outs 876 return batch_outs 877 878 879def _reset_metrics(model): 880 if model._distribution_strategy: 881 for mode in [ModeKeys.TRAIN, ModeKeys.TEST, ModeKeys.PREDICT]: 882 distributed_model = get_distributed_model(model, mode) 883 if distributed_model: 884 first_model = model._distribution_strategy.unwrap(distributed_model)[0] 885 first_model.reset_metrics() 886 887 888def get_distributed_model(model, mode): 889 key = _generate_cache_key(mode) 890 return model._distributed_model_cache.get(key, None) 891 892 893def set_distributed_model(model, mode, distributed_model): 894 key = _generate_cache_key(mode) 895 model._distributed_model_cache[key] = distributed_model 896 897 898def _generate_cache_key(mode): 899 key = hash(mode) 900 return key 901 902 903@tf_contextlib.contextmanager 904def distributed_scope(strategy, learning_phase): 905 with strategy.scope(), K.learning_phase_scope(learning_phase): 906 yield 907 908 909def filter_distributed_callbacks(callbacks_list): 910 """Filter Callbacks based on the worker context when running multi-worker. 911 912 Arguments: 913 callbacks_list: A list of `Callback` instances. 914 915 Returns: 916 The list of `Callback` instances that should be run on this worker. 917 """ 918 919 if not K.in_multi_worker_mode(): 920 raise ValueError( 921 'filter_distributed_callbacks() should only be called when Keras ' 922 'is in multi worker mode.') 923 924 worker_context = dc_context.get_current_worker_context() 925 callbacks_list = callbacks_list or [] 926 if not [ 927 c for c in callbacks_list if isinstance(c, callbacks.ModelCheckpoint) 928 ]: 929 # TODO(rchao): Consider providing a ModelCheckpoint here if the user 930 # fails to. 931 logging.warning('ModelCheckpoint callback is not provided. ' 932 'Workers will need to restart training if any fails.') 933 # TODO(rchao): Add similar warning for restoring callback (to be designed). 934 935 if callbacks_list is None or worker_context.is_chief: 936 return callbacks_list 937 938 # Some Callbacks should only run on the chief worker. 939 return [ 940 callback for callback in callbacks_list if not callback._chief_worker_only 941 ] # pylint: disable=protected-access 942