1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TPU Strategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import contextlib 23import copy 24import weakref 25 26import numpy as np 27 28from tensorflow.python.autograph.core import ag_ctx as autograph_ctx 29from tensorflow.python.autograph.impl import api as autograph 30from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 31from tensorflow.python.distribute import device_util 32from tensorflow.python.distribute import distribute_lib 33from tensorflow.python.distribute import input_lib 34from tensorflow.python.distribute import numpy_dataset 35from tensorflow.python.distribute import reduce_util 36from tensorflow.python.distribute import values 37from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver 38from tensorflow.python.eager import context 39from tensorflow.python.eager import def_function 40from tensorflow.python.eager import function 41from tensorflow.python.framework import constant_op 42from tensorflow.python.framework import device_spec 43from tensorflow.python.framework import dtypes 44from tensorflow.python.framework import ops 45from tensorflow.python.framework import tensor_shape 46from tensorflow.python.framework import tensor_util 47from tensorflow.python.ops import array_ops 48from tensorflow.python.ops import control_flow_ops 49from tensorflow.python.ops import math_ops 50from tensorflow.python.ops import resource_variable_ops 51from tensorflow.python.tpu import device_assignment as device_assignment_lib # pylint: disable=unused-import 52from tensorflow.python.tpu import tpu 53from tensorflow.python.tpu import tpu_strategy_util 54from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib 55from tensorflow.python.tpu import training_loop 56from tensorflow.python.tpu.ops import tpu_ops 57from tensorflow.python.util import nest 58from tensorflow.python.util.tf_export import tf_export 59 60 61def get_tpu_system_metadata(tpu_cluster_resolver): 62 """Retrieves TPU system metadata given a TPUClusterResolver.""" 63 master = tpu_cluster_resolver.master() 64 65 # pylint: disable=protected-access 66 cluster_spec = tpu_cluster_resolver.cluster_spec() 67 cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None 68 tpu_system_metadata = ( 69 tpu_system_metadata_lib._query_tpu_system_metadata( 70 master, 71 cluster_def=cluster_def, 72 query_topology=False)) 73 74 return tpu_system_metadata 75 76 77@contextlib.contextmanager 78def maybe_init_scope(): 79 if ops.executing_eagerly_outside_functions(): 80 yield 81 else: 82 with ops.init_scope(): 83 yield 84 85 86def validate_experimental_run_function(fn): 87 """Validate the function passed into strategy.experimental_run_v2.""" 88 89 # We allow three types of functions/objects passed into TPUStrategy 90 # experimental_run_v2 in eager mode: 91 # 1. a user annotated tf.function 92 # 2. a ConcreteFunction, this is mostly what you get from loading a saved 93 # model. 94 # 3. a callable object and the `__call__` method itself is a tf.function. 95 # 96 # Otherwise we return an error, because we don't support eagerly running 97 # experimental_run_v2 in TPUStrategy. 98 99 if context.executing_eagerly() and not isinstance( 100 fn, def_function.Function) and not isinstance( 101 fn, function.ConcreteFunction) and not (callable(fn) and isinstance( 102 fn.__call__, def_function.Function)): 103 raise NotImplementedError( 104 "TPUStrategy.experimental_run_v2(fn, ...) does not support pure eager " 105 "execution. please make sure the function passed into " 106 "`strategy.experimental_run_v2` is a `tf.function` or " 107 "`strategy.experimental_run_v2` is called inside a `tf.function` if " 108 "eager behavior is enabled.") 109 110 111@tf_export("distribute.experimental.TPUStrategy", v1=[]) 112class TPUStrategy(distribute_lib.Strategy): 113 """TPU distribution strategy implementation.""" 114 115 def __init__(self, 116 tpu_cluster_resolver=None, 117 device_assignment=None): 118 """Synchronous training in TPU donuts or Pods. 119 120 To construct a TPUStrategy object, you need to run the 121 initialization code as below: 122 123 ```python 124 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu) 125 tf.config.experimental_connect_to_cluster(resolver) 126 tf.tpu.experimental.initialize_tpu_system(resolver) 127 strategy = tf.distribute.experimental.TPUStrategy(resolver) 128 ``` 129 130 While using distribution strategies, the variables created within strategy's 131 scope will be replicated across all the replicas and can be kept in sync 132 using all-reduce algorithms. 133 134 To run TF2 programs on TPUs, you can either use `.compile` and 135 `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized 136 training loop by calling `strategy.experimental_run_v2` directly. Note that 137 TPUStrategy doesn't support pure eager execution, so please make sure the 138 function passed into `strategy.experimental_run_v2` is a `tf.function` or 139 `strategy.experimental_run_v2` is called inside a `tf.function` if eager 140 behavior is enabled. 141 142 Args: 143 tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 144 which provides information about the TPU cluster. 145 device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to 146 specify the placement of replicas on the TPU cluster. Currently only 147 supports the usecase of using a single core within a TPU cluster. 148 """ 149 super(TPUStrategy, self).__init__(TPUExtended( 150 self, tpu_cluster_resolver, device_assignment=device_assignment)) 151 distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy") 152 distribute_lib.distribution_strategy_replica_gauge.get_cell( 153 "num_workers").set(self.extended.num_hosts) 154 distribute_lib.distribution_strategy_replica_gauge.get_cell( 155 "num_replicas_per_worker").set(self.extended.num_replicas_per_host) 156 157 # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this 158 # can use the default implementation. 159 # This implementation runs a single step. It does not use infeed or outfeed. 160 def experimental_run_v2(self, fn, args=(), kwargs=None): 161 """See base class.""" 162 validate_experimental_run_function(fn) 163 164 # Note: the target function is converted to graph even when in Eager mode, 165 # so autograph is on by default here. 166 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) 167 return self.extended.tpu_run(fn, args, kwargs) 168 169 170@tf_export(v1=["distribute.experimental.TPUStrategy"]) 171class TPUStrategyV1(distribute_lib.StrategyV1): 172 """TPU distribution strategy implementation.""" 173 174 def __init__(self, 175 tpu_cluster_resolver=None, 176 steps_per_run=None, 177 device_assignment=None): 178 """Initializes the TPUStrategy object. 179 180 Args: 181 tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 182 which provides information about the TPU cluster. 183 steps_per_run: Number of steps to run on device before returning to the 184 host. Note that this can have side-effects on performance, hooks, 185 metrics, summaries etc. 186 This parameter is only used when Distribution Strategy is used with 187 estimator or keras. 188 device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to 189 specify the placement of replicas on the TPU cluster. Currently only 190 supports the usecase of using a single core within a TPU cluster. 191 """ 192 super(TPUStrategyV1, self).__init__(TPUExtended( 193 self, tpu_cluster_resolver, steps_per_run, device_assignment)) 194 distribute_lib.distribution_strategy_gauge.get_cell("V1").set("TPUStrategy") 195 distribute_lib.distribution_strategy_replica_gauge.get_cell( 196 "num_workers").set(self.extended.num_hosts) 197 distribute_lib.distribution_strategy_replica_gauge.get_cell( 198 "num_replicas_per_worker").set(self.extended.num_replicas_per_host) 199 200 @property 201 def steps_per_run(self): 202 """DEPRECATED: use .extended.steps_per_run instead.""" 203 return self._extended.steps_per_run 204 205 # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this 206 # can use the default implementation. 207 # This implementation runs a single step. It does not use infeed or outfeed. 208 def experimental_run_v2(self, fn, args=(), kwargs=None): 209 """See base class.""" 210 validate_experimental_run_function(fn) 211 212 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) 213 return self.extended.tpu_run(fn, args, kwargs) 214 215 216# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1. 217class TPUExtended(distribute_lib.StrategyExtendedV1): 218 """Implementation of TPUStrategy.""" 219 220 def __init__(self, 221 container_strategy, 222 tpu_cluster_resolver=None, 223 steps_per_run=None, 224 device_assignment=None): 225 super(TPUExtended, self).__init__(container_strategy) 226 227 if tpu_cluster_resolver is None: 228 tpu_cluster_resolver = TPUClusterResolver("") 229 230 if steps_per_run is None: 231 # TODO(frankchn): Warn when we are being used by DS/Keras and this is 232 # not specified. 233 steps_per_run = 1 234 235 self._tpu_function_cache = weakref.WeakKeyDictionary() 236 self._tpu_cluster_resolver = tpu_cluster_resolver 237 self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) 238 self._device_assignment = device_assignment 239 240 tpu_devices_flat = [ 241 d.name for d in self._tpu_metadata.devices if "device:TPU:" in d.name] 242 243 # `self._tpu_devices` is a two-dimensional NumPy array of strings. It is 244 # indexed using `[replica_id][logical_device_id]`. 245 if device_assignment is None: 246 self._tpu_devices = np.array( 247 [[d] for d in tpu_devices_flat], dtype=object) 248 else: 249 job_name = device_spec.DeviceSpecV2.from_string(tpu_devices_flat[0]).job 250 251 tpu_devices = [] 252 for replica_id in range(device_assignment.num_replicas): 253 replica_devices = [] 254 255 for logical_core in range(device_assignment.num_cores_per_replica): 256 replica_devices.append( 257 device_util.canonicalize( 258 device_assignment.tpu_device( 259 replica=replica_id, 260 logical_core=logical_core, 261 job=job_name))) 262 263 tpu_devices.append(replica_devices) 264 self._tpu_devices = np.array(tpu_devices, dtype=object) 265 266 self._host_device = device_util.get_host_for_device(self._tpu_devices[0][0]) 267 268 # Preload the data onto the TPUs. Currently we always preload onto logical 269 # device 0 for each replica. 270 # TODO(cjfj): Create `InputWorkers` lazily, allowing users to place the 271 # input onto a different logical device? 272 input_worker_devices = collections.OrderedDict() 273 for tpu_device in self._tpu_devices[:, 0]: 274 host_device = device_util.get_host_for_device(tpu_device) 275 input_worker_devices.setdefault(host_device, []) 276 input_worker_devices[host_device].append(tpu_device) 277 self._input_worker_devices = tuple(input_worker_devices.items()) 278 self._input_workers_obj = None 279 280 # TODO(sourabhbajaj): Remove this once performance of running one step 281 # at a time is comparable to multiple steps. 282 self.steps_per_run = steps_per_run 283 self._require_static_shapes = True 284 285 # TPUStrategy handles the graph replication in TF-XLA bridge, so we don't 286 # need to retrace functions for each device. 287 self._retrace_functions_for_each_device = False 288 289 self.experimental_enable_get_next_as_optional = True 290 self.experimental_enable_dynamic_batch_size = True 291 self._prefetch_on_host = False 292 293 self._logical_device_stack = [0] 294 295 # TODO(bfontain): Remove once a proper dataset API exists for prefetching 296 # a dataset to multiple devices exists. 297 # If value is true, this forces prefetch of data to the host's memeory rather 298 # than the individual TPU device's memory. This is needed when using for TPU 299 # Embeddings as a) sparse tensors cannot be prefetched to the TPU device 300 # memory and b) TPU Embedding enqueue operation are CPU ops and this avoids 301 # a copy back to the host for dense tensors 302 def _set_prefetch_on_host(self, value): 303 if self._prefetch_on_host == value: 304 return 305 if self._input_workers_obj is not None: 306 raise RuntimeError("Unable to change prefetch on host behavior as " 307 "InputWorkers are already created.") 308 self._prefetch_on_host = value 309 if value: 310 # To prefetch on the host, we must set all the input worker devices to the 311 # corresponding host devices. 312 self._input_worker_devices = tuple([ 313 tuple([host, 314 [device_util.get_host_for_device(d) for d in devices]]) 315 for host, devices in self._input_worker_devices]) 316 # Force creation of the workers. 317 workers = self._input_workers 318 del workers 319 320 @property 321 def _input_workers(self): 322 if self._input_workers_obj is None: 323 self._input_workers_obj = input_lib.InputWorkers( 324 self._input_worker_devices) 325 return self._input_workers_obj 326 327 def _validate_colocate_with_variable(self, colocate_with_variable): 328 values.validate_colocate(colocate_with_variable, self) 329 330 def _make_dataset_iterator(self, dataset): 331 """Make iterators for each of the TPU hosts.""" 332 return input_lib.DatasetIterator( 333 dataset, 334 self._input_workers, 335 self._container_strategy(), 336 split_batch_by=self._num_replicas_in_sync) 337 338 def _make_input_fn_iterator( 339 self, 340 input_fn, 341 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 342 input_contexts = [] 343 num_workers = self._input_workers.num_workers 344 for i in range(num_workers): 345 input_contexts.append(distribute_lib.InputContext( 346 num_input_pipelines=num_workers, 347 input_pipeline_id=i, 348 num_replicas_in_sync=self._num_replicas_in_sync)) 349 return input_lib.InputFunctionIterator( 350 input_fn, 351 self._input_workers, 352 input_contexts, 353 self._container_strategy()) 354 355 def _experimental_make_numpy_dataset(self, numpy_input, session): 356 return numpy_dataset.one_host_numpy_dataset( 357 numpy_input, numpy_dataset.SingleDevice(self._host_device), 358 session) 359 360 def _experimental_distribute_dataset(self, dataset): 361 return input_lib.get_distributed_dataset( 362 dataset, 363 self._input_workers, 364 self._container_strategy(), 365 split_batch_by=self._num_replicas_in_sync) 366 367 def _experimental_distribute_datasets_from_function(self, dataset_fn): 368 input_contexts = [] 369 num_workers = self._input_workers.num_workers 370 for i in range(num_workers): 371 input_contexts.append(distribute_lib.InputContext( 372 num_input_pipelines=num_workers, 373 input_pipeline_id=i, 374 num_replicas_in_sync=self._num_replicas_in_sync)) 375 376 return input_lib.get_distributed_datasets_from_function( 377 dataset_fn, 378 self._input_workers, 379 input_contexts, 380 self._container_strategy()) 381 382 # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. 383 # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have 384 # a mechanism to infer the outputs of `fn`. Pending b/110550782. 385 def _experimental_run_steps_on_iterator( 386 self, fn, multi_worker_iterator, iterations, initial_loop_values=None): 387 # Wrap `fn` for repeat. 388 if initial_loop_values is None: 389 initial_loop_values = {} 390 initial_loop_values = nest.flatten(initial_loop_values) 391 ctx = input_lib.MultiStepContext() 392 393 def run_fn(inputs): 394 """Single step on the TPU device.""" 395 fn_result = fn(ctx, inputs) 396 flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) 397 if flat_last_step_outputs: 398 with ops.control_dependencies([fn_result]): 399 return [array_ops.identity(f) for f in flat_last_step_outputs] 400 else: 401 return fn_result 402 403 # We capture the control_flow_context at this point, before we run `fn` 404 # inside a while_loop and TPU replicate context. This is useful in cases 405 # where we might need to exit these contexts and get back to the outer 406 # context to do some things, for e.g. create an op which should be 407 # evaluated only once at the end of the loop on the host. One such usage 408 # is in creating metrics' value op. 409 self._outer_control_flow_context = ( 410 ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access 411 412 def rewrite_fn(*args): 413 """The rewritten step fn running on TPU.""" 414 del args 415 416 per_replica_inputs = multi_worker_iterator.get_next() 417 replicate_inputs = [] 418 for replica_id in range(self._num_replicas_in_sync): 419 select_replica = lambda x: values.select_replica(replica_id, x) # pylint: disable=cell-var-from-loop 420 replicate_inputs.append((nest.map_structure( 421 select_replica, per_replica_inputs),)) 422 423 replicate_outputs = tpu.replicate( 424 run_fn, replicate_inputs, device_assignment=self._device_assignment) 425 426 # If run_fn has tensor outputs, tpu.replicate returns a list of list. We 427 # will flatten it in this case. If run_fn has no tensor outputs, 428 # tpu.replicate returns a list of no_ops, we will keep the output as it 429 # is. 430 if isinstance(replicate_outputs[0], list): 431 replicate_outputs = nest.flatten(replicate_outputs) 432 433 return replicate_outputs 434 435 # TODO(sourabhbajaj): The input to while loop should be based on the 436 # output type of the step_fn 437 assert isinstance(initial_loop_values, list) 438 initial_loop_values = initial_loop_values * self._num_replicas_in_sync 439 440 # Put the while loop op on TPU host 0. 441 with ops.device(self._host_device): 442 if self.steps_per_run == 1: 443 replicate_outputs = rewrite_fn() 444 else: 445 replicate_outputs = training_loop.repeat(iterations, rewrite_fn, 446 initial_loop_values) 447 448 del self._outer_control_flow_context 449 ctx.run_op = control_flow_ops.group(replicate_outputs) 450 451 if isinstance(replicate_outputs, list): 452 # Filter out any ops from the outputs, typically this would be the case 453 # when there were no tensor outputs. 454 last_step_tensor_outputs = [ 455 x for x in replicate_outputs if not isinstance(x, ops.Operation) 456 ] 457 458 # Outputs are currently of the structure (flattened) 459 # [output0_device0, output1_device0, output2_device0, 460 # output0_device1, output1_device1, output2_device1, 461 # ...] 462 # Convert this to the following structure instead: (grouped by output) 463 # [[output0_device0, output0_device1], 464 # [output1_device0, output1_device1], 465 # [output2_device0, output2_device1]] 466 output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync 467 last_step_tensor_outputs = [ 468 last_step_tensor_outputs[i::output_num] for i in range(output_num) 469 ] 470 else: 471 # no tensors returned. 472 last_step_tensor_outputs = [] 473 474 _set_last_step_outputs(ctx, last_step_tensor_outputs) 475 return ctx 476 477 def _call_for_each_replica(self, fn, args, kwargs): 478 # TODO(jhseu): Consider making it so call_for_each_replica implies that 479 # we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly. 480 with _TPUReplicaContext(self._container_strategy()): 481 return fn(*args, **kwargs) 482 483 @contextlib.contextmanager 484 def experimental_logical_device(self, logical_device_id): 485 """Places variables and ops on the specified logical device.""" 486 num_logical_devices_per_replica = self._tpu_devices.shape[1] 487 if logical_device_id >= num_logical_devices_per_replica: 488 raise ValueError( 489 "`logical_device_id` not in range (was {}, but there are only {} " 490 "logical devices per replica).".format( 491 logical_device_id, num_logical_devices_per_replica)) 492 493 self._logical_device_stack.append(logical_device_id) 494 try: 495 if values._enclosing_tpu_context() is None: # pylint: disable=protected-access 496 yield 497 else: 498 with ops.device(tpu.core(logical_device_id)): 499 yield 500 finally: 501 self._logical_device_stack.pop() 502 503 def _experimental_initialize_system(self): 504 """Experimental method added to be used by Estimator. 505 506 This is a private method only to be used by Estimator. Other frameworks 507 should directly be calling `tf.tpu.experimental.initialize_tpu_system` 508 """ 509 tpu_strategy_util.initialize_tpu_system(self._tpu_cluster_resolver) 510 511 def _create_variable(self, next_creator, **kwargs): 512 """Create a TPUMirroredVariable. See `DistributionStrategy.scope`.""" 513 if kwargs.pop("skip_mirrored_creator", False): 514 return next_creator(**kwargs) 515 516 colocate_with = kwargs.pop("colocate_with", None) 517 if colocate_with is None: 518 devices = self._tpu_devices[:, self._logical_device_stack[-1]] 519 elif isinstance(colocate_with, numpy_dataset.SingleDevice): 520 with ops.device(colocate_with.device): 521 return next_creator(**kwargs) 522 else: 523 devices = colocate_with.devices 524 525 def _real_mirrored_creator(**kwargs): # pylint: disable=g-missing-docstring 526 initial_value = None 527 value_list = [] 528 for i, d in enumerate(devices): 529 with ops.device(d): 530 if i == 0: 531 initial_value = kwargs["initial_value"] 532 # Note: some v1 code expects variable initializer creation to happen 533 # inside a init_scope. 534 with maybe_init_scope(): 535 initial_value = initial_value() if callable( 536 initial_value) else initial_value 537 538 if i > 0: 539 # Give replicas meaningful distinct names: 540 var0name = value_list[0].name.split(":")[0] 541 # We append a / to variable names created on replicas with id > 0 to 542 # ensure that we ignore the name scope and instead use the given 543 # name as the absolute name of the variable. 544 kwargs["name"] = "%s/replica_%d/" % (var0name, i) 545 kwargs["initial_value"] = initial_value 546 547 with context.device_policy(context.DEVICE_PLACEMENT_SILENT): 548 v = next_creator(**kwargs) 549 550 assert not isinstance(v, values.TPUMirroredVariable) 551 value_list.append(v) 552 return value_list 553 554 return values.create_mirrored_variable(self._container_strategy(), 555 _real_mirrored_creator, 556 values.TPUMirroredVariable, 557 values.TPUSyncOnReadVariable, 558 **kwargs) 559 560 def _reduce_to(self, reduce_op, value, destinations): 561 if (isinstance(value, values.DistributedValues) or 562 tensor_util.is_tensor(value) 563 ) and values._enclosing_tpu_context() is not None: # pylint: disable=protected-access 564 if reduce_op == reduce_util.ReduceOp.MEAN: 565 # TODO(jhseu): Revisit once we support model-parallelism. 566 value *= (1. / self._num_replicas_in_sync) 567 elif reduce_op != reduce_util.ReduceOp.SUM: 568 raise NotImplementedError( 569 "Currently only support sum & mean in TPUStrategy.") 570 return tpu_ops.cross_replica_sum(value) 571 572 if not isinstance(value, values.DistributedValues): 573 # This function handles reducing values that are not PerReplica or 574 # Mirrored values. For example, the same value could be present on all 575 # replicas in which case `value` would be a single value or value could 576 # be 0. 577 return cross_device_ops_lib.reduce_non_distributed_value( 578 reduce_op, value, destinations, self._num_replicas_in_sync) 579 580 # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`. 581 # Always performs the reduction on the TPU host. 582 with ops.device(self._host_device): 583 output = math_ops.add_n(value.values) 584 if reduce_op == reduce_util.ReduceOp.MEAN: 585 output *= (1. / len(value.values)) 586 587 devices = cross_device_ops_lib.get_devices_from(destinations) 588 589 if len(devices) == 1: 590 # If necessary, copy to requested destination. 591 dest_canonical = device_util.canonicalize(devices[0]) 592 host_canonical = device_util.canonicalize(self._host_device) 593 594 if dest_canonical != host_canonical: 595 with ops.device(dest_canonical): 596 output = array_ops.identity(output) 597 else: 598 output = cross_device_ops_lib.simple_broadcast(output, destinations) 599 600 return output 601 602 def _update(self, var, fn, args, kwargs, group): 603 assert isinstance(var, values.TPUVariableMixin) or isinstance( 604 var, resource_variable_ops.BaseResourceVariable) 605 if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access 606 if group: 607 return fn(var, *args, **kwargs) 608 else: 609 return (fn(var, *args, **kwargs),) 610 611 # Otherwise, we revert to MirroredStrategy behavior and update each variable 612 # directly. 613 updates = [] 614 for i, v in enumerate(var.values): 615 name = "update_%d" % i 616 with ops.device(v.device), \ 617 distribute_lib.UpdateContext(i), \ 618 ops.name_scope(name): 619 # If args and kwargs are not mirrored, the value is returned as is. 620 updates.append(fn(v, 621 *values.select_replica_mirrored(i, args), 622 **values.select_replica_mirrored(i, kwargs))) 623 return values.update_regroup(self, updates, group) 624 625 def read_var(self, var): 626 assert isinstance(var, values.TPUVariableMixin) or isinstance( 627 var, resource_variable_ops.BaseResourceVariable) 628 return var.read_value() 629 630 def _local_results(self, val): 631 if isinstance(val, values.DistributedValues): 632 return val.values 633 return (val,) 634 635 def value_container(self, value): 636 return value 637 638 def _broadcast_to(self, tensor, destinations): 639 del destinations 640 # This is both a fast path for Python constants, and a way to delay 641 # converting Python values to a tensor until we know what type it 642 # should be converted to. Otherwise we have trouble with: 643 # global_step.assign_add(1) 644 # since the `1` gets broadcast as an int32 but global_step is int64. 645 if isinstance(tensor, (float, int)): 646 return tensor 647 if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access 648 broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)] 649 result = tpu_ops.all_to_all( 650 broadcast_tensor, 651 concat_dimension=0, 652 split_dimension=0, 653 split_count=self._num_replicas_in_sync) 654 655 # This uses the broadcasted value from the first replica because the only 656 # caller of this is for ONLY_FIRST_REPLICA variables aggregation. 657 return result[0] 658 return tensor 659 660 @property 661 def num_hosts(self): 662 if self._device_assignment is None: 663 return self._tpu_metadata.num_hosts 664 665 return len(set([self._device_assignment.host_device(r) 666 for r in range(self._device_assignment.num_replicas)])) 667 668 @property 669 def num_replicas_per_host(self): 670 if self._device_assignment is None: 671 return self._tpu_metadata.num_of_cores_per_host 672 673 # TODO(sourabhbajaj): Remove this method we use inputs and remove infeed 674 # as the computation of num_replicas_per_host is not a constant 675 # when using device_assignment. This is a temporary workaround to support 676 # StatefulRNN as everything is 1 in that case. 677 # This method needs to take host_id as input for correct computation. 678 max_models_per_host = (self._tpu_metadata.num_of_cores_per_host // 679 self._device_assignment.num_cores_per_replica) 680 return min(self._device_assignment.num_replicas, max_models_per_host) 681 682 @property 683 def _num_replicas_in_sync(self): 684 if self._device_assignment is None: 685 return self._tpu_metadata.num_cores 686 return self._device_assignment.num_replicas 687 688 @property 689 def experimental_between_graph(self): 690 return False 691 692 @property 693 def experimental_should_init(self): 694 return True 695 696 @property 697 def should_checkpoint(self): 698 return True 699 700 @property 701 def should_save_summary(self): 702 return True 703 704 @property 705 def worker_devices(self): 706 return tuple(self._tpu_devices[:, self._logical_device_stack[-1]]) 707 708 @property 709 def parameter_devices(self): 710 return self.worker_devices 711 712 def non_slot_devices(self, var_list): 713 return self._host_device 714 715 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 716 del colocate_with 717 with ops.device(self._host_device), distribute_lib.UpdateContext(None): 718 result = fn(*args, **kwargs) 719 if group: 720 return result 721 else: 722 return nest.map_structure(self._local_results, result) 723 724 def _configure(self, 725 session_config=None, 726 cluster_spec=None, 727 task_type=None, 728 task_id=None): 729 del cluster_spec, task_type, task_id 730 if session_config: 731 session_config.CopyFrom(self._update_config_proto(session_config)) 732 733 def _update_config_proto(self, config_proto): 734 updated_config = copy.deepcopy(config_proto) 735 updated_config.isolate_session_state = True 736 cluster_spec = self._tpu_cluster_resolver.cluster_spec() 737 if cluster_spec: 738 updated_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) 739 return updated_config 740 741 # TODO(priyag): Delete this once all strategies use global batch size. 742 @property 743 def _global_batch_size(self): 744 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. 745 746 `make_input_fn_iterator` assumes per-replica batching. 747 748 Returns: 749 Boolean. 750 """ 751 return True 752 753 def tpu_run(self, fn, args, kwargs): 754 func = self._tpu_function_creator(fn) 755 return func(args, kwargs) 756 757 def _tpu_function_creator(self, fn): 758 if fn in self._tpu_function_cache: 759 return self._tpu_function_cache[fn] 760 761 strategy = self._container_strategy() 762 763 def tpu_function(args, kwargs): 764 """TF Function used to replicate the user computation.""" 765 if kwargs is None: 766 kwargs = {} 767 768 # Remove None at the end of args as they are not replicatable 769 # If there are None in the middle we can't do anything about it 770 # so let those cases fail. 771 # For example when Keras model predict is used they pass the targets as 772 # None. We want to handle it here so all client libraries don't have to 773 # do this as other strategies can handle None values better. 774 while args and args[-1] is None: 775 args = args[:-1] 776 777 # Used to re-structure flattened output tensors from `tpu.replicate()` 778 # into a structured format. 779 result = [[]] 780 781 def replicated_fn(replica_id, replica_args, replica_kwargs): 782 """Wraps user function to provide replica ID and `Tensor` inputs.""" 783 with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id): 784 result[0] = fn(*replica_args, **replica_kwargs) 785 return result[0] 786 787 replicate_inputs = [] # By replica. 788 for i in range(strategy.num_replicas_in_sync): 789 replicate_inputs.append( 790 [constant_op.constant(i, dtype=dtypes.int32), 791 values.select_replica(i, args), 792 values.select_replica(i, kwargs)]) 793 794 # Construct and pass `maximum_shapes` so that we could support dynamic 795 # shapes using dynamic padder. 796 if self.experimental_enable_dynamic_batch_size and replicate_inputs: 797 maximum_shapes = [] 798 flattened_list = nest.flatten(replicate_inputs[0]) 799 for input_tensor in flattened_list: 800 if tensor_util.is_tensor(input_tensor): 801 rank = input_tensor.get_shape().rank 802 else: 803 rank = np.rank(input_tensor) 804 maximum_shape = tensor_shape.TensorShape([None] * rank) 805 maximum_shapes.append(maximum_shape) 806 maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], 807 maximum_shapes) 808 else: 809 maximum_shapes = None 810 811 with strategy.scope(): 812 replicate_outputs = tpu.replicate( 813 replicated_fn, 814 replicate_inputs, 815 device_assignment=self._device_assignment, 816 maximum_shapes=maximum_shapes) 817 818 # Remove all no ops that may have been added during 'tpu.replicate()' 819 if isinstance(result[0], list): 820 result[0] = [ 821 output for output in result[0] if not isinstance( 822 output, ops.Operation) 823 ] 824 825 # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. 826 if result[0] is None or isinstance(result[0], ops.Operation): 827 replicate_outputs = [None] * len(replicate_outputs) 828 else: 829 replicate_outputs = [ 830 nest.pack_sequence_as(result[0], nest.flatten(replica_output)) 831 for replica_output in replicate_outputs 832 ] 833 return values.regroup(replicate_outputs) 834 835 if context.executing_eagerly(): 836 tpu_function = def_function.function(tpu_function) 837 838 self._tpu_function_cache[fn] = tpu_function 839 return tpu_function 840 841 def _in_multi_worker_mode(self): 842 """Whether this strategy indicates working in multi-worker settings.""" 843 # TPUStrategy has different distributed training structure that the whole 844 # cluster should be treated as single worker from higher-level (e.g. Keras) 845 # library's point of view. 846 # TODO(rchao): Revisit this as we design a fault-tolerance solution for 847 # TPUStrategy. 848 return False 849 850 851class _TPUReplicaContext(distribute_lib.ReplicaContext): 852 """Replication Context class for TPU Strategy.""" 853 854 # TODO(sourabhbajaj): Call for each replica should be updating this. 855 # TODO(b/118385803): Always properly initialize replica_id. 856 def __init__(self, strategy, replica_id_in_sync_group=None): 857 if replica_id_in_sync_group is None: 858 replica_id_in_sync_group = constant_op.constant(0, dtypes.int32) 859 distribute_lib.ReplicaContext.__init__( 860 self, strategy, replica_id_in_sync_group=replica_id_in_sync_group) 861 862 @property 863 def devices(self): 864 distribute_lib.require_replica_context(self) 865 ds = self._strategy 866 replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) 867 868 if replica_id is None: # Non-constant `Tensor` inside `tpu.replicate`. 869 # TODO(cjfj): Return other devices when model parallelism is supported. 870 return (tpu.core(0),) 871 else: 872 return (ds.extended.worker_devices[replica_id],) 873 874 def experimental_logical_device(self, logical_device_id): 875 """Places variables and ops on the specified logical device.""" 876 return self.strategy.extended.experimental_logical_device(logical_device_id) 877 878 879def _set_last_step_outputs(ctx, last_step_tensor_outputs): 880 """Sets the last step outputs on the given context.""" 881 # Convert replicate_outputs to the original dict structure of 882 # last_step_outputs. 883 last_step_tensor_outputs_dict = nest.pack_sequence_as( 884 ctx.last_step_outputs, last_step_tensor_outputs) 885 886 for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access 887 output = last_step_tensor_outputs_dict[name] 888 # For outputs that have already been reduced, take the first value 889 # from the list as each value should be the same. Else return the full 890 # list of values. 891 # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica 892 # value. 893 if reduce_op is not None: 894 # TODO(priyag): Should this return the element or a list with 1 element 895 last_step_tensor_outputs_dict[name] = output[0] 896 ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access 897