1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TPU Strategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import atexit 22import collections 23import contextlib 24import copy 25import functools 26import weakref 27 28from absl import logging 29import numpy as np 30 31from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding 32from tensorflow.python.autograph.core import ag_ctx as autograph_ctx 33from tensorflow.python.autograph.impl import api as autograph 34from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 35from tensorflow.python.distribute import device_util 36from tensorflow.python.distribute import distribute_lib 37from tensorflow.python.distribute import distribute_utils 38from tensorflow.python.distribute import input_lib 39from tensorflow.python.distribute import numpy_dataset 40from tensorflow.python.distribute import reduce_util 41from tensorflow.python.distribute import tpu_util 42from tensorflow.python.distribute import tpu_values 43from tensorflow.python.distribute import values 44from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver 45from tensorflow.python.eager import context 46from tensorflow.python.eager import def_function 47from tensorflow.python.eager import function 48from tensorflow.python.framework import constant_op 49from tensorflow.python.framework import device_spec 50from tensorflow.python.framework import dtypes 51from tensorflow.python.framework import ops 52from tensorflow.python.framework import sparse_tensor 53from tensorflow.python.framework import tensor_shape 54from tensorflow.python.framework import tensor_util 55from tensorflow.python.ops import array_ops 56from tensorflow.python.ops import control_flow_ops 57from tensorflow.python.ops import math_ops 58from tensorflow.python.ops import resource_variable_ops 59from tensorflow.python.ops import variables as variables_lib 60from tensorflow.python.ops.ragged import ragged_tensor 61from tensorflow.python.tpu import device_assignment as device_assignment_lib # pylint: disable=unused-import 62from tensorflow.python.tpu import tpu 63from tensorflow.python.tpu import tpu_strategy_util 64from tensorflow.python.tpu import training_loop 65from tensorflow.python.tpu.ops import tpu_ops 66from tensorflow.python.util import deprecation 67from tensorflow.python.util import nest 68from tensorflow.python.util import tf_inspect 69from tensorflow.python.util.tf_export import tf_export 70 71 72_XLA_OP_BY_OP_INPUTS_LIMIT = 200 73 74 75@contextlib.contextmanager 76def maybe_init_scope(): 77 if ops.executing_eagerly_outside_functions(): 78 yield 79 else: 80 with ops.init_scope(): 81 yield 82 83 84def validate_run_function(fn): 85 """Validate the function passed into strategy.run.""" 86 87 # We allow three types of functions/objects passed into TPUStrategy 88 # run in eager mode: 89 # 1. a user annotated tf.function 90 # 2. a ConcreteFunction, this is mostly what you get from loading a saved 91 # model. 92 # 3. a callable object and the `__call__` method itself is a tf.function. 93 # 94 # Otherwise we return an error, because we don't support eagerly running 95 # run in TPUStrategy. 96 97 if context.executing_eagerly() \ 98 and not isinstance(fn, def_function.Function) \ 99 and not isinstance(fn, function.ConcreteFunction) \ 100 and not (callable(fn) and isinstance(fn.__call__, def_function.Function)): 101 raise NotImplementedError( 102 "TPUStrategy.run(fn, ...) does not support pure eager " 103 "execution. please make sure the function passed into " 104 "`strategy.run` is a `tf.function` or " 105 "`strategy.run` is called inside a `tf.function` if " 106 "eager behavior is enabled.") 107 108 109def _maybe_partial_apply_variables(fn, args, kwargs): 110 """Inspects arguments to partially apply any DistributedVariable. 111 112 This avoids an automatic cast of the current variable value to tensor. 113 114 Note that a variable may be captured implicitly with Python scope instead of 115 passing it to run(), but supporting run() keeps behavior consistent 116 with MirroredStrategy. 117 118 Since positional arguments must be applied from left to right, this function 119 does some tricky function inspection to move variable positional arguments 120 into kwargs. As a result of this, we can't support passing Variables as *args, 121 nor as args to functions which combine both explicit positional arguments and 122 *args. 123 124 Args: 125 fn: The function to run, as passed to run(). 126 args: Positional arguments to fn, as passed to run(). 127 kwargs: Keyword arguments to fn, as passed to run(). 128 129 Returns: 130 A tuple of the function (possibly wrapped), args, kwargs (both 131 possibly filtered, with members of args possibly moved to kwargs). 132 If no variables are found, this function is a noop. 133 134 Raises: 135 ValueError: If the function signature makes unsupported use of *args, or if 136 too many arguments are passed. 137 """ 138 139 def is_distributed_var(x): 140 flat = nest.flatten(x) 141 return flat and isinstance(flat[0], values.DistributedVariable) 142 143 # We will split kwargs into two dicts, one of which will be applied now. 144 var_kwargs = {} 145 nonvar_kwargs = {} 146 147 if kwargs: 148 var_kwargs = {k: v for k, v in kwargs.items() if is_distributed_var(v)} 149 if var_kwargs: 150 nonvar_kwargs = { 151 k: v for k, v in kwargs.items() if not is_distributed_var(v) 152 } 153 154 # Dump the argument names of `fn` to a list. This will include both positional 155 # and keyword arguments, but since positional arguments come first we can 156 # look up names of positional arguments by index. 157 positional_args = [] 158 index_of_star_args = None 159 for i, p in enumerate(tf_inspect.signature(fn).parameters.values()): 160 # Class methods define "self" as first argument, but we don't pass "self". 161 # Note that this is a heuristic, as a method can name its first argument 162 # something else, and a function can define a first argument "self" as well. 163 # In both of these cases, using a Variable will fail with an unfortunate 164 # error about the number of arguments. 165 # inspect.is_method() seems not to work here, possibly due to the use of 166 # tf.function(). 167 if i == 0 and p.name == "self": 168 continue 169 170 if p.kind == tf_inspect.Parameter.POSITIONAL_OR_KEYWORD: 171 positional_args.append(p.name) 172 173 elif p.kind == tf_inspect.Parameter.VAR_POSITIONAL: 174 # We'll raise an error later if a variable is passed to *args, since we 175 # can neither pass it by name nor partially apply it. This case only 176 # happens once at most. 177 index_of_star_args = i 178 179 elif p.kind == tf_inspect.Parameter.POSITIONAL_ONLY: 180 # This is a rare Python feature, indicating a / in the arg list. 181 if var_kwargs or any(is_distributed_var(a) for a in args): 182 raise ValueError( 183 "Mixing Variables and positional-only parameters not supported by " 184 "TPUStrategy.") 185 return fn, args, kwargs 186 187 star_args = [] 188 have_seen_var_arg = False 189 190 for i, a in enumerate(args): 191 if is_distributed_var(a): 192 if index_of_star_args is not None and i >= index_of_star_args: 193 raise ValueError( 194 "TPUStrategy.run() cannot handle Variables passed to *args. " 195 "Either name the function argument, or capture the Variable " 196 "implicitly.") 197 if len(positional_args) <= i: 198 raise ValueError( 199 "Too many positional arguments passed to call to TPUStrategy.run()." 200 ) 201 var_kwargs[positional_args[i]] = a 202 have_seen_var_arg = True 203 else: 204 if index_of_star_args is not None and i >= index_of_star_args: 205 if have_seen_var_arg: 206 raise ValueError( 207 "TPUStrategy.run() cannot handle both Variables and a mix of " 208 "positional args and *args. Either remove the *args, or capture " 209 "the Variable implicitly.") 210 else: 211 star_args.append(a) 212 continue 213 214 if len(positional_args) <= i: 215 raise ValueError( 216 "Too many positional arguments passed to call to TPUStrategy.run()." 217 ) 218 nonvar_kwargs[positional_args[i]] = a 219 220 if var_kwargs: 221 return functools.partial(fn, **var_kwargs), star_args, nonvar_kwargs 222 return fn, args, kwargs 223 224 225@tf_export("distribute.TPUStrategy", v1=[]) 226class TPUStrategyV2(distribute_lib.Strategy): 227 """Synchronous training on TPUs and TPU Pods. 228 229 To construct a TPUStrategy object, you need to run the 230 initialization code as below: 231 232 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 233 >>> tf.config.experimental_connect_to_cluster(resolver) 234 >>> tf.tpu.experimental.initialize_tpu_system(resolver) 235 >>> strategy = tf.distribute.TPUStrategy(resolver) 236 237 While using distribution strategies, the variables created within the 238 strategy's scope will be replicated across all the replicas and can be kept in 239 sync using all-reduce algorithms. 240 241 To run TF2 programs on TPUs, you can either use `.compile` and 242 `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized 243 training loop by calling `strategy.run` directly. Note that 244 TPUStrategy doesn't support pure eager execution, so please make sure the 245 function passed into `strategy.run` is a `tf.function` or 246 `strategy.run` is called inside a `tf.function` if eager 247 behavior is enabled. See more details in https://www.tensorflow.org/guide/tpu. 248 249 `distribute_datasets_from_function` and 250 `experimental_distribute_dataset` APIs can be used to distribute the dataset 251 across the TPU workers when writing your own training loop. If you are using 252 `fit` and `compile` methods available in `tf.keras.Model`, then Keras will 253 handle the distribution for you. 254 255 An example of writing customized training loop on TPUs: 256 257 >>> with strategy.scope(): 258 ... model = tf.keras.Sequential([ 259 ... tf.keras.layers.Dense(2, input_shape=(5,)), 260 ... ]) 261 ... optimizer = tf.keras.optimizers.SGD(learning_rate=0.1) 262 263 >>> def dataset_fn(ctx): 264 ... x = np.random.random((2, 5)).astype(np.float32) 265 ... y = np.random.randint(2, size=(2, 1)) 266 ... dataset = tf.data.Dataset.from_tensor_slices((x, y)) 267 ... return dataset.repeat().batch(1, drop_remainder=True) 268 >>> dist_dataset = strategy.distribute_datasets_from_function( 269 ... dataset_fn) 270 >>> iterator = iter(dist_dataset) 271 272 >>> @tf.function() 273 ... def train_step(iterator): 274 ... 275 ... def step_fn(inputs): 276 ... features, labels = inputs 277 ... with tf.GradientTape() as tape: 278 ... logits = model(features, training=True) 279 ... loss = tf.keras.losses.sparse_categorical_crossentropy( 280 ... labels, logits) 281 ... 282 ... grads = tape.gradient(loss, model.trainable_variables) 283 ... optimizer.apply_gradients(zip(grads, model.trainable_variables)) 284 ... 285 ... strategy.run(step_fn, args=(next(iterator),)) 286 287 >>> train_step(iterator) 288 289 For the advanced use cases like model parallelism, you can set 290 `experimental_device_assignment` argument when creating TPUStrategy to specify 291 number of replicas and number of logical devices. Below is an example to 292 initialize TPU system with 2 logical devices and 1 replica. 293 294 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 295 >>> tf.config.experimental_connect_to_cluster(resolver) 296 >>> topology = tf.tpu.experimental.initialize_tpu_system(resolver) 297 >>> device_assignment = tf.tpu.experimental.DeviceAssignment.build( 298 ... topology, 299 ... computation_shape=[1, 1, 1, 2], 300 ... num_replicas=1) 301 >>> strategy = tf.distribute.TPUStrategy( 302 ... resolver, experimental_device_assignment=device_assignment) 303 304 Then you can run a `tf.add` operation only on logical device 0. 305 306 >>> @tf.function() 307 ... def step_fn(inputs): 308 ... features, _ = inputs 309 ... output = tf.add(features, features) 310 ... 311 ... # Add operation will be executed on logical device 0. 312 ... output = strategy.experimental_assign_to_logical_device(output, 0) 313 ... return output 314 >>> dist_dataset = strategy.distribute_datasets_from_function( 315 ... dataset_fn) 316 >>> iterator = iter(dist_dataset) 317 >>> strategy.run(step_fn, args=(next(iterator),)) 318 """ 319 320 def __init__(self, 321 tpu_cluster_resolver=None, 322 experimental_device_assignment=None): 323 """Synchronous training in TPU donuts or Pods. 324 325 Args: 326 tpu_cluster_resolver: A 327 `tf.distribute.cluster_resolver.TPUClusterResolver` instance, which 328 provides information about the TPU cluster. If None, it will assume 329 running on a local TPU worker. 330 experimental_device_assignment: Optional 331 `tf.tpu.experimental.DeviceAssignment` to specify the placement of 332 replicas on the TPU cluster. 333 """ 334 super(TPUStrategyV2, self).__init__(TPUExtended( 335 self, tpu_cluster_resolver, 336 device_assignment=experimental_device_assignment)) 337 distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy") 338 distribute_lib.distribution_strategy_replica_gauge.get_cell( 339 "num_workers").set(self.extended.num_hosts) 340 distribute_lib.distribution_strategy_replica_gauge.get_cell( 341 "num_replicas_per_worker").set(self.extended.num_replicas_per_host) 342 # Packed variable is used to reduce the overhead of function execution. 343 # For a DistributedVariable, only one variable handle is captured into a 344 # function graph. It's only supported in eager mode. 345 self._enable_packed_variable_in_eager_mode = True 346 347 def run(self, fn, args=(), kwargs=None, options=None): 348 """Run the computation defined by `fn` on each TPU replica. 349 350 Executes ops specified by `fn` on each replica. If `args` or `kwargs` have 351 `tf.distribute.DistributedValues`, such as those produced by a 352 `tf.distribute.DistributedDataset` from 353 `tf.distribute.Strategy.experimental_distribute_dataset` or 354 `tf.distribute.Strategy.distribute_datasets_from_function`, 355 when `fn` is executed on a particular replica, it will be executed with the 356 component of `tf.distribute.DistributedValues` that correspond to that 357 replica. 358 359 `fn` may call `tf.distribute.get_replica_context()` to access members such 360 as `all_reduce`. 361 362 All arguments in `args` or `kwargs` should either be nest of tensors or 363 `tf.distribute.DistributedValues` containing tensors or composite tensors. 364 365 Example usage: 366 367 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 368 >>> tf.config.experimental_connect_to_cluster(resolver) 369 >>> tf.tpu.experimental.initialize_tpu_system(resolver) 370 >>> strategy = tf.distribute.TPUStrategy(resolver) 371 >>> @tf.function 372 ... def run(): 373 ... def value_fn(value_context): 374 ... return value_context.num_replicas_in_sync 375 ... distributed_values = ( 376 ... strategy.experimental_distribute_values_from_function(value_fn)) 377 ... def replica_fn(input): 378 ... return input * 2 379 ... return strategy.run(replica_fn, args=(distributed_values,)) 380 >>> result = run() 381 382 Args: 383 fn: The function to run. The output must be a `tf.nest` of `Tensor`s. 384 args: (Optional) Positional arguments to `fn`. 385 kwargs: (Optional) Keyword arguments to `fn`. 386 options: (Optional) An instance of `tf.distribute.RunOptions` specifying 387 the options to run `fn`. 388 389 Returns: 390 Merged return value of `fn` across replicas. The structure of the return 391 value is the same as the return value from `fn`. Each element in the 392 structure can either be `tf.distribute.DistributedValues`, `Tensor` 393 objects, or `Tensor`s (for example, if running on a single replica). 394 """ 395 validate_run_function(fn) 396 397 fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs) 398 399 # Note: the target function is converted to graph even when in Eager mode, 400 # so autograph is on by default here. 401 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) 402 options = options or distribute_lib.RunOptions() 403 return self.extended.tpu_run(fn, args, kwargs, options) 404 405 def experimental_assign_to_logical_device(self, tensor, logical_device_id): 406 """Adds annotation that `tensor` will be assigned to a logical device. 407 408 This adds an annotation to `tensor` specifying that operations on 409 `tensor` will be invoked on logical core device id `logical_device_id`. 410 When model parallelism is used, the default behavior is that all ops 411 are placed on zero-th logical device. 412 413 ```python 414 415 # Initializing TPU system with 2 logical devices and 4 replicas. 416 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 417 tf.config.experimental_connect_to_cluster(resolver) 418 topology = tf.tpu.experimental.initialize_tpu_system(resolver) 419 device_assignment = tf.tpu.experimental.DeviceAssignment.build( 420 topology, 421 computation_shape=[1, 1, 1, 2], 422 num_replicas=4) 423 strategy = tf.distribute.TPUStrategy( 424 resolver, experimental_device_assignment=device_assignment) 425 iterator = iter(inputs) 426 427 @tf.function() 428 def step_fn(inputs): 429 output = tf.add(inputs, inputs) 430 431 # Add operation will be executed on logical device 0. 432 output = strategy.experimental_assign_to_logical_device(output, 0) 433 return output 434 435 strategy.run(step_fn, args=(next(iterator),)) 436 ``` 437 438 Args: 439 tensor: Input tensor to annotate. 440 logical_device_id: Id of the logical core to which the tensor will be 441 assigned. 442 443 Raises: 444 ValueError: The logical device id presented is not consistent with total 445 number of partitions specified by the device assignment. 446 447 Returns: 448 Annotated tensor with identical value as `tensor`. 449 """ 450 num_logical_devices_per_replica = self.extended._tpu_devices.shape[1] # pylint: disable=protected-access 451 if (logical_device_id < 0 or 452 logical_device_id >= num_logical_devices_per_replica): 453 raise ValueError("`logical_core_id` to assign must be lower then total " 454 "number of logical devices per replica. Received " 455 "logical device id {} but there are only total of {} " 456 "logical devices in replica.".format( 457 logical_device_id, num_logical_devices_per_replica)) 458 return xla_sharding.assign_device( 459 tensor, logical_device_id, use_sharding_op=True) 460 461 def experimental_split_to_logical_devices(self, tensor, partition_dimensions): 462 """Adds annotation that `tensor` will be split across logical devices. 463 464 This adds an annotation to tensor `tensor` specifying that operations on 465 `tensor` will be split among multiple logical devices. Tensor `tensor` will 466 be split across dimensions specified by `partition_dimensions`. 467 The dimensions of `tensor` must be divisible by corresponding value in 468 `partition_dimensions`. 469 470 For example, for system with 8 logical devices, if `tensor` is an image 471 tensor with shape (batch_size, width, height, channel) and 472 `partition_dimensions` is [1, 2, 4, 1], then `tensor` will be split 473 2 in width dimension and 4 way in height dimension and the split 474 tensor values will be fed into 8 logical devices. 475 476 ```python 477 # Initializing TPU system with 8 logical devices and 1 replica. 478 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 479 tf.config.experimental_connect_to_cluster(resolver) 480 topology = tf.tpu.experimental.initialize_tpu_system(resolver) 481 device_assignment = tf.tpu.experimental.DeviceAssignment.build( 482 topology, 483 computation_shape=[1, 2, 2, 2], 484 num_replicas=1) 485 strategy = tf.distribute.TPUStrategy( 486 resolver, experimental_device_assignment=device_assignment) 487 488 iterator = iter(inputs) 489 490 @tf.function() 491 def step_fn(inputs): 492 inputs = strategy.experimental_split_to_logical_devices( 493 inputs, [1, 2, 4, 1]) 494 495 # model() function will be executed on 8 logical devices with `inputs` 496 # split 2 * 4 ways. 497 output = model(inputs) 498 return output 499 500 strategy.run(step_fn, args=(next(iterator),)) 501 ``` 502 Args: 503 tensor: Input tensor to annotate. 504 partition_dimensions: An unnested list of integers with the size equal to 505 rank of `tensor` specifying how `tensor` will be partitioned. The 506 product of all elements in `partition_dimensions` must be equal to the 507 total number of logical devices per replica. 508 509 Raises: 510 ValueError: 1) If the size of partition_dimensions does not equal to rank 511 of `tensor` or 2) if product of elements of `partition_dimensions` does 512 not match the number of logical devices per replica defined by the 513 implementing DistributionStrategy's device specification or 514 3) if a known size of `tensor` is not divisible by corresponding 515 value in `partition_dimensions`. 516 517 Returns: 518 Annotated tensor with identical value as `tensor`. 519 """ 520 num_logical_devices_per_replica = self.extended._tpu_devices.shape[1] # pylint: disable=protected-access 521 num_partition_splits = np.prod(partition_dimensions) 522 input_shape = tensor.shape 523 tensor_rank = len(input_shape) 524 525 if tensor_rank != len(partition_dimensions): 526 raise ValueError("Length of `partition_dimensions` ({}) must be " 527 "equal to the rank of `x` ({}).".format( 528 len(partition_dimensions), tensor_rank)) 529 530 for dim_index, dim_size in enumerate(input_shape): 531 if dim_size is None: 532 continue 533 534 split_size = partition_dimensions[dim_index] 535 if dim_size % split_size != 0: 536 raise ValueError("Tensor shape at dimension {} ({}) must be " 537 "divisible by corresponding value specified " 538 "by `partition_dimensions` ({}).".format( 539 dim_index, dim_size, split_size)) 540 541 if num_partition_splits != num_logical_devices_per_replica: 542 raise ValueError("Number of logical devices ({}) does not match the " 543 "number of partition splits specified ({}).".format( 544 num_logical_devices_per_replica, 545 num_partition_splits)) 546 547 tile_assignment = np.arange(num_partition_splits).reshape( 548 partition_dimensions) 549 return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True) 550 551 def experimental_replicate_to_logical_devices(self, tensor): 552 """Adds annotation that `tensor` will be replicated to all logical devices. 553 554 This adds an annotation to tensor `tensor` specifying that operations on 555 `tensor` will be invoked on all logical devices. 556 557 ```python 558 # Initializing TPU system with 2 logical devices and 4 replicas. 559 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 560 tf.config.experimental_connect_to_cluster(resolver) 561 topology = tf.tpu.experimental.initialize_tpu_system(resolver) 562 device_assignment = tf.tpu.experimental.DeviceAssignment.build( 563 topology, 564 computation_shape=[1, 1, 1, 2], 565 num_replicas=4) 566 strategy = tf.distribute.TPUStrategy( 567 resolver, experimental_device_assignment=device_assignment) 568 569 iterator = iter(inputs) 570 571 @tf.function() 572 def step_fn(inputs): 573 images, labels = inputs 574 images = strategy.experimental_split_to_logical_devices( 575 inputs, [1, 2, 4, 1]) 576 577 # model() function will be executed on 8 logical devices with `inputs` 578 # split 2 * 4 ways. 579 output = model(inputs) 580 581 # For loss calculation, all logical devices share the same logits 582 # and labels. 583 labels = strategy.experimental_replicate_to_logical_devices(labels) 584 output = strategy.experimental_replicate_to_logical_devices(output) 585 loss = loss_fn(labels, output) 586 587 return loss 588 589 strategy.run(step_fn, args=(next(iterator),)) 590 ``` 591 Args: 592 tensor: Input tensor to annotate. 593 594 Returns: 595 Annotated tensor with identical value as `tensor`. 596 """ 597 return xla_sharding.replicate(tensor, use_sharding_op=True) 598 599 600@tf_export("distribute.experimental.TPUStrategy", v1=[]) 601@deprecation.deprecated_endpoints("distribute.experimental.TPUStrategy") 602class TPUStrategy(distribute_lib.Strategy): 603 """Synchronous training on TPUs and TPU Pods. 604 605 To construct a TPUStrategy object, you need to run the 606 initialization code as below: 607 608 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 609 >>> tf.config.experimental_connect_to_cluster(resolver) 610 >>> tf.tpu.experimental.initialize_tpu_system(resolver) 611 >>> strategy = tf.distribute.experimental.TPUStrategy(resolver) 612 613 While using distribution strategies, the variables created within the 614 strategy's scope will be replicated across all the replicas and can be kept in 615 sync using all-reduce algorithms. 616 617 To run TF2 programs on TPUs, you can either use `.compile` and 618 `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized 619 training loop by calling `strategy.run` directly. Note that 620 TPUStrategy doesn't support pure eager execution, so please make sure the 621 function passed into `strategy.run` is a `tf.function` or 622 `strategy.run` is called inside a `tf.function` if eager 623 behavior is enabled. 624 """ 625 626 def __init__(self, 627 tpu_cluster_resolver=None, 628 device_assignment=None): 629 """Synchronous training in TPU donuts or Pods. 630 631 Args: 632 tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 633 which provides information about the TPU cluster. 634 device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to 635 specify the placement of replicas on the TPU cluster. 636 """ 637 logging.warning( 638 "`tf.distribute.experimental.TPUStrategy` is deprecated, please use " 639 " the non experimental symbol `tf.distribute.TPUStrategy` instead.") 640 641 super(TPUStrategy, self).__init__(TPUExtended( 642 self, tpu_cluster_resolver, device_assignment=device_assignment)) 643 distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy") 644 distribute_lib.distribution_strategy_replica_gauge.get_cell( 645 "num_workers").set(self.extended.num_hosts) 646 distribute_lib.distribution_strategy_replica_gauge.get_cell( 647 "num_replicas_per_worker").set(self.extended.num_replicas_per_host) 648 # Packed variable is used to reduce the overhead of function execution. 649 # For a DistributedVariable, only one variable handle is captured into a 650 # function graph. It's only supported in eager mode. 651 self._enable_packed_variable_in_eager_mode = True 652 653 # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this 654 # can use the default implementation. 655 # This implementation runs a single step. It does not use infeed or outfeed. 656 def run(self, fn, args=(), kwargs=None, options=None): 657 """See base class.""" 658 validate_run_function(fn) 659 660 fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs) 661 662 # Note: the target function is converted to graph even when in Eager mode, 663 # so autograph is on by default here. 664 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) 665 options = options or distribute_lib.RunOptions() 666 return self.extended.tpu_run(fn, args, kwargs, options) 667 668 @property 669 def cluster_resolver(self): 670 """Returns the cluster resolver associated with this strategy. 671 672 `tf.distribute.experimental.TPUStrategy` provides the 673 associated `tf.distribute.cluster_resolver.ClusterResolver`. If the user 674 provides one in `__init__`, that instance is returned; if the user does 675 not, a default 676 `tf.distribute.cluster_resolver.TPUClusterResolver` is provided. 677 """ 678 return self.extended._tpu_cluster_resolver # pylint: disable=protected-access 679 680 681@tf_export(v1=["distribute.experimental.TPUStrategy"]) 682class TPUStrategyV1(distribute_lib.StrategyV1): 683 """TPU distribution strategy implementation.""" 684 685 def __init__(self, 686 tpu_cluster_resolver=None, 687 steps_per_run=None, 688 device_assignment=None): 689 """Initializes the TPUStrategy object. 690 691 Args: 692 tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 693 which provides information about the TPU cluster. 694 steps_per_run: Number of steps to run on device before returning to the 695 host. Note that this can have side-effects on performance, hooks, 696 metrics, summaries etc. 697 This parameter is only used when Distribution Strategy is used with 698 estimator or keras. 699 device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to 700 specify the placement of replicas on the TPU cluster. Currently only 701 supports the usecase of using a single core within a TPU cluster. 702 """ 703 super(TPUStrategyV1, self).__init__(TPUExtended( 704 self, tpu_cluster_resolver, steps_per_run, device_assignment)) 705 distribute_lib.distribution_strategy_gauge.get_cell("V1").set("TPUStrategy") 706 distribute_lib.distribution_strategy_replica_gauge.get_cell( 707 "num_workers").set(self.extended.num_hosts) 708 distribute_lib.distribution_strategy_replica_gauge.get_cell( 709 "num_replicas_per_worker").set(self.extended.num_replicas_per_host) 710 # Packed variable is used to reduce the overhead of function execution. 711 # For a DistributedVariable, only one variable handle is captured into a 712 # function graph. It's only supported in eager mode. 713 self._enable_packed_variable_in_eager_mode = True 714 715 @property 716 def steps_per_run(self): 717 """DEPRECATED: use .extended.steps_per_run instead.""" 718 return self._extended.steps_per_run 719 720 # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this 721 # can use the default implementation. 722 # This implementation runs a single step. It does not use infeed or outfeed. 723 def run(self, fn, args=(), kwargs=None, options=None): 724 """Run `fn` on each replica, with the given arguments. 725 726 Executes ops specified by `fn` on each replica. If `args` or `kwargs` have 727 "per-replica" values, such as those produced by a "distributed `Dataset`", 728 when `fn` is executed on a particular replica, it will be executed with the 729 component of those "per-replica" values that correspond to that replica. 730 731 `fn` may call `tf.distribute.get_replica_context()` to access members such 732 as `all_reduce`. 733 734 All arguments in `args` or `kwargs` should either be nest of tensors or 735 per-replica objects containing tensors or composite tensors. 736 737 Users can pass strategy specific options to `options` argument. An example 738 to enable bucketizing dynamic shapes in `TPUStrategy.run` 739 is: 740 741 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 742 >>> tf.config.experimental_connect_to_cluster(resolver) 743 >>> tf.tpu.experimental.initialize_tpu_system(resolver) 744 >>> strategy = tf.distribute.experimental.TPUStrategy(resolver) 745 746 >>> options = tf.distribute.RunOptions( 747 ... experimental_bucketizing_dynamic_shape=True) 748 749 >>> dataset = tf.data.Dataset.range( 750 ... strategy.num_replicas_in_sync, output_type=dtypes.float32).batch( 751 ... strategy.num_replicas_in_sync, drop_remainder=True) 752 >>> input_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 753 754 >>> @tf.function() 755 ... def step_fn(inputs): 756 ... output = tf.reduce_sum(inputs) 757 ... return output 758 759 >>> strategy.run(step_fn, args=(next(input_iterator),), options=options) 760 761 Args: 762 fn: The function to run. The output must be a `tf.nest` of `Tensor`s. 763 args: (Optional) Positional arguments to `fn`. 764 kwargs: (Optional) Keyword arguments to `fn`. 765 options: (Optional) An instance of `tf.distribute.RunOptions` specifying 766 the options to run `fn`. 767 768 Returns: 769 Merged return value of `fn` across replicas. The structure of the return 770 value is the same as the return value from `fn`. Each element in the 771 structure can either be "per-replica" `Tensor` objects or `Tensor`s 772 (for example, if running on a single replica). 773 """ 774 validate_run_function(fn) 775 776 fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs) 777 778 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) 779 options = options or distribute_lib.RunOptions() 780 return self.extended.tpu_run(fn, args, kwargs, options) 781 782 783# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1. 784class TPUExtended(distribute_lib.StrategyExtendedV1): 785 """Implementation of TPUStrategy.""" 786 787 def __init__(self, 788 container_strategy, 789 tpu_cluster_resolver=None, 790 steps_per_run=None, 791 device_assignment=None): 792 super(TPUExtended, self).__init__(container_strategy) 793 794 if tpu_cluster_resolver is None: 795 tpu_cluster_resolver = TPUClusterResolver("") 796 797 if steps_per_run is None: 798 # TODO(frankchn): Warn when we are being used by DS/Keras and this is 799 # not specified. 800 steps_per_run = 1 801 802 # `self._tpu_function_cache` is a dict of `tf.function`s, thus if a 803 # `tf.function` is passed into `strategy.run` in eager mode, the 804 # `tf.function` won't get retraced. 805 self._tpu_function_cache = weakref.WeakKeyDictionary() 806 807 self._tpu_cluster_resolver = tpu_cluster_resolver 808 self._tpu_metadata = self._tpu_cluster_resolver.get_tpu_system_metadata() 809 self._device_assignment = device_assignment 810 811 tpu_devices_flat = [ 812 d.name for d in self._tpu_metadata.devices if "device:TPU:" in d.name] 813 814 # `self._tpu_devices` is a two-dimensional NumPy array of strings. It is 815 # indexed using `[replica_id][logical_device_id]`. 816 if device_assignment is None: 817 self._tpu_devices = np.array( 818 [[d] for d in tpu_devices_flat], dtype=object) 819 else: 820 job_name = device_spec.DeviceSpecV2.from_string(tpu_devices_flat[0]).job 821 822 tpu_devices = [] 823 for replica_id in range(device_assignment.num_replicas): 824 replica_devices = [] 825 826 for logical_core in range(device_assignment.num_cores_per_replica): 827 replica_devices.append( 828 device_util.canonicalize( 829 device_assignment.tpu_device( 830 replica=replica_id, 831 logical_core=logical_core, 832 job=job_name))) 833 834 tpu_devices.append(replica_devices) 835 self._tpu_devices = np.array(tpu_devices, dtype=object) 836 837 self._host_device = device_util.get_host_for_device(self._tpu_devices[0][0]) 838 839 # Preload the data onto the TPUs. Currently we always preload onto logical 840 # device 0 for each replica. 841 # TODO(cjfj): Create `InputWorkers` lazily, allowing users to place the 842 # input onto a different logical device? 843 self._device_input_worker_devices = collections.OrderedDict() 844 self._host_input_worker_devices = collections.OrderedDict() 845 for tpu_device in self._tpu_devices[:, 0]: 846 host_device = device_util.get_host_for_device(tpu_device) 847 self._device_input_worker_devices.setdefault(host_device, []) 848 self._device_input_worker_devices[host_device].append(tpu_device) 849 self._host_input_worker_devices.setdefault(host_device, []) 850 self._host_input_worker_devices[host_device].append(host_device) 851 852 # TODO(sourabhbajaj): Remove this once performance of running one step 853 # at a time is comparable to multiple steps. 854 self.steps_per_run = steps_per_run 855 self._require_static_shapes = True 856 857 self.experimental_enable_get_next_as_optional = True 858 859 self._logical_device_stack = [0] 860 861 if context.executing_eagerly(): 862 # In async remote eager, we want to sync the executors before exiting the 863 # program. 864 atexit.register(context.async_wait) 865 866 # Flag to turn on VariablePolicy. 867 self._use_var_policy = True 868 869 # Flag to enable XLA SPMD partitioning. 870 # TODO(b/170873313): Enable XLA SPMD partitioning in TPUStrategy. 871 self._use_spmd_for_xla_partitioning = False 872 873 def _validate_colocate_with_variable(self, colocate_with_variable): 874 distribute_utils. validate_colocate(colocate_with_variable, self) 875 876 def _make_dataset_iterator(self, dataset): 877 """Make iterators for each of the TPU hosts.""" 878 input_workers = input_lib.InputWorkers( 879 tuple(self._device_input_worker_devices.items())) 880 return input_lib.DatasetIterator( 881 dataset, 882 input_workers, 883 self._container_strategy(), 884 num_replicas_in_sync=self._num_replicas_in_sync) 885 886 def _make_input_fn_iterator( 887 self, 888 input_fn, 889 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 890 input_contexts = [] 891 input_workers = input_lib.InputWorkers( 892 tuple(self._device_input_worker_devices.items())) 893 num_workers = input_workers.num_workers 894 for i in range(num_workers): 895 input_contexts.append(distribute_lib.InputContext( 896 num_input_pipelines=num_workers, 897 input_pipeline_id=i, 898 num_replicas_in_sync=self._num_replicas_in_sync)) 899 return input_lib.InputFunctionIterator( 900 input_fn, 901 input_workers, 902 input_contexts, 903 self._container_strategy()) 904 905 def _experimental_make_numpy_dataset(self, numpy_input, session): 906 return numpy_dataset.one_host_numpy_dataset( 907 numpy_input, numpy_dataset.SingleDevice(self._host_device), 908 session) 909 910 def _get_input_workers(self, options): 911 if not options or options.experimental_fetch_to_device: 912 return input_lib.InputWorkers( 913 tuple(self._device_input_worker_devices.items())) 914 else: 915 return input_lib.InputWorkers( 916 tuple(self._host_input_worker_devices.items())) 917 918 def _check_spec(self, element_spec): 919 if isinstance(element_spec, values.PerReplicaSpec): 920 element_spec = element_spec._component_specs # pylint: disable=protected-access 921 specs = nest.flatten_with_joined_string_paths(element_spec) 922 for path, spec in specs: 923 if isinstance(spec, (sparse_tensor.SparseTensorSpec, 924 ragged_tensor.RaggedTensorSpec)): 925 raise ValueError( 926 "Found tensor {} with spec {}. TPUStrategy does not support " 927 "distributed datasets with device prefetch when using sparse or " 928 "ragged tensors. If you intend to use sparse or ragged tensors, " 929 "please pass a tf.distribute.InputOptions object with " 930 "experimental_fetch_to_device set to False to your dataset " 931 "distribution function.".format(path, type(spec))) 932 933 def _experimental_distribute_dataset(self, dataset, options): 934 if (options and options.experimental_replication_mode == 935 distribute_lib.InputReplicationMode.PER_REPLICA): 936 raise NotImplementedError( 937 "InputReplicationMode.PER_REPLICA " 938 "is only supported in " 939 "`experimental_distribute_datasets_from_function`." 940 ) 941 if options is None or options.experimental_fetch_to_device: 942 self._check_spec(dataset.element_spec) 943 944 return input_lib.get_distributed_dataset( 945 dataset, 946 self._get_input_workers(options), 947 self._container_strategy(), 948 num_replicas_in_sync=self._num_replicas_in_sync, 949 options=options) 950 951 def _distribute_datasets_from_function(self, dataset_fn, options): 952 if (options and options.experimental_replication_mode == 953 distribute_lib.InputReplicationMode.PER_REPLICA): 954 raise NotImplementedError( 955 "InputReplicationMode.PER_REPLICA " 956 "is only supported in " 957 " `experimental_distribute_datasets_from_function` " 958 "of tf.distribute.MirroredStrategy") 959 input_workers = self._get_input_workers(options) 960 input_contexts = [] 961 num_workers = input_workers.num_workers 962 for i in range(num_workers): 963 input_contexts.append(distribute_lib.InputContext( 964 num_input_pipelines=num_workers, 965 input_pipeline_id=i, 966 num_replicas_in_sync=self._num_replicas_in_sync)) 967 968 distributed_dataset = input_lib.get_distributed_datasets_from_function( 969 dataset_fn, 970 input_workers, 971 input_contexts, 972 self._container_strategy(), 973 options=options) 974 975 # We can only check after the dataset_fn is called. 976 if options is None or options.experimental_fetch_to_device: 977 self._check_spec(distributed_dataset.element_spec) 978 return distributed_dataset 979 980 def _experimental_distribute_values_from_function(self, value_fn): 981 per_replica_values = [] 982 for replica_id in range(self._num_replicas_in_sync): 983 per_replica_values.append( 984 value_fn(distribute_lib.ValueContext(replica_id, 985 self._num_replicas_in_sync))) 986 return distribute_utils.regroup(per_replica_values, always_wrap=True) 987 988 # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. 989 # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have 990 # a mechanism to infer the outputs of `fn`. Pending b/110550782. 991 def _experimental_run_steps_on_iterator( 992 self, fn, multi_worker_iterator, iterations, initial_loop_values=None): 993 # Wrap `fn` for repeat. 994 if initial_loop_values is None: 995 initial_loop_values = {} 996 initial_loop_values = nest.flatten(initial_loop_values) 997 ctx = input_lib.MultiStepContext() 998 999 def run_fn(inputs): 1000 """Single step on the TPU device.""" 1001 fn_result = fn(ctx, inputs) 1002 flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) 1003 if flat_last_step_outputs: 1004 with ops.control_dependencies([fn_result]): 1005 return [array_ops.identity(f) for f in flat_last_step_outputs] 1006 else: 1007 return fn_result 1008 1009 # We capture the control_flow_context at this point, before we run `fn` 1010 # inside a while_loop and TPU replicate context. This is useful in cases 1011 # where we might need to exit these contexts and get back to the outer 1012 # context to do some things, for e.g. create an op which should be 1013 # evaluated only once at the end of the loop on the host. One such usage 1014 # is in creating metrics' value op. 1015 self._outer_control_flow_context = ( 1016 ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access 1017 1018 def rewrite_fn(*args): 1019 """The rewritten step fn running on TPU.""" 1020 del args 1021 1022 per_replica_inputs = multi_worker_iterator.get_next() 1023 replicate_inputs = [] 1024 for replica_id in range(self._num_replicas_in_sync): 1025 select_replica = lambda x: distribute_utils.select_replica( # pylint: disable=g-long-lambda 1026 replica_id, x) # pylint: disable=cell-var-from-loop 1027 replicate_inputs.append((nest.map_structure( 1028 select_replica, per_replica_inputs),)) 1029 1030 replicate_outputs = tpu.replicate( 1031 run_fn, 1032 replicate_inputs, 1033 device_assignment=self._device_assignment, 1034 xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=self 1035 ._use_spmd_for_xla_partitioning)) 1036 # If run_fn has tensor outputs, tpu.replicate returns a list of list. We 1037 # will flatten it in this case. If run_fn has no tensor outputs, 1038 # tpu.replicate returns a list of no_ops, we will keep the output as it 1039 # is. 1040 if isinstance(replicate_outputs[0], list): 1041 replicate_outputs = nest.flatten(replicate_outputs) 1042 1043 return replicate_outputs 1044 1045 # TODO(sourabhbajaj): The input to while loop should be based on the 1046 # output type of the step_fn 1047 assert isinstance(initial_loop_values, list) 1048 initial_loop_values = initial_loop_values * self._num_replicas_in_sync 1049 1050 # Put the while loop op on TPU host 0. 1051 with ops.device(self._host_device): 1052 if self.steps_per_run == 1: 1053 replicate_outputs = rewrite_fn() 1054 else: 1055 replicate_outputs = training_loop.repeat(iterations, rewrite_fn, 1056 initial_loop_values) 1057 1058 del self._outer_control_flow_context 1059 ctx.run_op = control_flow_ops.group(replicate_outputs) 1060 1061 if isinstance(replicate_outputs, list): 1062 # Filter out any ops from the outputs, typically this would be the case 1063 # when there were no tensor outputs. 1064 last_step_tensor_outputs = [ 1065 x for x in replicate_outputs if not isinstance(x, ops.Operation) 1066 ] 1067 1068 # Outputs are currently of the structure (flattened) 1069 # [output0_device0, output1_device0, output2_device0, 1070 # output0_device1, output1_device1, output2_device1, 1071 # ...] 1072 # Convert this to the following structure instead: (grouped by output) 1073 # [[output0_device0, output0_device1], 1074 # [output1_device0, output1_device1], 1075 # [output2_device0, output2_device1]] 1076 output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync 1077 last_step_tensor_outputs = [ 1078 last_step_tensor_outputs[i::output_num] for i in range(output_num) 1079 ] 1080 else: 1081 # no tensors returned. 1082 last_step_tensor_outputs = [] 1083 1084 _set_last_step_outputs(ctx, last_step_tensor_outputs) 1085 return ctx 1086 1087 def _call_for_each_replica(self, fn, args, kwargs): 1088 # TODO(jhseu): Consider making it so call_for_each_replica implies that 1089 # we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly. 1090 with _TPUReplicaContext(self._container_strategy()): 1091 return fn(*args, **kwargs) 1092 1093 @contextlib.contextmanager 1094 def experimental_logical_device(self, logical_device_id): 1095 """Places variables and ops on the specified logical device.""" 1096 num_logical_devices_per_replica = self._tpu_devices.shape[1] 1097 if logical_device_id >= num_logical_devices_per_replica: 1098 raise ValueError( 1099 "`logical_device_id` not in range (was {}, but there are only {} " 1100 "logical devices per replica).".format( 1101 logical_device_id, num_logical_devices_per_replica)) 1102 1103 self._logical_device_stack.append(logical_device_id) 1104 try: 1105 if tpu_util.enclosing_tpu_context() is None: 1106 yield 1107 else: 1108 with ops.device(tpu.core(logical_device_id)): 1109 yield 1110 finally: 1111 self._logical_device_stack.pop() 1112 1113 def _experimental_initialize_system(self): 1114 """Experimental method added to be used by Estimator. 1115 1116 This is a private method only to be used by Estimator. Other frameworks 1117 should directly be calling `tf.tpu.experimental.initialize_tpu_system` 1118 """ 1119 tpu_strategy_util.initialize_tpu_system(self._tpu_cluster_resolver) 1120 1121 def _create_variable(self, next_creator, **kwargs): 1122 """Create a TPUMirroredVariable. See `DistributionStrategy.scope`.""" 1123 if kwargs.pop("skip_mirrored_creator", False): 1124 return next_creator(**kwargs) 1125 1126 colocate_with = kwargs.pop("colocate_with", None) 1127 if colocate_with is None: 1128 devices = self._tpu_devices[:, self._logical_device_stack[-1]] 1129 elif isinstance(colocate_with, numpy_dataset.SingleDevice): 1130 with ops.device(colocate_with.device): 1131 return next_creator(**kwargs) 1132 else: 1133 devices = colocate_with._devices # pylint: disable=protected-access 1134 1135 def _real_mirrored_creator(**kwargs): # pylint: disable=g-missing-docstring 1136 initial_value = None 1137 value_list = [] 1138 for i, d in enumerate(devices): 1139 with ops.device(d): 1140 if i == 0: 1141 initial_value = kwargs["initial_value"] 1142 # Note: some v1 code expects variable initializer creation to happen 1143 # inside a init_scope. 1144 with maybe_init_scope(): 1145 initial_value = initial_value() if callable( 1146 initial_value) else initial_value 1147 1148 if i > 0: 1149 # Give replicas meaningful distinct names: 1150 var0name = value_list[0].name.split(":")[0] 1151 # We append a / to variable names created on replicas with id > 0 to 1152 # ensure that we ignore the name scope and instead use the given 1153 # name as the absolute name of the variable. 1154 kwargs["name"] = "%s/replica_%d/" % (var0name, i) 1155 kwargs["initial_value"] = initial_value 1156 1157 with context.device_policy(context.DEVICE_PLACEMENT_SILENT): 1158 v = next_creator(**kwargs) 1159 1160 assert not isinstance(v, tpu_values.TPUMirroredVariable) 1161 value_list.append(v) 1162 return value_list 1163 1164 return distribute_utils.create_mirrored_variable( 1165 self._container_strategy(), _real_mirrored_creator, 1166 distribute_utils.TPU_VARIABLE_CLASS_MAPPING, 1167 distribute_utils.TPU_VARIABLE_POLICY_MAPPING, **kwargs) 1168 1169 def _resource_creator_scope(self): 1170 1171 def lookup_creator(next_creator, *args, **kwargs): 1172 host_to_table = collections.OrderedDict() 1173 for host_device in self._device_input_worker_devices.keys(): 1174 with ops.device(host_device): 1175 host_to_table[host_device] = next_creator(*args, **kwargs) 1176 1177 return values.PerWorkerResource(self._container_strategy(), host_to_table) 1178 1179 # TODO(b/194362531): Define creator(s) for other resources. 1180 return ops.resource_creator_scope("StaticHashTable", lookup_creator) 1181 1182 def _gather_to_implementation(self, value, destinations, axis, options): 1183 if not isinstance(value, values.DistributedValues): 1184 return value 1185 1186 value_list = list(value.values) 1187 # pylint: disable=protected-access 1188 if isinstance( 1189 value, 1190 values.DistributedVariable) and value._packed_variable is not None: 1191 value_list = list( 1192 value._packed_variable.on_device(d) 1193 for d in value._packed_variable.devices) 1194 # pylint: enable=protected-access 1195 1196 # Currently XLA op by op mode has a limit for the number of inputs for a 1197 # single op, thus we break one `add_n` op into a group of `add_n` ops to 1198 # work around the constraint. 1199 if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT: 1200 output = array_ops.concat(value_list, axis=axis) 1201 else: 1202 output = array_ops.concat( 1203 value_list[:_XLA_OP_BY_OP_INPUTS_LIMIT], axis=axis) 1204 for i in range(_XLA_OP_BY_OP_INPUTS_LIMIT, len(value_list), 1205 _XLA_OP_BY_OP_INPUTS_LIMIT - 1): 1206 output = array_ops.concat( 1207 [output] + value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT - 1], 1208 axis=axis) 1209 1210 output = self._broadcast_output(destinations, output) 1211 return output 1212 1213 def _broadcast_output(self, destinations, output): 1214 devices = cross_device_ops_lib.get_devices_from(destinations) 1215 1216 if len(devices) == 1: 1217 # If necessary, copy to requested destination. 1218 dest_canonical = device_util.canonicalize(devices[0]) 1219 host_canonical = device_util.canonicalize(self._host_device) 1220 1221 if dest_canonical != host_canonical: 1222 with ops.device(dest_canonical): 1223 output = array_ops.identity(output) 1224 else: 1225 output = cross_device_ops_lib.simple_broadcast(output, destinations) 1226 1227 return output 1228 1229 def _reduce_to(self, reduce_op, value, destinations, options): 1230 if (isinstance(value, values.DistributedValues) or 1231 tensor_util.is_tf_type(value) 1232 ) and tpu_util.enclosing_tpu_context() is not None: 1233 if reduce_op == reduce_util.ReduceOp.MEAN: 1234 # TODO(jhseu): Revisit once we support model-parallelism. 1235 # scalar_mul maintains the type of value: tensor or IndexedSlices. 1236 value = math_ops.scalar_mul((1./self._num_replicas_in_sync), value) 1237 elif reduce_op != reduce_util.ReduceOp.SUM: 1238 raise NotImplementedError( 1239 "Currently only support sum & mean in TPUStrategy.") 1240 return tpu_ops.cross_replica_sum(value) 1241 1242 if not isinstance(value, values.DistributedValues): 1243 # This function handles reducing values that are not PerReplica or 1244 # Mirrored values. For example, the same value could be present on all 1245 # replicas in which case `value` would be a single value or value could 1246 # be 0. 1247 return cross_device_ops_lib.reduce_non_distributed_value( 1248 reduce_op, value, destinations, self._num_replicas_in_sync) 1249 1250 value_list = value.values 1251 # pylint: disable=protected-access 1252 if isinstance( 1253 value, 1254 values.DistributedVariable) and value._packed_variable is not None: 1255 value_list = tuple( 1256 value._packed_variable.on_device(d) 1257 for d in value._packed_variable.devices) 1258 # pylint: enable=protected-access 1259 1260 # Currently XLA op by op mode has a limit for the number of inputs for a 1261 # single op, thus we break one `add_n` op into a group of `add_n` ops to 1262 # work around the constraint. 1263 # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`. 1264 if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT: 1265 output = math_ops.add_n(value_list) 1266 else: 1267 output = array_ops.zeros_like(value_list[0], dtype=value_list[0].dtype) 1268 for i in range(0, len(value_list), _XLA_OP_BY_OP_INPUTS_LIMIT): 1269 output += math_ops.add_n(value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT]) 1270 1271 if reduce_op == reduce_util.ReduceOp.MEAN: 1272 output *= (1. / len(value_list)) 1273 1274 output = self._broadcast_output(destinations, output) 1275 return output 1276 1277 def _update(self, var, fn, args, kwargs, group): 1278 assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance( 1279 var, resource_variable_ops.BaseResourceVariable) 1280 if tpu_util.enclosing_tpu_context() is not None: 1281 if group: 1282 return fn(var, *args, **kwargs) 1283 else: 1284 return (fn(var, *args, **kwargs),) 1285 1286 # Otherwise, we revert to MirroredStrategy behavior and update the variable 1287 # on each replica directly. 1288 updates = [] 1289 values_and_devices = [] 1290 packed_var = var._packed_variable # pylint: disable=protected-access 1291 if packed_var is not None: 1292 for device in packed_var.devices: 1293 values_and_devices.append((packed_var, device)) 1294 else: 1295 for value in var.values: 1296 values_and_devices.append((value, value.device)) 1297 1298 if (var.synchronization != variables_lib.VariableSynchronization.ON_READ and 1299 var.aggregation != variables_lib.VariableAggregation.NONE): 1300 distribute_utils.assert_mirrored(args) 1301 distribute_utils.assert_mirrored(kwargs) 1302 for i, value_and_device in enumerate(values_and_devices): 1303 value = value_and_device[0] 1304 device = value_and_device[1] 1305 name = "update_%d" % i 1306 with ops.device(device), \ 1307 distribute_lib.UpdateContext(i), \ 1308 ops.name_scope(name): 1309 # If args and kwargs are not mirrored, the value is returned as is. 1310 updates.append( 1311 fn(value, *distribute_utils.select_replica(i, args), 1312 **distribute_utils.select_replica(i, kwargs))) 1313 return distribute_utils.update_regroup(self, updates, group) 1314 1315 def read_var(self, var): 1316 assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance( 1317 var, resource_variable_ops.BaseResourceVariable) 1318 return var.read_value() 1319 1320 def value_container(self, value): 1321 return value 1322 1323 def _broadcast_to(self, tensor, destinations): 1324 del destinations 1325 # This is both a fast path for Python constants, and a way to delay 1326 # converting Python values to a tensor until we know what type it 1327 # should be converted to. Otherwise we have trouble with: 1328 # global_step.assign_add(1) 1329 # since the `1` gets broadcast as an int32 but global_step is int64. 1330 if isinstance(tensor, (float, int)): 1331 return tensor 1332 if tpu_util.enclosing_tpu_context() is not None: 1333 broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)] 1334 result = tpu_ops.all_to_all( 1335 broadcast_tensor, 1336 concat_dimension=0, 1337 split_dimension=0, 1338 split_count=self._num_replicas_in_sync) 1339 1340 # This uses the broadcasted value from the first replica because the only 1341 # caller of this is for ONLY_FIRST_REPLICA variables aggregation. 1342 return result[0] 1343 return tensor 1344 1345 @property 1346 def num_hosts(self): 1347 if self._device_assignment is None: 1348 return self._tpu_metadata.num_hosts 1349 1350 return len(set([self._device_assignment.host_device(r) 1351 for r in range(self._device_assignment.num_replicas)])) 1352 1353 @property 1354 def num_replicas_per_host(self): 1355 if self._device_assignment is None: 1356 return self._tpu_metadata.num_of_cores_per_host 1357 1358 # TODO(sourabhbajaj): Remove this method we use inputs and remove infeed 1359 # as the computation of num_replicas_per_host is not a constant 1360 # when using device_assignment. This is a temporary workaround to support 1361 # StatefulRNN as everything is 1 in that case. 1362 # This method needs to take host_id as input for correct computation. 1363 max_models_per_host = (self._tpu_metadata.num_of_cores_per_host // 1364 self._device_assignment.num_cores_per_replica) 1365 return min(self._device_assignment.num_replicas, max_models_per_host) 1366 1367 @property 1368 def _num_replicas_in_sync(self): 1369 if self._device_assignment is None: 1370 return self._tpu_metadata.num_cores 1371 return self._device_assignment.num_replicas 1372 1373 @property 1374 def experimental_between_graph(self): 1375 return False 1376 1377 @property 1378 def experimental_should_init(self): 1379 return True 1380 1381 @property 1382 def should_checkpoint(self): 1383 return True 1384 1385 @property 1386 def should_save_summary(self): 1387 return True 1388 1389 @property 1390 def worker_devices(self): 1391 return tuple(self._tpu_devices[:, self._logical_device_stack[-1]]) 1392 1393 @property 1394 def parameter_devices(self): 1395 return self.worker_devices 1396 1397 def non_slot_devices(self, var_list): 1398 return self._host_device 1399 1400 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 1401 del colocate_with 1402 with ops.device(self._host_device), distribute_lib.UpdateContext(None): 1403 result = fn(*args, **kwargs) 1404 if group: 1405 return result 1406 else: 1407 return nest.map_structure(self._local_results, result) 1408 1409 def _configure(self, 1410 session_config=None, 1411 cluster_spec=None, 1412 task_type=None, 1413 task_id=None): 1414 del cluster_spec, task_type, task_id 1415 if session_config: 1416 session_config.CopyFrom(self._update_config_proto(session_config)) 1417 1418 def _update_config_proto(self, config_proto): 1419 updated_config = copy.deepcopy(config_proto) 1420 updated_config.isolate_session_state = True 1421 cluster_spec = self._tpu_cluster_resolver.cluster_spec() 1422 if cluster_spec: 1423 updated_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) 1424 return updated_config 1425 1426 # TODO(priyag): Delete this once all strategies use global batch size. 1427 @property 1428 def _global_batch_size(self): 1429 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. 1430 1431 `make_input_fn_iterator` assumes per-replica batching. 1432 1433 Returns: 1434 Boolean. 1435 """ 1436 return True 1437 1438 def tpu_run(self, fn, args, kwargs, options=None): 1439 func = self._tpu_function_creator(fn, options) 1440 return func(args, kwargs) 1441 1442 def _tpu_function_creator(self, fn, options): 1443 if context.executing_eagerly() and fn in self._tpu_function_cache: 1444 return self._tpu_function_cache[fn] 1445 1446 strategy = self._container_strategy() 1447 1448 def tpu_function(args, kwargs): 1449 """TF Function used to replicate the user computation.""" 1450 if kwargs is None: 1451 kwargs = {} 1452 1453 # Used to re-structure flattened output tensors from `tpu.replicate()` 1454 # into a structured format. 1455 result = [[]] 1456 1457 def replicated_fn(replica_id, replica_args, replica_kwargs): 1458 """Wraps user function to provide replica ID and `Tensor` inputs.""" 1459 with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id): 1460 result[0] = fn(*replica_args, **replica_kwargs) 1461 return result[0] 1462 1463 replicate_inputs = [] # By replica. 1464 for i in range(strategy.num_replicas_in_sync): 1465 replicate_inputs.append( 1466 [constant_op.constant(i, dtype=dtypes.int32), 1467 distribute_utils.select_replica(i, args), 1468 distribute_utils.select_replica(i, kwargs)]) 1469 1470 # Construct and pass `maximum_shapes` so that we could support dynamic 1471 # shapes using dynamic padder. 1472 if options.experimental_enable_dynamic_batch_size and replicate_inputs: 1473 maximum_shapes = [] 1474 flattened_list = nest.flatten(replicate_inputs[0]) 1475 for input_tensor in flattened_list: 1476 if tensor_util.is_tf_type(input_tensor): 1477 rank = input_tensor.shape.rank 1478 else: 1479 rank = np.ndim(input_tensor) 1480 if rank is None: 1481 raise ValueError( 1482 "input tensor {} to TPUStrategy.run() has unknown rank, " 1483 "which is not allowed".format(input_tensor)) 1484 maximum_shape = tensor_shape.TensorShape([None] * rank) 1485 maximum_shapes.append(maximum_shape) 1486 maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], 1487 maximum_shapes) 1488 else: 1489 maximum_shapes = None 1490 1491 if options.experimental_bucketizing_dynamic_shape: 1492 padding_spec = tpu.PaddingSpec.POWER_OF_TWO 1493 else: 1494 padding_spec = None 1495 1496 with strategy.scope(): 1497 xla_options = options.experimental_xla_options or tpu.XLAOptions( 1498 use_spmd_for_xla_partitioning=self._use_spmd_for_xla_partitioning) 1499 replicate_outputs = tpu.replicate( 1500 replicated_fn, 1501 replicate_inputs, 1502 device_assignment=self._device_assignment, 1503 maximum_shapes=maximum_shapes, 1504 padding_spec=padding_spec, 1505 xla_options=xla_options) 1506 1507 # Remove all no ops that may have been added during 'tpu.replicate()' 1508 filter_ops = lambda x: [o for o in x if not isinstance(o, ops.Operation)] 1509 if isinstance(result[0], list): 1510 result[0] = filter_ops(result[0]) 1511 1512 # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. 1513 if result[0] is None or isinstance(result[0], ops.Operation): 1514 replicate_outputs = [None] * len(replicate_outputs) 1515 else: 1516 replicate_outputs = [ 1517 nest.pack_sequence_as(result[0], filter_ops(nest.flatten(output))) 1518 for output in replicate_outputs 1519 ] 1520 return distribute_utils.regroup(replicate_outputs) 1521 1522 if context.executing_eagerly(): 1523 tpu_function = def_function.function(tpu_function) 1524 self._tpu_function_cache[fn] = tpu_function 1525 return tpu_function 1526 1527 def _in_multi_worker_mode(self): 1528 """Whether this strategy indicates working in multi-worker settings.""" 1529 # TPUStrategy has different distributed training structure that the whole 1530 # cluster should be treated as single worker from higher-level (e.g. Keras) 1531 # library's point of view. 1532 # TODO(rchao): Revisit this as we design a fault-tolerance solution for 1533 # TPUStrategy. 1534 return False 1535 1536 def _get_local_replica_id(self, replica_id_in_sync_group): 1537 return replica_id_in_sync_group 1538 1539 1540class _TPUReplicaContext(distribute_lib.ReplicaContext): 1541 """Replication Context class for TPU Strategy.""" 1542 1543 # TODO(sourabhbajaj): Call for each replica should be updating this. 1544 # TODO(b/118385803): Always properly initialize replica_id. 1545 def __init__(self, strategy, replica_id_in_sync_group=0): 1546 distribute_lib.ReplicaContext.__init__( 1547 self, strategy, replica_id_in_sync_group=replica_id_in_sync_group) 1548 1549 @property 1550 def devices(self): 1551 distribute_lib.require_replica_context(self) 1552 ds = self._strategy 1553 replica_id = tensor_util.constant_value(self.replica_id_in_sync_group) 1554 1555 if replica_id is None: # Non-constant `Tensor` inside `tpu.replicate`. 1556 # TODO(cjfj): Return other devices when model parallelism is supported. 1557 return (tpu.core(0),) 1558 else: 1559 return (ds.extended.worker_devices[replica_id],) 1560 1561 def experimental_logical_device(self, logical_device_id): 1562 """Places variables and ops on the specified logical device.""" 1563 return self.strategy.extended.experimental_logical_device(logical_device_id) 1564 1565 # TODO(wxinyi): Investigate whether to use cross_replica_sum to optimize it. 1566 def all_gather(self, value, axis, experimental_hints=None): 1567 del experimental_hints 1568 for v in nest.flatten(value): 1569 if isinstance(v, ops.IndexedSlices): 1570 raise NotImplementedError("all_gather does not support IndexedSlices") 1571 1572 def _all_to_all(value, axis): 1573 # The underlying AllToAllOp first do a split of the input value and then 1574 # cross-replica communication and concatenation of the result. So we 1575 # concatenate the local tensor here first. 1576 inputs = array_ops.concat( 1577 [value for _ in range(self.num_replicas_in_sync)], axis=0) 1578 unordered_output = tpu_ops.all_to_all( 1579 inputs, 1580 concat_dimension=axis, 1581 split_dimension=0, 1582 split_count=self.num_replicas_in_sync) 1583 1584 # Re-order since xla.replica_id and ReplicaContext.replica_id mismatch. 1585 # xla_id = xla.replica_id() 1586 concat_replica_id = array_ops.concat([ 1587 array_ops.expand_dims_v2(self.replica_id_in_sync_group, 0) 1588 for _ in range(self.num_replicas_in_sync) 1589 ], 1590 axis=0) 1591 replica_ids = tpu_ops.all_to_all( 1592 concat_replica_id, 1593 concat_dimension=0, 1594 split_dimension=0, 1595 split_count=self.num_replicas_in_sync) 1596 1597 splited_unordered = array_ops.split( 1598 unordered_output, 1599 num_or_size_splits=self.num_replicas_in_sync, 1600 axis=axis) 1601 sorted_with_extra_dim = math_ops.unsorted_segment_sum( 1602 array_ops.concat([ 1603 array_ops.expand_dims(replica, axis=0) 1604 for replica in splited_unordered 1605 ], 1606 axis=0), 1607 replica_ids, 1608 num_segments=self.num_replicas_in_sync) 1609 1610 splited_with_extra_dim = array_ops.split( 1611 sorted_with_extra_dim, 1612 num_or_size_splits=self.num_replicas_in_sync, 1613 axis=0) 1614 squeezed = [ 1615 array_ops.squeeze(replica, axis=0) 1616 for replica in splited_with_extra_dim 1617 ] 1618 result = array_ops.concat(squeezed, axis=axis) 1619 return result 1620 1621 ys = [_all_to_all(t, axis=axis) for t in nest.flatten(value)] 1622 return nest.pack_sequence_as(value, ys) 1623 1624 1625def _set_last_step_outputs(ctx, last_step_tensor_outputs): 1626 """Sets the last step outputs on the given context.""" 1627 # Convert replicate_outputs to the original dict structure of 1628 # last_step_outputs. 1629 last_step_tensor_outputs_dict = nest.pack_sequence_as( 1630 ctx.last_step_outputs, last_step_tensor_outputs) 1631 1632 for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access 1633 output = last_step_tensor_outputs_dict[name] 1634 # For outputs that aren't reduced, return a PerReplica of all values. Else 1635 # take the first value from the list as each value should be the same. 1636 if reduce_op is None: 1637 last_step_tensor_outputs_dict[name] = values.PerReplica(output) 1638 else: 1639 # TODO(priyag): Should this return the element or a list with 1 element 1640 last_step_tensor_outputs_dict[name] = output[0] 1641 ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access 1642