1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TPU Strategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import atexit 22import collections 23import contextlib 24import copy 25import functools 26import weakref 27 28from absl import logging 29import numpy as np 30 31from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding 32from tensorflow.python.autograph.core import ag_ctx as autograph_ctx 33from tensorflow.python.autograph.impl import api as autograph 34from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 35from tensorflow.python.distribute import device_util 36from tensorflow.python.distribute import distribute_lib 37from tensorflow.python.distribute import distribute_utils 38from tensorflow.python.distribute import input_lib 39from tensorflow.python.distribute import numpy_dataset 40from tensorflow.python.distribute import reduce_util 41from tensorflow.python.distribute import tpu_util 42from tensorflow.python.distribute import tpu_values 43from tensorflow.python.distribute import values 44from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver 45from tensorflow.python.eager import context 46from tensorflow.python.eager import def_function 47from tensorflow.python.eager import function 48from tensorflow.python.framework import constant_op 49from tensorflow.python.framework import device_spec 50from tensorflow.python.framework import dtypes 51from tensorflow.python.framework import ops 52from tensorflow.python.framework import sparse_tensor 53from tensorflow.python.framework import tensor_shape 54from tensorflow.python.framework import tensor_util 55from tensorflow.python.ops import array_ops 56from tensorflow.python.ops import control_flow_ops 57from tensorflow.python.ops import math_ops 58from tensorflow.python.ops import resource_variable_ops 59from tensorflow.python.ops import variables as variables_lib 60from tensorflow.python.ops.ragged import ragged_tensor 61from tensorflow.python.tpu import device_assignment as device_assignment_lib # pylint: disable=unused-import 62from tensorflow.python.tpu import tpu 63from tensorflow.python.tpu import tpu_strategy_util 64from tensorflow.python.tpu import training_loop 65from tensorflow.python.tpu.ops import tpu_ops 66from tensorflow.python.util import deprecation 67from tensorflow.python.util import nest 68from tensorflow.python.util import tf_inspect 69from tensorflow.python.util.tf_export import tf_export 70 71_XLA_OP_BY_OP_INPUTS_LIMIT = 200 72 73 74@contextlib.contextmanager 75def maybe_init_scope(): 76 if ops.executing_eagerly_outside_functions(): 77 yield 78 else: 79 with ops.init_scope(): 80 yield 81 82 83def validate_run_function(fn): 84 """Validate the function passed into strategy.run.""" 85 86 # We allow three types of functions/objects passed into TPUStrategy 87 # run in eager mode: 88 # 1. a user annotated tf.function 89 # 2. a ConcreteFunction, this is mostly what you get from loading a saved 90 # model. 91 # 3. a callable object and the `__call__` method itself is a tf.function. 92 # 93 # Otherwise we return an error, because we don't support eagerly running 94 # run in TPUStrategy. 95 96 if context.executing_eagerly() \ 97 and not isinstance(fn, def_function.Function) \ 98 and not isinstance(fn, function.ConcreteFunction) \ 99 and not (callable(fn) and isinstance(fn.__call__, def_function.Function)): 100 raise NotImplementedError( 101 "TPUStrategy.run(fn, ...) does not support pure eager " 102 "execution. please make sure the function passed into " 103 "`strategy.run` is a `tf.function` or " 104 "`strategy.run` is called inside a `tf.function` if " 105 "eager behavior is enabled.") 106 107 108def _maybe_partial_apply_variables(fn, args, kwargs): 109 """Inspects arguments to partially apply any DistributedVariable. 110 111 This avoids an automatic cast of the current variable value to tensor. 112 113 Note that a variable may be captured implicitly with Python scope instead of 114 passing it to run(), but supporting run() keeps behavior consistent 115 with MirroredStrategy. 116 117 Since positional arguments must be applied from left to right, this function 118 does some tricky function inspection to move variable positional arguments 119 into kwargs. As a result of this, we can't support passing Variables as *args, 120 nor as args to functions which combine both explicit positional arguments and 121 *args. 122 123 Args: 124 fn: The function to run, as passed to run(). 125 args: Positional arguments to fn, as passed to run(). 126 kwargs: Keyword arguments to fn, as passed to run(). 127 128 Returns: 129 A tuple of the function (possibly wrapped), args, kwargs (both 130 possibly filtered, with members of args possibly moved to kwargs). 131 If no variables are found, this function is a noop. 132 133 Raises: 134 ValueError: If the function signature makes unsupported use of *args, or if 135 too many arguments are passed. 136 """ 137 138 def is_distributed_var(x): 139 flat = nest.flatten(x) 140 return flat and isinstance(flat[0], values.DistributedVariable) 141 142 # We will split kwargs into two dicts, one of which will be applied now. 143 var_kwargs = {} 144 nonvar_kwargs = {} 145 146 if kwargs: 147 var_kwargs = {k: v for k, v in kwargs.items() if is_distributed_var(v)} 148 if var_kwargs: 149 nonvar_kwargs = { 150 k: v for k, v in kwargs.items() if not is_distributed_var(v) 151 } 152 153 # Dump the argument names of `fn` to a list. This will include both positional 154 # and keyword arguments, but since positional arguments come first we can 155 # look up names of positional arguments by index. 156 positional_args = [] 157 index_of_star_args = None 158 for i, p in enumerate(tf_inspect.signature(fn).parameters.values()): 159 # Class methods define "self" as first argument, but we don't pass "self". 160 # Note that this is a heuristic, as a method can name its first argument 161 # something else, and a function can define a first argument "self" as well. 162 # In both of these cases, using a Variable will fail with an unfortunate 163 # error about the number of arguments. 164 # inspect.is_method() seems not to work here, possibly due to the use of 165 # tf.function(). 166 if i == 0 and p.name == "self": 167 continue 168 169 if p.kind == tf_inspect.Parameter.POSITIONAL_OR_KEYWORD: 170 positional_args.append(p.name) 171 172 elif p.kind == tf_inspect.Parameter.VAR_POSITIONAL: 173 # We'll raise an error later if a variable is passed to *args, since we 174 # can neither pass it by name nor partially apply it. This case only 175 # happens once at most. 176 index_of_star_args = i 177 178 elif p.kind == tf_inspect.Parameter.POSITIONAL_ONLY: 179 # This is a rare Python feature, indicating a / in the arg list. 180 if var_kwargs or any(is_distributed_var(a) for a in args): 181 raise ValueError( 182 "Mixing Variables and positional-only parameters not supported by " 183 "TPUStrategy.") 184 return fn, args, kwargs 185 186 star_args = [] 187 have_seen_var_arg = False 188 189 for i, a in enumerate(args): 190 if is_distributed_var(a): 191 if index_of_star_args is not None and i >= index_of_star_args: 192 raise ValueError( 193 "TPUStrategy.run() cannot handle Variables passed to *args. " 194 "Either name the function argument, or capture the Variable " 195 "implicitly.") 196 if len(positional_args) <= i: 197 raise ValueError( 198 "Too many positional arguments passed to call to TPUStrategy.run()." 199 ) 200 var_kwargs[positional_args[i]] = a 201 have_seen_var_arg = True 202 else: 203 if index_of_star_args is not None and i >= index_of_star_args: 204 if have_seen_var_arg: 205 raise ValueError( 206 "TPUStrategy.run() cannot handle both Variables and a mix of " 207 "positional args and *args. Either remove the *args, or capture " 208 "the Variable implicitly.") 209 else: 210 star_args.append(a) 211 continue 212 213 if len(positional_args) <= i: 214 raise ValueError( 215 "Too many positional arguments passed to call to TPUStrategy.run()." 216 ) 217 nonvar_kwargs[positional_args[i]] = a 218 219 if var_kwargs: 220 return functools.partial(fn, **var_kwargs), star_args, nonvar_kwargs 221 return fn, args, kwargs 222 223 224@tf_export("distribute.TPUStrategy", v1=[]) 225class TPUStrategyV2(distribute_lib.Strategy): 226 """Synchronous training on TPUs and TPU Pods. 227 228 To construct a TPUStrategy object, you need to run the 229 initialization code as below: 230 231 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 232 >>> tf.config.experimental_connect_to_cluster(resolver) 233 >>> tf.tpu.experimental.initialize_tpu_system(resolver) 234 >>> strategy = tf.distribute.TPUStrategy(resolver) 235 236 While using distribution strategies, the variables created within the 237 strategy's scope will be replicated across all the replicas and can be kept in 238 sync using all-reduce algorithms. 239 240 To run TF2 programs on TPUs, you can either use `.compile` and 241 `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized 242 training loop by calling `strategy.run` directly. Note that 243 TPUStrategy doesn't support pure eager execution, so please make sure the 244 function passed into `strategy.run` is a `tf.function` or 245 `strategy.run` is called inside a `tf.function` if eager 246 behavior is enabled. See more details in https://www.tensorflow.org/guide/tpu. 247 248 `distribute_datasets_from_function` and 249 `experimental_distribute_dataset` APIs can be used to distribute the dataset 250 across the TPU workers when writing your own training loop. If you are using 251 `fit` and `compile` methods available in `tf.keras.Model`, then Keras will 252 handle the distribution for you. 253 254 An example of writing customized training loop on TPUs: 255 256 >>> with strategy.scope(): 257 ... model = tf.keras.Sequential([ 258 ... tf.keras.layers.Dense(2, input_shape=(5,)), 259 ... ]) 260 ... optimizer = tf.keras.optimizers.SGD(learning_rate=0.1) 261 262 >>> def dataset_fn(ctx): 263 ... x = np.random.random((2, 5)).astype(np.float32) 264 ... y = np.random.randint(2, size=(2, 1)) 265 ... dataset = tf.data.Dataset.from_tensor_slices((x, y)) 266 ... return dataset.repeat().batch(1, drop_remainder=True) 267 >>> dist_dataset = strategy.distribute_datasets_from_function( 268 ... dataset_fn) 269 >>> iterator = iter(dist_dataset) 270 271 >>> @tf.function() 272 ... def train_step(iterator): 273 ... 274 ... def step_fn(inputs): 275 ... features, labels = inputs 276 ... with tf.GradientTape() as tape: 277 ... logits = model(features, training=True) 278 ... loss = tf.keras.losses.sparse_categorical_crossentropy( 279 ... labels, logits) 280 ... 281 ... grads = tape.gradient(loss, model.trainable_variables) 282 ... optimizer.apply_gradients(zip(grads, model.trainable_variables)) 283 ... 284 ... strategy.run(step_fn, args=(next(iterator),)) 285 286 >>> train_step(iterator) 287 288 For the advanced use cases like model parallelism, you can set 289 `experimental_device_assignment` argument when creating TPUStrategy to specify 290 number of replicas and number of logical devices. Below is an example to 291 initialize TPU system with 2 logical devices and 1 replica. 292 293 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 294 >>> tf.config.experimental_connect_to_cluster(resolver) 295 >>> topology = tf.tpu.experimental.initialize_tpu_system(resolver) 296 >>> device_assignment = tf.tpu.experimental.DeviceAssignment.build( 297 ... topology, 298 ... computation_shape=[1, 1, 1, 2], 299 ... num_replicas=1) 300 >>> strategy = tf.distribute.TPUStrategy( 301 ... resolver, experimental_device_assignment=device_assignment) 302 303 Then you can run a `tf.add` operation only on logical device 0. 304 305 >>> @tf.function() 306 ... def step_fn(inputs): 307 ... features, _ = inputs 308 ... output = tf.add(features, features) 309 ... 310 ... # Add operation will be executed on logical device 0. 311 ... output = strategy.experimental_assign_to_logical_device(output, 0) 312 ... return output 313 >>> dist_dataset = strategy.distribute_datasets_from_function( 314 ... dataset_fn) 315 >>> iterator = iter(dist_dataset) 316 >>> strategy.run(step_fn, args=(next(iterator),)) 317 """ 318 319 def __init__(self, 320 tpu_cluster_resolver=None, 321 experimental_device_assignment=None): 322 """Synchronous training in TPU donuts or Pods. 323 324 Args: 325 tpu_cluster_resolver: A 326 `tf.distribute.cluster_resolver.TPUClusterResolver` instance, which 327 provides information about the TPU cluster. If None, it will assume 328 running on a local TPU worker. 329 experimental_device_assignment: Optional 330 `tf.tpu.experimental.DeviceAssignment` to specify the placement of 331 replicas on the TPU cluster. 332 """ 333 super(TPUStrategyV2, self).__init__(TPUExtended( 334 self, tpu_cluster_resolver, 335 device_assignment=experimental_device_assignment)) 336 distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy") 337 distribute_lib.distribution_strategy_replica_gauge.get_cell( 338 "num_workers").set(self.extended.num_hosts) 339 distribute_lib.distribution_strategy_replica_gauge.get_cell( 340 "num_replicas_per_worker").set(self.extended.num_replicas_per_host) 341 # Packed variable is used to reduce the overhead of function execution. 342 # For a DistributedVariable, only one variable handle is captured into a 343 # function graph. It's only supported in eager mode. 344 self._enable_packed_variable_in_eager_mode = True 345 346 def run(self, fn, args=(), kwargs=None, options=None): 347 """Run the computation defined by `fn` on each TPU replica. 348 349 Executes ops specified by `fn` on each replica. If `args` or `kwargs` have 350 `tf.distribute.DistributedValues`, such as those produced by a 351 `tf.distribute.DistributedDataset` from 352 `tf.distribute.Strategy.experimental_distribute_dataset` or 353 `tf.distribute.Strategy.distribute_datasets_from_function`, 354 when `fn` is executed on a particular replica, it will be executed with the 355 component of `tf.distribute.DistributedValues` that correspond to that 356 replica. 357 358 `fn` may call `tf.distribute.get_replica_context()` to access members such 359 as `all_reduce`. 360 361 All arguments in `args` or `kwargs` should either be nest of tensors or 362 `tf.distribute.DistributedValues` containing tensors or composite tensors. 363 364 Example usage: 365 366 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 367 >>> tf.config.experimental_connect_to_cluster(resolver) 368 >>> tf.tpu.experimental.initialize_tpu_system(resolver) 369 >>> strategy = tf.distribute.TPUStrategy(resolver) 370 >>> @tf.function 371 ... def run(): 372 ... def value_fn(value_context): 373 ... return value_context.num_replicas_in_sync 374 ... distributed_values = ( 375 ... strategy.experimental_distribute_values_from_function(value_fn)) 376 ... def replica_fn(input): 377 ... return input * 2 378 ... return strategy.run(replica_fn, args=(distributed_values,)) 379 >>> result = run() 380 381 Args: 382 fn: The function to run. The output must be a `tf.nest` of `Tensor`s. 383 args: (Optional) Positional arguments to `fn`. 384 kwargs: (Optional) Keyword arguments to `fn`. 385 options: (Optional) An instance of `tf.distribute.RunOptions` specifying 386 the options to run `fn`. 387 388 Returns: 389 Merged return value of `fn` across replicas. The structure of the return 390 value is the same as the return value from `fn`. Each element in the 391 structure can either be `tf.distribute.DistributedValues`, `Tensor` 392 objects, or `Tensor`s (for example, if running on a single replica). 393 """ 394 validate_run_function(fn) 395 396 fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs) 397 398 # Note: the target function is converted to graph even when in Eager mode, 399 # so autograph is on by default here. 400 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) 401 options = options or distribute_lib.RunOptions() 402 return self.extended.tpu_run(fn, args, kwargs, options) 403 404 def experimental_assign_to_logical_device(self, tensor, logical_device_id): 405 """Adds annotation that `tensor` will be assigned to a logical device. 406 407 This adds an annotation to `tensor` specifying that operations on 408 `tensor` will be invoked on logical core device id `logical_device_id`. 409 When model parallelism is used, the default behavior is that all ops 410 are placed on zero-th logical device. 411 412 ```python 413 414 # Initializing TPU system with 2 logical devices and 4 replicas. 415 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 416 tf.config.experimental_connect_to_cluster(resolver) 417 topology = tf.tpu.experimental.initialize_tpu_system(resolver) 418 device_assignment = tf.tpu.experimental.DeviceAssignment.build( 419 topology, 420 computation_shape=[1, 1, 1, 2], 421 num_replicas=4) 422 strategy = tf.distribute.TPUStrategy( 423 resolver, experimental_device_assignment=device_assignment) 424 iterator = iter(inputs) 425 426 @tf.function() 427 def step_fn(inputs): 428 output = tf.add(inputs, inputs) 429 430 # Add operation will be executed on logical device 0. 431 output = strategy.experimental_assign_to_logical_device(output, 0) 432 return output 433 434 strategy.run(step_fn, args=(next(iterator),)) 435 ``` 436 437 Args: 438 tensor: Input tensor to annotate. 439 logical_device_id: Id of the logical core to which the tensor will be 440 assigned. 441 442 Raises: 443 ValueError: The logical device id presented is not consistent with total 444 number of partitions specified by the device assignment. 445 446 Returns: 447 Annotated tensor with identical value as `tensor`. 448 """ 449 num_logical_devices_per_replica = self.extended._tpu_devices.shape[1] # pylint: disable=protected-access 450 if (logical_device_id < 0 or 451 logical_device_id >= num_logical_devices_per_replica): 452 raise ValueError("`logical_core_id` to assign must be lower then total " 453 "number of logical devices per replica. Received " 454 "logical device id {} but there are only total of {} " 455 "logical devices in replica.".format( 456 logical_device_id, num_logical_devices_per_replica)) 457 return xla_sharding.assign_device( 458 tensor, logical_device_id, use_sharding_op=True) 459 460 def experimental_split_to_logical_devices(self, tensor, partition_dimensions): 461 """Adds annotation that `tensor` will be split across logical devices. 462 463 This adds an annotation to tensor `tensor` specifying that operations on 464 `tensor` will be split among multiple logical devices. Tensor `tensor` will 465 be split across dimensions specified by `partition_dimensions`. 466 The dimensions of `tensor` must be divisible by corresponding value in 467 `partition_dimensions`. 468 469 For example, for system with 8 logical devices, if `tensor` is an image 470 tensor with shape (batch_size, width, height, channel) and 471 `partition_dimensions` is [1, 2, 4, 1], then `tensor` will be split 472 2 in width dimension and 4 way in height dimension and the split 473 tensor values will be fed into 8 logical devices. 474 475 ```python 476 # Initializing TPU system with 8 logical devices and 1 replica. 477 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 478 tf.config.experimental_connect_to_cluster(resolver) 479 topology = tf.tpu.experimental.initialize_tpu_system(resolver) 480 device_assignment = tf.tpu.experimental.DeviceAssignment.build( 481 topology, 482 computation_shape=[1, 2, 2, 2], 483 num_replicas=1) 484 strategy = tf.distribute.TPUStrategy( 485 resolver, experimental_device_assignment=device_assignment) 486 487 iterator = iter(inputs) 488 489 @tf.function() 490 def step_fn(inputs): 491 inputs = strategy.experimental_split_to_logical_devices( 492 inputs, [1, 2, 4, 1]) 493 494 # model() function will be executed on 8 logical devices with `inputs` 495 # split 2 * 4 ways. 496 output = model(inputs) 497 return output 498 499 strategy.run(step_fn, args=(next(iterator),)) 500 ``` 501 Args: 502 tensor: Input tensor to annotate. 503 partition_dimensions: An unnested list of integers with the size equal to 504 rank of `tensor` specifying how `tensor` will be partitioned. The 505 product of all elements in `partition_dimensions` must be equal to the 506 total number of logical devices per replica. 507 508 Raises: 509 ValueError: 1) If the size of partition_dimensions does not equal to rank 510 of `tensor` or 2) if product of elements of `partition_dimensions` does 511 not match the number of logical devices per replica defined by the 512 implementing DistributionStrategy's device specification or 513 3) if a known size of `tensor` is not divisible by corresponding 514 value in `partition_dimensions`. 515 516 Returns: 517 Annotated tensor with identical value as `tensor`. 518 """ 519 num_logical_devices_per_replica = self.extended._tpu_devices.shape[1] # pylint: disable=protected-access 520 num_partition_splits = np.prod(partition_dimensions) 521 input_shape = tensor.shape 522 tensor_rank = len(input_shape) 523 524 if tensor_rank != len(partition_dimensions): 525 raise ValueError("Length of `partition_dimensions` ({}) must be " 526 "equal to the rank of `x` ({}).".format( 527 len(partition_dimensions), tensor_rank)) 528 529 for dim_index, dim_size in enumerate(input_shape): 530 if dim_size is None: 531 continue 532 533 split_size = partition_dimensions[dim_index] 534 if dim_size % split_size != 0: 535 raise ValueError("Tensor shape at dimension ({}) must be " 536 "divisible by corresponding value specified " 537 "by `partition_dimensions` ({}).".format( 538 dim_index, split_size)) 539 540 if num_partition_splits != num_logical_devices_per_replica: 541 raise ValueError("Number of logical devices ({}) does not match the " 542 "number of partition splits specified ({}).".format( 543 num_logical_devices_per_replica, 544 num_partition_splits)) 545 546 tile_assignment = np.arange(num_partition_splits).reshape( 547 partition_dimensions) 548 return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True) 549 550 def experimental_replicate_to_logical_devices(self, tensor): 551 """Adds annotation that `tensor` will be replicated to all logical devices. 552 553 This adds an annotation to tensor `tensor` specifying that operations on 554 `tensor` will be invoked on all logical devices. 555 556 ```python 557 # Initializing TPU system with 2 logical devices and 4 replicas. 558 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 559 tf.config.experimental_connect_to_cluster(resolver) 560 topology = tf.tpu.experimental.initialize_tpu_system(resolver) 561 device_assignment = tf.tpu.experimental.DeviceAssignment.build( 562 topology, 563 computation_shape=[1, 1, 1, 2], 564 num_replicas=4) 565 strategy = tf.distribute.TPUStrategy( 566 resolver, experimental_device_assignment=device_assignment) 567 568 iterator = iter(inputs) 569 570 @tf.function() 571 def step_fn(inputs): 572 images, labels = inputs 573 images = strategy.experimental_split_to_logical_devices( 574 inputs, [1, 2, 4, 1]) 575 576 # model() function will be executed on 8 logical devices with `inputs` 577 # split 2 * 4 ways. 578 output = model(inputs) 579 580 # For loss calculation, all logical devices share the same logits 581 # and labels. 582 labels = strategy.experimental_replicate_to_logical_devices(labels) 583 output = strategy.experimental_replicate_to_logical_devices(output) 584 loss = loss_fn(labels, output) 585 586 return loss 587 588 strategy.run(step_fn, args=(next(iterator),)) 589 ``` 590 Args: 591 tensor: Input tensor to annotate. 592 593 Returns: 594 Annotated tensor with identical value as `tensor`. 595 """ 596 return xla_sharding.replicate(tensor, use_sharding_op=True) 597 598 599@tf_export("distribute.experimental.TPUStrategy", v1=[]) 600@deprecation.deprecated_endpoints("distribute.experimental.TPUStrategy") 601class TPUStrategy(distribute_lib.Strategy): 602 """Synchronous training on TPUs and TPU Pods. 603 604 To construct a TPUStrategy object, you need to run the 605 initialization code as below: 606 607 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 608 >>> tf.config.experimental_connect_to_cluster(resolver) 609 >>> tf.tpu.experimental.initialize_tpu_system(resolver) 610 >>> strategy = tf.distribute.experimental.TPUStrategy(resolver) 611 612 While using distribution strategies, the variables created within the 613 strategy's scope will be replicated across all the replicas and can be kept in 614 sync using all-reduce algorithms. 615 616 To run TF2 programs on TPUs, you can either use `.compile` and 617 `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized 618 training loop by calling `strategy.run` directly. Note that 619 TPUStrategy doesn't support pure eager execution, so please make sure the 620 function passed into `strategy.run` is a `tf.function` or 621 `strategy.run` is called inside a `tf.function` if eager 622 behavior is enabled. 623 """ 624 625 def __init__(self, 626 tpu_cluster_resolver=None, 627 device_assignment=None): 628 """Synchronous training in TPU donuts or Pods. 629 630 Args: 631 tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 632 which provides information about the TPU cluster. 633 device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to 634 specify the placement of replicas on the TPU cluster. 635 """ 636 logging.warning( 637 "`tf.distribute.experimental.TPUStrategy` is deprecated, please use " 638 " the non experimental symbol `tf.distribute.TPUStrategy` instead.") 639 640 super(TPUStrategy, self).__init__(TPUExtended( 641 self, tpu_cluster_resolver, device_assignment=device_assignment)) 642 distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy") 643 distribute_lib.distribution_strategy_replica_gauge.get_cell( 644 "num_workers").set(self.extended.num_hosts) 645 distribute_lib.distribution_strategy_replica_gauge.get_cell( 646 "num_replicas_per_worker").set(self.extended.num_replicas_per_host) 647 # Packed variable is used to reduce the overhead of function execution. 648 # For a DistributedVariable, only one variable handle is captured into a 649 # function graph. It's only supported in eager mode. 650 self._enable_packed_variable_in_eager_mode = True 651 652 # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this 653 # can use the default implementation. 654 # This implementation runs a single step. It does not use infeed or outfeed. 655 def run(self, fn, args=(), kwargs=None, options=None): 656 """See base class.""" 657 validate_run_function(fn) 658 659 fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs) 660 661 # Note: the target function is converted to graph even when in Eager mode, 662 # so autograph is on by default here. 663 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) 664 options = options or distribute_lib.RunOptions() 665 return self.extended.tpu_run(fn, args, kwargs, options) 666 667 @property 668 def cluster_resolver(self): 669 """Returns the cluster resolver associated with this strategy. 670 671 `tf.distribute.experimental.TPUStrategy` provides the 672 associated `tf.distribute.cluster_resolver.ClusterResolver`. If the user 673 provides one in `__init__`, that instance is returned; if the user does 674 not, a default 675 `tf.distribute.cluster_resolver.TPUClusterResolver` is provided. 676 """ 677 return self.extended._tpu_cluster_resolver # pylint: disable=protected-access 678 679 680@tf_export(v1=["distribute.experimental.TPUStrategy"]) 681class TPUStrategyV1(distribute_lib.StrategyV1): 682 """TPU distribution strategy implementation.""" 683 684 def __init__(self, 685 tpu_cluster_resolver=None, 686 steps_per_run=None, 687 device_assignment=None): 688 """Initializes the TPUStrategy object. 689 690 Args: 691 tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 692 which provides information about the TPU cluster. 693 steps_per_run: Number of steps to run on device before returning to the 694 host. Note that this can have side-effects on performance, hooks, 695 metrics, summaries etc. 696 This parameter is only used when Distribution Strategy is used with 697 estimator or keras. 698 device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to 699 specify the placement of replicas on the TPU cluster. Currently only 700 supports the usecase of using a single core within a TPU cluster. 701 """ 702 super(TPUStrategyV1, self).__init__(TPUExtended( 703 self, tpu_cluster_resolver, steps_per_run, device_assignment)) 704 distribute_lib.distribution_strategy_gauge.get_cell("V1").set("TPUStrategy") 705 distribute_lib.distribution_strategy_replica_gauge.get_cell( 706 "num_workers").set(self.extended.num_hosts) 707 distribute_lib.distribution_strategy_replica_gauge.get_cell( 708 "num_replicas_per_worker").set(self.extended.num_replicas_per_host) 709 # Packed variable is used to reduce the overhead of function execution. 710 # For a DistributedVariable, only one variable handle is captured into a 711 # function graph. It's only supported in eager mode. 712 self._enable_packed_variable_in_eager_mode = True 713 714 @property 715 def steps_per_run(self): 716 """DEPRECATED: use .extended.steps_per_run instead.""" 717 return self._extended.steps_per_run 718 719 # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this 720 # can use the default implementation. 721 # This implementation runs a single step. It does not use infeed or outfeed. 722 def run(self, fn, args=(), kwargs=None, options=None): 723 """Run `fn` on each replica, with the given arguments. 724 725 Executes ops specified by `fn` on each replica. If `args` or `kwargs` have 726 "per-replica" values, such as those produced by a "distributed `Dataset`", 727 when `fn` is executed on a particular replica, it will be executed with the 728 component of those "per-replica" values that correspond to that replica. 729 730 `fn` may call `tf.distribute.get_replica_context()` to access members such 731 as `all_reduce`. 732 733 All arguments in `args` or `kwargs` should either be nest of tensors or 734 per-replica objects containing tensors or composite tensors. 735 736 Users can pass strategy specific options to `options` argument. An example 737 to enable bucketizing dynamic shapes in `TPUStrategy.run` 738 is: 739 740 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 741 >>> tf.config.experimental_connect_to_cluster(resolver) 742 >>> tf.tpu.experimental.initialize_tpu_system(resolver) 743 >>> strategy = tf.distribute.experimental.TPUStrategy(resolver) 744 745 >>> options = tf.distribute.RunOptions( 746 ... experimental_bucketizing_dynamic_shape=True) 747 748 >>> dataset = tf.data.Dataset.range( 749 ... strategy.num_replicas_in_sync, output_type=dtypes.float32).batch( 750 ... strategy.num_replicas_in_sync, drop_remainder=True) 751 >>> input_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 752 753 >>> @tf.function() 754 ... def step_fn(inputs): 755 ... output = tf.reduce_sum(inputs) 756 ... return output 757 758 >>> strategy.run(step_fn, args=(next(input_iterator),), options=options) 759 760 Args: 761 fn: The function to run. The output must be a `tf.nest` of `Tensor`s. 762 args: (Optional) Positional arguments to `fn`. 763 kwargs: (Optional) Keyword arguments to `fn`. 764 options: (Optional) An instance of `tf.distribute.RunOptions` specifying 765 the options to run `fn`. 766 767 Returns: 768 Merged return value of `fn` across replicas. The structure of the return 769 value is the same as the return value from `fn`. Each element in the 770 structure can either be "per-replica" `Tensor` objects or `Tensor`s 771 (for example, if running on a single replica). 772 """ 773 validate_run_function(fn) 774 775 fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs) 776 777 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) 778 options = options or distribute_lib.RunOptions() 779 return self.extended.tpu_run(fn, args, kwargs, options) 780 781 782# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1. 783class TPUExtended(distribute_lib.StrategyExtendedV1): 784 """Implementation of TPUStrategy.""" 785 786 def __init__(self, 787 container_strategy, 788 tpu_cluster_resolver=None, 789 steps_per_run=None, 790 device_assignment=None): 791 super(TPUExtended, self).__init__(container_strategy) 792 793 if tpu_cluster_resolver is None: 794 tpu_cluster_resolver = TPUClusterResolver("") 795 796 if steps_per_run is None: 797 # TODO(frankchn): Warn when we are being used by DS/Keras and this is 798 # not specified. 799 steps_per_run = 1 800 801 # `self._tpu_function_cache` is a dict of `tf.function`s, thus if a 802 # `tf.function` is passed into `strategy.run` in eager mode, the 803 # `tf.function` won't get retraced. 804 self._tpu_function_cache = weakref.WeakKeyDictionary() 805 806 self._tpu_cluster_resolver = tpu_cluster_resolver 807 self._tpu_metadata = self._tpu_cluster_resolver.get_tpu_system_metadata() 808 self._device_assignment = device_assignment 809 810 tpu_devices_flat = [ 811 d.name for d in self._tpu_metadata.devices if "device:TPU:" in d.name] 812 813 # `self._tpu_devices` is a two-dimensional NumPy array of strings. It is 814 # indexed using `[replica_id][logical_device_id]`. 815 if device_assignment is None: 816 self._tpu_devices = np.array( 817 [[d] for d in tpu_devices_flat], dtype=object) 818 else: 819 job_name = device_spec.DeviceSpecV2.from_string(tpu_devices_flat[0]).job 820 821 tpu_devices = [] 822 for replica_id in range(device_assignment.num_replicas): 823 replica_devices = [] 824 825 for logical_core in range(device_assignment.num_cores_per_replica): 826 replica_devices.append( 827 device_util.canonicalize( 828 device_assignment.tpu_device( 829 replica=replica_id, 830 logical_core=logical_core, 831 job=job_name))) 832 833 tpu_devices.append(replica_devices) 834 self._tpu_devices = np.array(tpu_devices, dtype=object) 835 836 self._host_device = device_util.get_host_for_device(self._tpu_devices[0][0]) 837 838 # Preload the data onto the TPUs. Currently we always preload onto logical 839 # device 0 for each replica. 840 # TODO(cjfj): Create `InputWorkers` lazily, allowing users to place the 841 # input onto a different logical device? 842 self._device_input_worker_devices = collections.OrderedDict() 843 self._host_input_worker_devices = collections.OrderedDict() 844 for tpu_device in self._tpu_devices[:, 0]: 845 host_device = device_util.get_host_for_device(tpu_device) 846 self._device_input_worker_devices.setdefault(host_device, []) 847 self._device_input_worker_devices[host_device].append(tpu_device) 848 self._host_input_worker_devices.setdefault(host_device, []) 849 self._host_input_worker_devices[host_device].append(host_device) 850 851 # TODO(sourabhbajaj): Remove this once performance of running one step 852 # at a time is comparable to multiple steps. 853 self.steps_per_run = steps_per_run 854 self._require_static_shapes = True 855 856 self.experimental_enable_get_next_as_optional = True 857 858 self._logical_device_stack = [0] 859 860 if context.executing_eagerly(): 861 # In async remote eager, we want to sync the executors before exiting the 862 # program. 863 def async_wait(): 864 if context.context()._context_handle is not None: # pylint: disable=protected-access 865 context.async_wait() 866 atexit.register(async_wait) 867 868 # Flag to turn on VariablePolicy 869 self._use_var_policy = True 870 871 # Flag to enable TF2 SPMD 872 self._use_spmd_for_xla_partitioning = False 873 874 def _validate_colocate_with_variable(self, colocate_with_variable): 875 distribute_utils. validate_colocate(colocate_with_variable, self) 876 877 def _make_dataset_iterator(self, dataset): 878 """Make iterators for each of the TPU hosts.""" 879 input_workers = input_lib.InputWorkers( 880 tuple(self._device_input_worker_devices.items())) 881 return input_lib.DatasetIterator( 882 dataset, 883 input_workers, 884 self._container_strategy(), 885 num_replicas_in_sync=self._num_replicas_in_sync) 886 887 def _make_input_fn_iterator( 888 self, 889 input_fn, 890 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 891 input_contexts = [] 892 input_workers = input_lib.InputWorkers( 893 tuple(self._device_input_worker_devices.items())) 894 num_workers = input_workers.num_workers 895 for i in range(num_workers): 896 input_contexts.append(distribute_lib.InputContext( 897 num_input_pipelines=num_workers, 898 input_pipeline_id=i, 899 num_replicas_in_sync=self._num_replicas_in_sync)) 900 return input_lib.InputFunctionIterator( 901 input_fn, 902 input_workers, 903 input_contexts, 904 self._container_strategy()) 905 906 def _experimental_make_numpy_dataset(self, numpy_input, session): 907 return numpy_dataset.one_host_numpy_dataset( 908 numpy_input, numpy_dataset.SingleDevice(self._host_device), 909 session) 910 911 def _get_input_workers(self, options): 912 if not options or options.experimental_prefetch_to_device: 913 return input_lib.InputWorkers( 914 tuple(self._device_input_worker_devices.items())) 915 else: 916 return input_lib.InputWorkers( 917 tuple(self._host_input_worker_devices.items())) 918 919 def _check_spec(self, element_spec): 920 if isinstance(element_spec, values.PerReplicaSpec): 921 element_spec = element_spec._component_specs # pylint: disable=protected-access 922 specs = nest.flatten_with_joined_string_paths(element_spec) 923 for path, spec in specs: 924 if isinstance(spec, (sparse_tensor.SparseTensorSpec, 925 ragged_tensor.RaggedTensorSpec)): 926 raise ValueError( 927 "Found tensor {} with spec {}. TPUStrategy does not support " 928 "distributed datasets with device prefetch when using sparse or " 929 "ragged tensors. If you intend to use sparse or ragged tensors, " 930 "please pass a tf.distribute.InputOptions object with " 931 "experimental_prefetch_to_device set to False to your dataset " 932 "distribution function.".format(path, type(spec))) 933 934 def _experimental_distribute_dataset(self, dataset, options): 935 if (options and options.experimental_replication_mode == 936 distribute_lib.InputReplicationMode.PER_REPLICA): 937 raise NotImplementedError( 938 "InputReplicationMode.PER_REPLICA " 939 "is only supported in " 940 "`experimental_distribute_datasets_from_function`." 941 ) 942 if options is None or options.experimental_prefetch_to_device: 943 self._check_spec(dataset.element_spec) 944 945 return input_lib.get_distributed_dataset( 946 dataset, 947 self._get_input_workers(options), 948 self._container_strategy(), 949 num_replicas_in_sync=self._num_replicas_in_sync) 950 951 def _distribute_datasets_from_function(self, dataset_fn, options): 952 if (options and options.experimental_replication_mode == 953 distribute_lib.InputReplicationMode.PER_REPLICA): 954 raise NotImplementedError( 955 "InputReplicationMode.PER_REPLICA " 956 "is only supported in " 957 " `experimental_distribute_datasets_from_function` " 958 "of tf.distribute.MirroredStrategy") 959 input_workers = self._get_input_workers(options) 960 input_contexts = [] 961 num_workers = input_workers.num_workers 962 for i in range(num_workers): 963 input_contexts.append(distribute_lib.InputContext( 964 num_input_pipelines=num_workers, 965 input_pipeline_id=i, 966 num_replicas_in_sync=self._num_replicas_in_sync)) 967 968 distributed_dataset = input_lib.get_distributed_datasets_from_function( 969 dataset_fn, 970 input_workers, 971 input_contexts, 972 self._container_strategy()) 973 974 # We can only check after the dataset_fn is called. 975 if options is None or options.experimental_prefetch_to_device: 976 self._check_spec(distributed_dataset.element_spec) 977 return distributed_dataset 978 979 def _experimental_distribute_values_from_function(self, value_fn): 980 per_replica_values = [] 981 for replica_id in range(self._num_replicas_in_sync): 982 per_replica_values.append( 983 value_fn(distribute_lib.ValueContext(replica_id, 984 self._num_replicas_in_sync))) 985 return distribute_utils.regroup(per_replica_values, always_wrap=True) 986 987 # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. 988 # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have 989 # a mechanism to infer the outputs of `fn`. Pending b/110550782. 990 def _experimental_run_steps_on_iterator( 991 self, fn, multi_worker_iterator, iterations, initial_loop_values=None): 992 # Wrap `fn` for repeat. 993 if initial_loop_values is None: 994 initial_loop_values = {} 995 initial_loop_values = nest.flatten(initial_loop_values) 996 ctx = input_lib.MultiStepContext() 997 998 def run_fn(inputs): 999 """Single step on the TPU device.""" 1000 fn_result = fn(ctx, inputs) 1001 flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) 1002 if flat_last_step_outputs: 1003 with ops.control_dependencies([fn_result]): 1004 return [array_ops.identity(f) for f in flat_last_step_outputs] 1005 else: 1006 return fn_result 1007 1008 # We capture the control_flow_context at this point, before we run `fn` 1009 # inside a while_loop and TPU replicate context. This is useful in cases 1010 # where we might need to exit these contexts and get back to the outer 1011 # context to do some things, for e.g. create an op which should be 1012 # evaluated only once at the end of the loop on the host. One such usage 1013 # is in creating metrics' value op. 1014 self._outer_control_flow_context = ( 1015 ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access 1016 1017 def rewrite_fn(*args): 1018 """The rewritten step fn running on TPU.""" 1019 del args 1020 1021 per_replica_inputs = multi_worker_iterator.get_next() 1022 replicate_inputs = [] 1023 for replica_id in range(self._num_replicas_in_sync): 1024 select_replica = lambda x: distribute_utils.select_replica( # pylint: disable=g-long-lambda 1025 replica_id, x) # pylint: disable=cell-var-from-loop 1026 replicate_inputs.append((nest.map_structure( 1027 select_replica, per_replica_inputs),)) 1028 1029 replicate_outputs = tpu.replicate( 1030 run_fn, 1031 replicate_inputs, 1032 device_assignment=self._device_assignment, 1033 xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=self 1034 ._use_spmd_for_xla_partitioning)) 1035 # If run_fn has tensor outputs, tpu.replicate returns a list of list. We 1036 # will flatten it in this case. If run_fn has no tensor outputs, 1037 # tpu.replicate returns a list of no_ops, we will keep the output as it 1038 # is. 1039 if isinstance(replicate_outputs[0], list): 1040 replicate_outputs = nest.flatten(replicate_outputs) 1041 1042 return replicate_outputs 1043 1044 # TODO(sourabhbajaj): The input to while loop should be based on the 1045 # output type of the step_fn 1046 assert isinstance(initial_loop_values, list) 1047 initial_loop_values = initial_loop_values * self._num_replicas_in_sync 1048 1049 # Put the while loop op on TPU host 0. 1050 with ops.device(self._host_device): 1051 if self.steps_per_run == 1: 1052 replicate_outputs = rewrite_fn() 1053 else: 1054 replicate_outputs = training_loop.repeat(iterations, rewrite_fn, 1055 initial_loop_values) 1056 1057 del self._outer_control_flow_context 1058 ctx.run_op = control_flow_ops.group(replicate_outputs) 1059 1060 if isinstance(replicate_outputs, list): 1061 # Filter out any ops from the outputs, typically this would be the case 1062 # when there were no tensor outputs. 1063 last_step_tensor_outputs = [ 1064 x for x in replicate_outputs if not isinstance(x, ops.Operation) 1065 ] 1066 1067 # Outputs are currently of the structure (flattened) 1068 # [output0_device0, output1_device0, output2_device0, 1069 # output0_device1, output1_device1, output2_device1, 1070 # ...] 1071 # Convert this to the following structure instead: (grouped by output) 1072 # [[output0_device0, output0_device1], 1073 # [output1_device0, output1_device1], 1074 # [output2_device0, output2_device1]] 1075 output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync 1076 last_step_tensor_outputs = [ 1077 last_step_tensor_outputs[i::output_num] for i in range(output_num) 1078 ] 1079 else: 1080 # no tensors returned. 1081 last_step_tensor_outputs = [] 1082 1083 _set_last_step_outputs(ctx, last_step_tensor_outputs) 1084 return ctx 1085 1086 def _call_for_each_replica(self, fn, args, kwargs): 1087 # TODO(jhseu): Consider making it so call_for_each_replica implies that 1088 # we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly. 1089 with _TPUReplicaContext(self._container_strategy()): 1090 return fn(*args, **kwargs) 1091 1092 @contextlib.contextmanager 1093 def experimental_logical_device(self, logical_device_id): 1094 """Places variables and ops on the specified logical device.""" 1095 num_logical_devices_per_replica = self._tpu_devices.shape[1] 1096 if logical_device_id >= num_logical_devices_per_replica: 1097 raise ValueError( 1098 "`logical_device_id` not in range (was {}, but there are only {} " 1099 "logical devices per replica).".format( 1100 logical_device_id, num_logical_devices_per_replica)) 1101 1102 self._logical_device_stack.append(logical_device_id) 1103 try: 1104 if tpu_util.enclosing_tpu_context() is None: 1105 yield 1106 else: 1107 with ops.device(tpu.core(logical_device_id)): 1108 yield 1109 finally: 1110 self._logical_device_stack.pop() 1111 1112 def _experimental_initialize_system(self): 1113 """Experimental method added to be used by Estimator. 1114 1115 This is a private method only to be used by Estimator. Other frameworks 1116 should directly be calling `tf.tpu.experimental.initialize_tpu_system` 1117 """ 1118 tpu_strategy_util.initialize_tpu_system(self._tpu_cluster_resolver) 1119 1120 def _create_variable(self, next_creator, **kwargs): 1121 """Create a TPUMirroredVariable. See `DistributionStrategy.scope`.""" 1122 if kwargs.pop("skip_mirrored_creator", False): 1123 return next_creator(**kwargs) 1124 1125 colocate_with = kwargs.pop("colocate_with", None) 1126 if colocate_with is None: 1127 devices = self._tpu_devices[:, self._logical_device_stack[-1]] 1128 elif isinstance(colocate_with, numpy_dataset.SingleDevice): 1129 with ops.device(colocate_with.device): 1130 return next_creator(**kwargs) 1131 else: 1132 devices = colocate_with._devices # pylint: disable=protected-access 1133 1134 def _real_mirrored_creator(**kwargs): # pylint: disable=g-missing-docstring 1135 initial_value = None 1136 value_list = [] 1137 for i, d in enumerate(devices): 1138 with ops.device(d): 1139 if i == 0: 1140 initial_value = kwargs["initial_value"] 1141 # Note: some v1 code expects variable initializer creation to happen 1142 # inside a init_scope. 1143 with maybe_init_scope(): 1144 initial_value = initial_value() if callable( 1145 initial_value) else initial_value 1146 1147 if i > 0: 1148 # Give replicas meaningful distinct names: 1149 var0name = value_list[0].name.split(":")[0] 1150 # We append a / to variable names created on replicas with id > 0 to 1151 # ensure that we ignore the name scope and instead use the given 1152 # name as the absolute name of the variable. 1153 kwargs["name"] = "%s/replica_%d/" % (var0name, i) 1154 kwargs["initial_value"] = initial_value 1155 1156 with context.device_policy(context.DEVICE_PLACEMENT_SILENT): 1157 v = next_creator(**kwargs) 1158 1159 assert not isinstance(v, tpu_values.TPUMirroredVariable) 1160 value_list.append(v) 1161 return value_list 1162 1163 return distribute_utils.create_mirrored_variable( 1164 self._container_strategy(), _real_mirrored_creator, 1165 distribute_utils.TPU_VARIABLE_CLASS_MAPPING, 1166 distribute_utils.TPU_VARIABLE_POLICY_MAPPING, **kwargs) 1167 1168 def _gather_to_implementation(self, value, destinations, axis, options): 1169 if not isinstance(value, values.DistributedValues): 1170 return value 1171 1172 value_list = value.values 1173 # pylint: disable=protected-access 1174 if isinstance( 1175 value, 1176 values.DistributedVariable) and value._packed_variable is not None: 1177 value_list = tuple( 1178 value._packed_variable.on_device(d) 1179 for d in value._packed_variable.devices) 1180 # pylint: enable=protected-access 1181 1182 # Currently XLA op by op mode has a limit for the number of inputs for a 1183 # single op, thus we break one `add_n` op into a group of `add_n` ops to 1184 # work around the constraint. 1185 if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT: 1186 output = array_ops.concat(value_list, axis=axis) 1187 else: 1188 output = array_ops.concat( 1189 value_list[:_XLA_OP_BY_OP_INPUTS_LIMIT], axis=axis) 1190 for i in range(_XLA_OP_BY_OP_INPUTS_LIMIT, len(value_list), 1191 _XLA_OP_BY_OP_INPUTS_LIMIT - 1): 1192 output = array_ops.concat( 1193 [output] + value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT - 1], 1194 axis=axis) 1195 1196 output = self._broadcast_output(destinations, output) 1197 return output 1198 1199 def _broadcast_output(self, destinations, output): 1200 devices = cross_device_ops_lib.get_devices_from(destinations) 1201 1202 if len(devices) == 1: 1203 # If necessary, copy to requested destination. 1204 dest_canonical = device_util.canonicalize(devices[0]) 1205 host_canonical = device_util.canonicalize(self._host_device) 1206 1207 if dest_canonical != host_canonical: 1208 with ops.device(dest_canonical): 1209 output = array_ops.identity(output) 1210 else: 1211 output = cross_device_ops_lib.simple_broadcast(output, destinations) 1212 1213 return output 1214 1215 def _reduce_to(self, reduce_op, value, destinations, options): 1216 if (isinstance(value, values.DistributedValues) or 1217 tensor_util.is_tf_type(value) 1218 ) and tpu_util.enclosing_tpu_context() is not None: 1219 if reduce_op == reduce_util.ReduceOp.MEAN: 1220 # TODO(jhseu): Revisit once we support model-parallelism. 1221 value *= (1. / self._num_replicas_in_sync) 1222 elif reduce_op != reduce_util.ReduceOp.SUM: 1223 raise NotImplementedError( 1224 "Currently only support sum & mean in TPUStrategy.") 1225 return tpu_ops.cross_replica_sum(value) 1226 1227 if not isinstance(value, values.DistributedValues): 1228 # This function handles reducing values that are not PerReplica or 1229 # Mirrored values. For example, the same value could be present on all 1230 # replicas in which case `value` would be a single value or value could 1231 # be 0. 1232 return cross_device_ops_lib.reduce_non_distributed_value( 1233 reduce_op, value, destinations, self._num_replicas_in_sync) 1234 1235 value_list = value.values 1236 # pylint: disable=protected-access 1237 if isinstance( 1238 value, 1239 values.DistributedVariable) and value._packed_variable is not None: 1240 value_list = tuple( 1241 value._packed_variable.on_device(d) 1242 for d in value._packed_variable.devices) 1243 # pylint: enable=protected-access 1244 1245 # Currently XLA op by op mode has a limit for the number of inputs for a 1246 # single op, thus we break one `add_n` op into a group of `add_n` ops to 1247 # work around the constraint. 1248 # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`. 1249 if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT: 1250 output = math_ops.add_n(value_list) 1251 else: 1252 output = array_ops.zeros_like(value_list[0], dtype=value_list[0].dtype) 1253 for i in range(0, len(value_list), _XLA_OP_BY_OP_INPUTS_LIMIT): 1254 output += math_ops.add_n(value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT]) 1255 1256 if reduce_op == reduce_util.ReduceOp.MEAN: 1257 output *= (1. / len(value_list)) 1258 1259 output = self._broadcast_output(destinations, output) 1260 return output 1261 1262 def _update(self, var, fn, args, kwargs, group): 1263 assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance( 1264 var, resource_variable_ops.BaseResourceVariable) 1265 if tpu_util.enclosing_tpu_context() is not None: 1266 if group: 1267 return fn(var, *args, **kwargs) 1268 else: 1269 return (fn(var, *args, **kwargs),) 1270 1271 # Otherwise, we revert to MirroredStrategy behavior and update the variable 1272 # on each replica directly. 1273 updates = [] 1274 values_and_devices = [] 1275 packed_var = var._packed_variable # pylint: disable=protected-access 1276 if packed_var is not None: 1277 for device in packed_var.devices: 1278 values_and_devices.append((packed_var, device)) 1279 else: 1280 for value in var.values: 1281 values_and_devices.append((value, value.device)) 1282 1283 if (var.synchronization != variables_lib.VariableSynchronization.ON_READ and 1284 var.aggregation != variables_lib.VariableAggregation.NONE): 1285 distribute_utils.assert_mirrored(args) 1286 distribute_utils.assert_mirrored(kwargs) 1287 for i, value_and_device in enumerate(values_and_devices): 1288 value = value_and_device[0] 1289 device = value_and_device[1] 1290 name = "update_%d" % i 1291 with ops.device(device), \ 1292 distribute_lib.UpdateContext(i), \ 1293 ops.name_scope(name): 1294 # If args and kwargs are not mirrored, the value is returned as is. 1295 updates.append( 1296 fn(value, *distribute_utils.select_replica(i, args), 1297 **distribute_utils.select_replica(i, kwargs))) 1298 return distribute_utils.update_regroup(self, updates, group) 1299 1300 def read_var(self, var): 1301 assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance( 1302 var, resource_variable_ops.BaseResourceVariable) 1303 return var.read_value() 1304 1305 def _local_results(self, val): 1306 if isinstance(val, values.DistributedValues): 1307 return val.values 1308 return (val,) 1309 1310 def value_container(self, value): 1311 return value 1312 1313 def _broadcast_to(self, tensor, destinations): 1314 del destinations 1315 # This is both a fast path for Python constants, and a way to delay 1316 # converting Python values to a tensor until we know what type it 1317 # should be converted to. Otherwise we have trouble with: 1318 # global_step.assign_add(1) 1319 # since the `1` gets broadcast as an int32 but global_step is int64. 1320 if isinstance(tensor, (float, int)): 1321 return tensor 1322 if tpu_util.enclosing_tpu_context() is not None: 1323 broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)] 1324 result = tpu_ops.all_to_all( 1325 broadcast_tensor, 1326 concat_dimension=0, 1327 split_dimension=0, 1328 split_count=self._num_replicas_in_sync) 1329 1330 # This uses the broadcasted value from the first replica because the only 1331 # caller of this is for ONLY_FIRST_REPLICA variables aggregation. 1332 return result[0] 1333 return tensor 1334 1335 @property 1336 def num_hosts(self): 1337 if self._device_assignment is None: 1338 return self._tpu_metadata.num_hosts 1339 1340 return len(set([self._device_assignment.host_device(r) 1341 for r in range(self._device_assignment.num_replicas)])) 1342 1343 @property 1344 def num_replicas_per_host(self): 1345 if self._device_assignment is None: 1346 return self._tpu_metadata.num_of_cores_per_host 1347 1348 # TODO(sourabhbajaj): Remove this method we use inputs and remove infeed 1349 # as the computation of num_replicas_per_host is not a constant 1350 # when using device_assignment. This is a temporary workaround to support 1351 # StatefulRNN as everything is 1 in that case. 1352 # This method needs to take host_id as input for correct computation. 1353 max_models_per_host = (self._tpu_metadata.num_of_cores_per_host // 1354 self._device_assignment.num_cores_per_replica) 1355 return min(self._device_assignment.num_replicas, max_models_per_host) 1356 1357 @property 1358 def _num_replicas_in_sync(self): 1359 if self._device_assignment is None: 1360 return self._tpu_metadata.num_cores 1361 return self._device_assignment.num_replicas 1362 1363 @property 1364 def experimental_between_graph(self): 1365 return False 1366 1367 @property 1368 def experimental_should_init(self): 1369 return True 1370 1371 @property 1372 def should_checkpoint(self): 1373 return True 1374 1375 @property 1376 def should_save_summary(self): 1377 return True 1378 1379 @property 1380 def worker_devices(self): 1381 return tuple(self._tpu_devices[:, self._logical_device_stack[-1]]) 1382 1383 @property 1384 def parameter_devices(self): 1385 return self.worker_devices 1386 1387 def non_slot_devices(self, var_list): 1388 return self._host_device 1389 1390 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 1391 del colocate_with 1392 with ops.device(self._host_device), distribute_lib.UpdateContext(None): 1393 result = fn(*args, **kwargs) 1394 if group: 1395 return result 1396 else: 1397 return nest.map_structure(self._local_results, result) 1398 1399 def _configure(self, 1400 session_config=None, 1401 cluster_spec=None, 1402 task_type=None, 1403 task_id=None): 1404 del cluster_spec, task_type, task_id 1405 if session_config: 1406 session_config.CopyFrom(self._update_config_proto(session_config)) 1407 1408 def _update_config_proto(self, config_proto): 1409 updated_config = copy.deepcopy(config_proto) 1410 updated_config.isolate_session_state = True 1411 cluster_spec = self._tpu_cluster_resolver.cluster_spec() 1412 if cluster_spec: 1413 updated_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) 1414 return updated_config 1415 1416 # TODO(priyag): Delete this once all strategies use global batch size. 1417 @property 1418 def _global_batch_size(self): 1419 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. 1420 1421 `make_input_fn_iterator` assumes per-replica batching. 1422 1423 Returns: 1424 Boolean. 1425 """ 1426 return True 1427 1428 def tpu_run(self, fn, args, kwargs, options=None): 1429 func = self._tpu_function_creator(fn, options) 1430 return func(args, kwargs) 1431 1432 def _tpu_function_creator(self, fn, options): 1433 if context.executing_eagerly() and fn in self._tpu_function_cache: 1434 return self._tpu_function_cache[fn] 1435 1436 strategy = self._container_strategy() 1437 1438 def tpu_function(args, kwargs): 1439 """TF Function used to replicate the user computation.""" 1440 if kwargs is None: 1441 kwargs = {} 1442 1443 # Used to re-structure flattened output tensors from `tpu.replicate()` 1444 # into a structured format. 1445 result = [[]] 1446 1447 def replicated_fn(replica_id, replica_args, replica_kwargs): 1448 """Wraps user function to provide replica ID and `Tensor` inputs.""" 1449 with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id): 1450 result[0] = fn(*replica_args, **replica_kwargs) 1451 return result[0] 1452 1453 replicate_inputs = [] # By replica. 1454 for i in range(strategy.num_replicas_in_sync): 1455 replicate_inputs.append( 1456 [constant_op.constant(i, dtype=dtypes.int32), 1457 distribute_utils.select_replica(i, args), 1458 distribute_utils.select_replica(i, kwargs)]) 1459 1460 # Construct and pass `maximum_shapes` so that we could support dynamic 1461 # shapes using dynamic padder. 1462 if options.experimental_enable_dynamic_batch_size and replicate_inputs: 1463 maximum_shapes = [] 1464 flattened_list = nest.flatten(replicate_inputs[0]) 1465 for input_tensor in flattened_list: 1466 if tensor_util.is_tf_type(input_tensor): 1467 rank = input_tensor.shape.rank 1468 else: 1469 rank = np.ndim(input_tensor) 1470 maximum_shape = tensor_shape.TensorShape([None] * rank) 1471 maximum_shapes.append(maximum_shape) 1472 maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], 1473 maximum_shapes) 1474 else: 1475 maximum_shapes = None 1476 1477 if options.experimental_bucketizing_dynamic_shape: 1478 padding_spec = tpu.PaddingSpec.POWER_OF_TWO 1479 else: 1480 padding_spec = None 1481 1482 with strategy.scope(): 1483 replicate_outputs = tpu.replicate( 1484 replicated_fn, 1485 replicate_inputs, 1486 device_assignment=self._device_assignment, 1487 maximum_shapes=maximum_shapes, 1488 padding_spec=padding_spec, 1489 xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=self 1490 ._use_spmd_for_xla_partitioning)) 1491 1492 # Remove all no ops that may have been added during 'tpu.replicate()' 1493 if isinstance(result[0], list): 1494 result[0] = [ 1495 output for output in result[0] if not isinstance( 1496 output, ops.Operation) 1497 ] 1498 1499 # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. 1500 if result[0] is None or isinstance(result[0], ops.Operation): 1501 replicate_outputs = [None] * len(replicate_outputs) 1502 else: 1503 replicate_outputs = [ 1504 nest.pack_sequence_as(result[0], nest.flatten(replica_output)) 1505 for replica_output in replicate_outputs 1506 ] 1507 return distribute_utils.regroup(replicate_outputs) 1508 1509 if context.executing_eagerly(): 1510 tpu_function = def_function.function(tpu_function) 1511 self._tpu_function_cache[fn] = tpu_function 1512 return tpu_function 1513 1514 def _in_multi_worker_mode(self): 1515 """Whether this strategy indicates working in multi-worker settings.""" 1516 # TPUStrategy has different distributed training structure that the whole 1517 # cluster should be treated as single worker from higher-level (e.g. Keras) 1518 # library's point of view. 1519 # TODO(rchao): Revisit this as we design a fault-tolerance solution for 1520 # TPUStrategy. 1521 return False 1522 1523 def _get_local_replica_id(self, replica_id_in_sync_group): 1524 return replica_id_in_sync_group 1525 1526 1527class _TPUReplicaContext(distribute_lib.ReplicaContext): 1528 """Replication Context class for TPU Strategy.""" 1529 1530 # TODO(sourabhbajaj): Call for each replica should be updating this. 1531 # TODO(b/118385803): Always properly initialize replica_id. 1532 def __init__(self, strategy, replica_id_in_sync_group=0): 1533 distribute_lib.ReplicaContext.__init__( 1534 self, strategy, replica_id_in_sync_group=replica_id_in_sync_group) 1535 1536 @property 1537 def devices(self): 1538 distribute_lib.require_replica_context(self) 1539 ds = self._strategy 1540 replica_id = tensor_util.constant_value(self.replica_id_in_sync_group) 1541 1542 if replica_id is None: # Non-constant `Tensor` inside `tpu.replicate`. 1543 # TODO(cjfj): Return other devices when model parallelism is supported. 1544 return (tpu.core(0),) 1545 else: 1546 return (ds.extended.worker_devices[replica_id],) 1547 1548 def experimental_logical_device(self, logical_device_id): 1549 """Places variables and ops on the specified logical device.""" 1550 return self.strategy.extended.experimental_logical_device(logical_device_id) 1551 1552 # TODO(wxinyi): Investigate whether to use cross_replica_sum to optimize it. 1553 def all_gather(self, value, axis, experimental_hints=None): 1554 del experimental_hints 1555 for v in nest.flatten(value): 1556 if isinstance(v, ops.IndexedSlices): 1557 raise NotImplementedError("all_gather does not support IndexedSlices") 1558 1559 def _all_to_all(value, axis): 1560 # The underlying AllToAllOp first do a split of the input value and then 1561 # cross-replica communication and concatenation of the result. So we 1562 # concatenate the local tensor here first. 1563 inputs = array_ops.concat( 1564 [value for _ in range(self.num_replicas_in_sync)], axis=0) 1565 unordered_output = tpu_ops.all_to_all( 1566 inputs, 1567 concat_dimension=axis, 1568 split_dimension=0, 1569 split_count=self.num_replicas_in_sync) 1570 1571 # Re-order since xla.replica_id and ReplicaContext.replica_id mismatch. 1572 # xla_id = xla.replica_id() 1573 concat_replica_id = array_ops.concat([ 1574 array_ops.expand_dims_v2(self.replica_id_in_sync_group, 0) 1575 for _ in range(self.num_replicas_in_sync) 1576 ], 1577 axis=0) 1578 replica_ids = tpu_ops.all_to_all( 1579 concat_replica_id, 1580 concat_dimension=0, 1581 split_dimension=0, 1582 split_count=self.num_replicas_in_sync) 1583 1584 splited_unordered = array_ops.split( 1585 unordered_output, 1586 num_or_size_splits=self.num_replicas_in_sync, 1587 axis=axis) 1588 sorted_with_extra_dim = math_ops.unsorted_segment_sum( 1589 array_ops.concat([ 1590 array_ops.expand_dims(replica, axis=0) 1591 for replica in splited_unordered 1592 ], 1593 axis=0), 1594 replica_ids, 1595 num_segments=self.num_replicas_in_sync) 1596 1597 splited_with_extra_dim = array_ops.split( 1598 sorted_with_extra_dim, 1599 num_or_size_splits=self.num_replicas_in_sync, 1600 axis=0) 1601 squeezed = [ 1602 array_ops.squeeze(replica, axis=0) 1603 for replica in splited_with_extra_dim 1604 ] 1605 result = array_ops.concat(squeezed, axis=axis) 1606 return result 1607 1608 ys = [_all_to_all(t, axis=axis) for t in nest.flatten(value)] 1609 return nest.pack_sequence_as(value, ys) 1610 1611 1612def _set_last_step_outputs(ctx, last_step_tensor_outputs): 1613 """Sets the last step outputs on the given context.""" 1614 # Convert replicate_outputs to the original dict structure of 1615 # last_step_outputs. 1616 last_step_tensor_outputs_dict = nest.pack_sequence_as( 1617 ctx.last_step_outputs, last_step_tensor_outputs) 1618 1619 for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access 1620 output = last_step_tensor_outputs_dict[name] 1621 # For outputs that aren't reduced, return a PerReplica of all values. Else 1622 # take the first value from the list as each value should be the same. 1623 if reduce_op is None: 1624 last_step_tensor_outputs_dict[name] = values.PerReplica(output) 1625 else: 1626 # TODO(priyag): Should this return the element or a list with 1 element 1627 last_step_tensor_outputs_dict[name] = output[0] 1628 ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access 1629