1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ====================================== 15 16"""Library of TPU helper functions.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import enum 24import typing 25from typing import Any, Callable, Iterable, List, Optional, Text, Tuple, Union 26 27from absl import logging 28import numpy as np 29from six.moves import xrange # pylint: disable=redefined-builtin 30 31from tensorflow.compiler.tf2xla.python import xla as tf2xla 32from tensorflow.core.framework import attr_value_pb2 33from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding 34from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as embedding_pb2 35from tensorflow.python.compiler.xla import xla 36from tensorflow.python.distribute import device_util 37from tensorflow.python.distribute import distribution_strategy_context 38from tensorflow.python.framework import auto_control_deps 39from tensorflow.python.framework import c_api_util 40from tensorflow.python.framework import composite_tensor 41from tensorflow.python.framework import config 42from tensorflow.python.framework import constant_op 43from tensorflow.python.framework import device as pydev 44from tensorflow.python.framework import dtypes 45from tensorflow.python.framework import errors 46from tensorflow.python.framework import func_graph 47from tensorflow.python.framework import function 48from tensorflow.python.framework import ops 49from tensorflow.python.framework import tensor_shape 50from tensorflow.python.ops import array_ops 51from tensorflow.python.ops import control_flow_ops 52from tensorflow.python.ops import math_ops 53from tensorflow.python.ops import variable_scope 54from tensorflow.python.ops import variables 55from tensorflow.python.tpu import device_assignment as device_assignment_lib 56from tensorflow.python.tpu import tpu_feed 57from tensorflow.python.tpu import tpu_function 58from tensorflow.python.tpu import tpu_name_util 59from tensorflow.python.tpu.ops import tpu_ops 60from tensorflow.python.types import core as core_types 61from tensorflow.python.util import compat 62from tensorflow.python.util import nest 63from tensorflow.python.util import object_identity 64from tensorflow.python.util.tf_export import tf_export 65 66ops.NotDifferentiable("TPUReplicatedInput") 67 68# Operations that indicate some error in the users graph, e.g. a placeholder 69# that's introduced outside of the infeed. 70_DENYLISTED_OPS = set([ 71 "Placeholder", 72]) 73 74# XLA doesn't currently support reading of intermediate tensors, thus some ops 75# are not supported. 76_UNSUPPORTED_OPS = set([ 77 "AudioSummary", 78 "AudioSummaryV2", 79 "HistogramSummary", 80 "ImageSummary", 81 "MergeSummary", 82 "Print", 83 "ScalarSummary", 84 "TensorSummary", 85 "TensorSummaryV2", 86 ]) 87 88# Ops which can be safely pruned from XLA compile if they have no consumers. 89# These ops should also have no inputs. 90_UNCONNECTED_OPS_TO_PRUNE = set(["Placeholder", "VarHandleOp"]) 91 92_MAX_WARNING_LINES = 5 93 94_TPU_REPLICATE_ATTR = "_tpu_replicate" 95_POST_DEVICE_REWRITE_ATTR = "_post_device_rewrite" 96_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status" 97_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation" 98_PIVOT_FOR_CLUSTER = "_pivot_for_cluster" 99 100 101core = tpu_name_util.core 102 103 104def _tpu_system_device_name(job: Optional[Text]) -> Text: 105 """Returns the device name for the TPU_SYSTEM device of `job`.""" 106 if job is None: 107 return "/device:TPU_SYSTEM:0" 108 else: 109 return "/job:%s/device:TPU_SYSTEM:0" % job 110 111 112@tf_export(v1=["tpu.initialize_system"]) 113def initialize_system( 114 embedding_config: Optional[embedding_pb2.TPUEmbeddingConfiguration] = None, 115 job: Optional[Text] = None, 116 compilation_failure_closes_chips: bool = True 117) -> core_types.Tensor: 118 """Initializes a distributed TPU system for use with TensorFlow. 119 120 Args: 121 embedding_config: If not None, a `TPUEmbeddingConfiguration` proto 122 describing the desired configuration of the hardware embedding lookup 123 tables. If embedding_config is None, no hardware embeddings can be used. 124 job: The job (the XXX in TensorFlow device specification /job:XXX) that 125 contains the TPU devices that will be initialized. If job=None it is 126 assumed there is only one job in the TensorFlow flock, and an error will 127 be returned if this assumption does not hold. 128 compilation_failure_closes_chips: Set the configuration whether 129 we want to close TPU chips when there is a compilation failure. 130 Returns: 131 A serialized `TopologyProto` that describes the TPU system. Note: 132 the topology must be evaluated using `Session.run` before it can be used. 133 """ 134 config_string = ("" if embedding_config is None else 135 embedding_config.SerializeToString()) 136 with ops.device(_tpu_system_device_name(job)): 137 topology = tpu_ops.configure_distributed_tpu( 138 compilation_failure_closes_chips=compilation_failure_closes_chips) 139 140 if embedding_config is None: 141 return topology 142 143 # This set of control dependencies is needed as this function is expected to 144 # return an op which will return the topology when executed, but we need to 145 # call the embedding initialization op between initializing the TPU and 146 # returning the topology. 147 with ops.control_dependencies([topology]): 148 embedding_init = tpu_ops.configure_tpu_embedding(config=config_string) 149 with ops.control_dependencies([embedding_init]): 150 return array_ops.identity(topology, name="tpu_init_identity") 151 152 153def initialize_system_for_tpu_embedding( 154 embedding_config: embedding_pb2.TPUEmbeddingConfiguration, 155 job: Optional[Text] = None, 156) -> ops.Operation: 157 """Initializes a distributed TPU Embedding system for use with TensorFlow. 158 159 The following two are equivalent: 160 1. initialize_system() with embedding_config. 161 2. initialize_system() without embedding_config, then 162 initialize_system_for_tpu_embedding(). 163 initialize_system() should not be called with embedding_config if 164 initialize_system_for_tpu_embedding() is meant to be called later. 165 166 Args: 167 embedding_config: a `TPUEmbeddingConfiguration` proto describing the desired 168 configuration of the hardware embedding lookup tables. 169 job: The job (the XXX in TensorFlow device specification /job:XXX) that 170 contains the TPU devices that will be initialized. If job=None it is 171 assumed there is only one job in the TensorFlow flock, and an error will 172 be returned if this assumption does not hold. 173 174 Returns: 175 A no-op. 176 """ 177 config_string = embedding_config.SerializeToString() 178 with ops.device(_tpu_system_device_name(job)): 179 return tpu_ops.configure_tpu_embedding(config=config_string) 180 181 182@tf_export(v1=["tpu.shutdown_system"]) 183def shutdown_system(job: Optional[Text] = None) -> ops.Operation: 184 """Shuts down a running a distributed TPU system. 185 186 Args: 187 job: The job (the XXX in TensorFlow device specification /job:XXX) that 188 contains the TPU devices that will be shutdown. If job=None it is 189 assumed there is only one job in the TensorFlow flock, and an error will 190 be returned if this assumption does not hold. 191 """ 192 with ops.device(_tpu_system_device_name(job)): 193 shutdown_distributed_tpu = tpu_ops.shutdown_distributed_tpu() 194 return shutdown_distributed_tpu 195 196 197def _enclosing_tpu_context_and_graph() -> Tuple[Any, Any]: 198 """Returns the TPUReplicateContext and its associated graph.""" 199 graph = ops.get_default_graph() 200 while graph is not None: 201 # pylint: disable=protected-access 202 context_ = graph._get_control_flow_context() 203 # pylint: enable=protected-access 204 while context_ is not None: 205 if isinstance(context_, TPUReplicateContext): 206 return context_, graph 207 context_ = context_.outer_context 208 graph = getattr(graph, "outer_graph", None) 209 raise ValueError("get_replicated_var_handle() called without " 210 "TPUReplicateContext. This shouldn't happen. Please file " 211 "a bug.") 212 213 214def is_tpu_strategy(strategy: Any) -> bool: 215 is_tpu_strat = lambda k: k.__name__.startswith("TPUStrategy") 216 clz = strategy.__class__ 217 return is_tpu_strat(clz) or any(map(is_tpu_strat, clz.__bases__)) 218 219 220def _enclosing_tpu_device_assignment( 221) -> Optional[device_assignment_lib.DeviceAssignment]: 222 if not distribution_strategy_context.has_strategy(): 223 return None 224 strategy = distribution_strategy_context.get_strategy() 225 if not is_tpu_strategy(strategy): 226 return None 227 return strategy.extended._device_assignment # pylint: disable=protected-access 228 229 230@auto_control_deps.register_acd_resource_resolver 231def tpu_replicated_input_resolver( 232 op: ops.Operation, 233 resource_reads: object_identity.ObjectIdentitySet, 234 resource_writes: object_identity.ObjectIdentitySet) -> bool: 235 """Replaces TPUReplicatedInput outputs with its inputs in resource_inputs.""" 236 # Ignore TPUReplicatedInput for ACD purposes since we will be directly adding 237 # control deps on the replicated inputs. 238 if op.type == "TPUReplicatedInput": 239 if resource_reads or resource_writes: 240 resource_reads.clear() 241 resource_writes.clear() 242 return True 243 else: 244 return False 245 # Replace tensors in `resource_inputs` which are outputs of TPUReplicatedInput 246 # with the actual replicated inputs. This allows ACD to correct add control 247 # deps when there are multiple calls to `run` in a 248 # `tf.function`. 249 def replace_with_unreplicated_resources(resource_inputs): 250 """Replaces handles in `resource_inputs` with their unreplicated inputs.""" 251 to_remove = [] 252 to_add = [] 253 for resource in resource_inputs: 254 if resource.op.type == "TPUReplicatedInput": 255 to_remove.append(resource) 256 to_add.extend(resource.op.inputs) 257 for t in to_remove: 258 resource_inputs.discard(t) 259 resource_inputs.update(to_add) 260 return to_add or to_remove 261 262 return bool(replace_with_unreplicated_resources(resource_reads) or 263 replace_with_unreplicated_resources(resource_writes)) 264 265 266class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): 267 """A `ControlFlowContext` for nodes inside a TPU computation. 268 269 The primary role of `TPUReplicateContext` is to mark operators inside a 270 tpu.replicate() computation with the attribute "_tpu_replicate=XYZ", where XYZ 271 is a unique name. 272 273 We use a `ControlFlowContext` to perform the annotation since it integrates 274 with Tensorflow constructs like ResourceVariables. For example, if a 275 `ResourceVariable` is constructed inside a tpu.replicate() block, the 276 `ResourceVariable` implementation can use 277 `with ops.control_dependencies(None)` to build the variable's definition 278 outside the replicated computation. 279 """ 280 281 def __init__(self, name: Text, num_replicas: int, pivot: ops.Operation): 282 """Builds a new TPUReplicateContext. 283 284 Args: 285 name: a unique name for the context, used to populate the `_tpu_replicate` 286 attribute. 287 num_replicas: an integer that gives the number of replicas for the 288 computation. 289 pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any 290 inputs will have a control dependency on the pivot node. This ensures 291 that nodes are correctly included in any enclosing control flow 292 contexts. 293 """ 294 super(TPUReplicateContext, self).__init__() 295 self._num_replicas = num_replicas 296 self._outer_device_function_stack = None 297 self._oc_dev_fn_stack = None 298 self._outside_compilation_cluster = None 299 self._outside_compilation_v2_context = None 300 self._outside_compilation_counter = 0 301 self._in_gradient_colocation = None 302 self._gradient_colocation_stack = [] 303 self._host_compute_core = [] 304 self._name = name 305 self._name_as_bytes = compat.as_bytes(name) 306 self._tpu_relicate_attr_buf = c_api_util.ScopedTFBuffer( 307 attr_value_pb2.AttrValue(s=self._name_as_bytes).SerializeToString()) 308 self._unsupported_ops = [] 309 self._pivot = pivot 310 self._replicated_vars = {} 311 312 def get_replicated_var_handle(self, 313 name: Text, 314 vars_: Union[List[core_types.Tensor], 315 List[variables.Variable]], 316 is_mirrored: bool = False, 317 is_packed: bool = False) -> core_types.Tensor: 318 """Returns a variable handle for replicated TPU variable 'var'. 319 320 This is a method used by an experimental replicated variable implementation 321 and is not intended as a public API. 322 323 Args: 324 name: The common name of the variable. 325 vars_: The replicated TPU variables or handles. 326 is_mirrored: Whether the variables are mirrored, which guarantees the 327 values in each replica are always the same. 328 is_packed: Whether the replicated variables are packed into one variable. 329 330 Returns: 331 The handle of the TPU replicated input node. 332 """ 333 device_assignment = _enclosing_tpu_device_assignment() 334 # We don't need to put device assignment as part of the replicated_vars key 335 # because each TPUReplicateContext will only have one device assignment. 336 handle = self._replicated_vars.get(name) 337 if handle is not None: 338 return handle 339 340 if device_assignment is not None and not is_packed: 341 # Find a variable copy for each replica in the device assignment. 342 # Note that the order of devices for replicas for the variable and the 343 # device assignment might not match. 344 job_name = pydev.DeviceSpec.from_string(vars_[0].device).job 345 devices_to_vars = {device_util.canonicalize(v.device): v for v in vars_} 346 replicated_vars = [] 347 for replica_id in range(device_assignment.num_replicas): 348 for logical_core in range(device_assignment.num_cores_per_replica): 349 device = device_util.canonicalize( 350 device_assignment.tpu_device( 351 replica=replica_id, logical_core=logical_core, job=job_name)) 352 if device in devices_to_vars: 353 replicated_vars.append(devices_to_vars[device]) 354 break 355 else: 356 raise ValueError( 357 "Failed to find a variable on any device in replica {} for " 358 "current device assignment".format(replica_id)) 359 else: 360 replicated_vars = vars_ 361 362 # Builds a TPUReplicatedInput node for the variable, if one does not already 363 # exist. The TPUReplicatedInput node must belong to the enclosing 364 # control-flow scope of the TPUReplicateContext. 365 # TODO(phawkins): consider changing the contract of the TPU encapsulation 366 # so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope 367 # instead. 368 369 _, graph = _enclosing_tpu_context_and_graph() 370 with graph.as_default(): 371 # If replicated_vars are variables, get the handles. Note that this can be 372 # done inside TPUReplicateContext because replicated_vars.handle may 373 # create new ops. 374 if isinstance(replicated_vars[0], variables.Variable): 375 replicated_vars = [v.handle for v in replicated_vars] 376 # pylint: disable=protected-access 377 saved_context = graph._get_control_flow_context() 378 graph._set_control_flow_context(self.outer_context) 379 handle = tpu_ops.tpu_replicated_input(replicated_vars, 380 name=name + "/handle", 381 is_mirrored_variable=is_mirrored, 382 is_packed=is_packed) 383 graph._set_control_flow_context(saved_context) 384 # pylint: enable=protected-access 385 self._replicated_vars[name] = handle 386 return handle 387 388 def report_unsupported_operations(self) -> None: 389 if self._unsupported_ops: 390 op_str = "\n".join(" %s (%s)" % (op.type, op.name) 391 for op in self._unsupported_ops[:_MAX_WARNING_LINES]) 392 logging.warning("%d unsupported operations found: \n%s", 393 len(self._unsupported_ops), op_str) 394 if len(self._unsupported_ops) > _MAX_WARNING_LINES: 395 logging.warning("... and %d more" % 396 (len(self._unsupported_ops) - _MAX_WARNING_LINES)) 397 398 def EnterGradientColocation(self, op: ops.Operation, gradient_uid: Text): 399 if op is not None: 400 if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access 401 # If we are in TF 2 functions (control flow V2 functions, or 402 # tf.function()), we need to attach _xla_outside_compilation attribute 403 # directly because we are not in TPUReplicateContext. 404 try: 405 outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii") 406 except ValueError: 407 # The attr was not present: do nothing. 408 return 409 parts = outside_attr.split(".") 410 cluster = parts[0] + "." + gradient_uid 411 self._outside_compilation_v2_context = OutsideCompilationV2Context( 412 cluster) 413 self._outside_compilation_v2_context.Enter() 414 return 415 self._gradient_colocation_stack.append(op) 416 if not self._outside_compilation_cluster: 417 try: 418 outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii") 419 if self._in_gradient_colocation: 420 raise NotImplementedError( 421 "Cannot nest gradient colocation operations outside compilation" 422 ) 423 if gradient_uid == "__unsupported__": 424 raise NotImplementedError( 425 "No gradient_uid calling gradient within outside_compilation") 426 # When we take the gradient of an op X in an outside_compilation 427 # cluster C in a forward computation we would like to put the ops 428 # corresponding to the gradient of X into a new outside_compilation 429 # cluster C'. However, if we take the gradient of X twice, the second 430 # one should get yet another new outside_compilation cluster C''. 431 # 432 # The mechanism we adopt is to use a 'root_cluster' which is the 433 # cluster that X was in before we took gradients, and a 'gradient_uid' 434 # which is different for every invocation of gradients, and put the 435 # gradient of X in cluster 'root_cluster.gradient_uid'. 436 # 437 # When taking a gradient of a gradient, some ops will be colocated 438 # with Op in the forward pass (e.g., cluster root_cluster) and some in 439 # the backward pass (e.g., cluster root_cluster.initial_gradient_uid). 440 # We need all of the grad-of-grad ops to be in the same cluster to 441 # avoid cyclic dependencies between clusters. We adopt a heuristic 442 # that puts any op clustered with root_cluster.<xxx> in 443 # root_cluster.gradient_uid, even if xxx was initial_gradient_uid. 444 self._in_gradient_colocation = op 445 parts = outside_attr.split(".") 446 cluster = parts[0] + "." + gradient_uid 447 self._EnterOutsideCompilationScope(cluster=cluster) 448 except ValueError: 449 # The attr was not present: do nothing. 450 pass 451 452 def ExitGradientColocation(self, op: ops.Operation, gradient_uid: Text): 453 if op is not None: 454 if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access 455 # Inside a TF2 tf.function or control flow graph and `op` was not 456 # marked to be outside compiled. 457 assert self._outside_compilation_v2_context is None 458 return 459 if self._outside_compilation_v2_context is not None: 460 # Inside a TF2 tf.function or control flow graph and `op` was 461 # marked to be outside compiled. 462 self._outside_compilation_v2_context.Exit() 463 self._outside_compilation_v2_context = None 464 return 465 if not self._gradient_colocation_stack: 466 raise errors.InternalError( 467 op.node_def, op, 468 f"Badly nested gradient colocation: empty stack when popping Op {op.name}" 469 ) 470 last_op = self._gradient_colocation_stack.pop() 471 if op is last_op: 472 if op is self._in_gradient_colocation: 473 self._in_gradient_colocation = None 474 self._ExitOutsideCompilationScope() 475 else: 476 raise errors.InternalError( 477 op.node_def, op, 478 f"Badly nested gradient colocation, expected {last_op}, got {op.name}" 479 ) 480 481 def _EnterOutsideCompilationScope(self, cluster: Optional[Text] = None): 482 483 class FakeOp(object): 484 """A helper class to determine the current device. 485 486 Supports only the type and device set/get methods needed to run the 487 graph's _apply_device_function method. 488 """ 489 490 def __init__(self): 491 self._device = "" 492 493 @property 494 def type(self): 495 return "FakeOp" 496 497 @property 498 def device(self): 499 return self._device 500 501 def _set_device(self, device): 502 if isinstance(device, pydev.DeviceSpec): 503 self._device = device.to_string() 504 else: 505 self._device = device 506 507 def _set_device_from_string(self, device_str): 508 self._device = device_str 509 510 if self._outside_compilation_cluster: 511 raise NotImplementedError("Cannot nest outside_compilation clusters") 512 if cluster: 513 self._outside_compilation_cluster = cluster 514 else: 515 self._outside_compilation_cluster = str(self._outside_compilation_counter) 516 self._outside_compilation_counter += 1 517 graph = ops.get_default_graph() 518 fake_op = FakeOp() 519 graph._apply_device_functions(fake_op) # pylint: disable=protected-access 520 device = pydev.DeviceSpec.from_string(fake_op.device) 521 if (device.device_type == "TPU_REPLICATED_CORE" and 522 device.device_index is not None): 523 self._host_compute_core.append(self._outside_compilation_cluster + ":" + 524 str(device.device_index)) 525 self._oc_dev_fn_stack = graph._device_function_stack # pylint: disable=protected-access 526 graph._device_function_stack = self._outer_device_function_stack # pylint: disable=protected-access 527 528 def _ExitOutsideCompilationScope(self): 529 if not self._outside_compilation_cluster: 530 raise ValueError( 531 "Attempted to exit outside_compilation scope when not in scope") 532 self._outside_compilation_cluster = None 533 graph = ops.get_default_graph() 534 graph._device_function_stack = self._oc_dev_fn_stack # pylint: disable=protected-access 535 536 def Enter(self) -> None: 537 if not self._outer_device_function_stack: 538 # Capture the device function stack at the time of first entry 539 # since that is the stack that will be used outside_compilation. 540 graph = ops.get_default_graph() 541 # pylint: disable=protected-access 542 self._outer_device_function_stack = graph._device_function_stack.copy() 543 # pylint: enable=protected-access 544 super(TPUReplicateContext, self).Enter() 545 546 def HostComputeCore(self) -> List[Text]: 547 return self._host_compute_core 548 549 def _RemoveExternalControlEdges( 550 self, op: ops.Operation 551 ) -> Tuple[List[ops.Operation], List[ops.Operation]]: 552 """Remove any external control dependency on this op.""" 553 internal_control_inputs = [] 554 external_control_inputs = [] 555 for x in op.control_inputs: 556 # pylint: disable=protected-access 557 is_internal_op = False 558 ctxt = x._get_control_flow_context() 559 while ctxt is not None: 560 if ctxt == self: 561 is_internal_op = True 562 break 563 ctxt = ctxt._outer_context 564 if is_internal_op: 565 internal_control_inputs.append(x) 566 else: 567 external_control_inputs.append(x) 568 # pylint: enable=protected-access 569 # pylint: disable=protected-access 570 op._remove_all_control_inputs() 571 op._add_control_inputs(internal_control_inputs) 572 # pylint: enable=protected-access 573 return internal_control_inputs, external_control_inputs 574 575 def AddOp(self, op: ops.Operation) -> None: 576 # pylint: disable=protected-access 577 if op.type in _DENYLISTED_OPS: 578 logging.error("Operation of type %s (%s) is not supported on the TPU. " 579 "Execution will fail if this op is used in the graph. ", 580 op.type, op.name) 581 582 if op.type in _UNSUPPORTED_OPS: 583 self._unsupported_ops.append(op) 584 585 if any(x.dtype._is_ref_dtype for x in op.inputs): 586 raise NotImplementedError( 587 f"Non-resource Variables are not supported inside TPU computations " 588 f"(operator name: {op.name})") 589 590 # TensorFlowOpLayer may clone nodes that are in tpu.rewrite()s. It'll add 591 # the "_cloned" attribute and we should continue in that case. 592 if (_TPU_REPLICATE_ATTR in op.node_def.attr and 593 "_cloned" not in op.node_def.attr): 594 raise ValueError(f"TPU computations cannot be nested on op ({op})") 595 op._set_attr_with_buf(_TPU_REPLICATE_ATTR, 596 self._tpu_relicate_attr_buf.buffer) 597 if self._outside_compilation_cluster: 598 op._set_attr( 599 _OUTSIDE_COMPILATION_ATTR, 600 attr_value_pb2.AttrValue( 601 s=compat.as_bytes(self._outside_compilation_cluster))) 602 if self._num_replicas > 1 or not self._outside_compilation_cluster: 603 # Prevent feeding or fetching anything that is being compiled, 604 # and any replicated outside_compilation Op. 605 op.graph.prevent_feeding(op) 606 op.graph.prevent_fetching(op) 607 608 # Remove any control edges from outer control flow contexts. These may cause 609 # mismatched frame errors. 610 (internal_control_inputs, 611 external_control_inputs) = self._RemoveExternalControlEdges(op) 612 613 if not op.inputs: 614 # Add a control edge from the control pivot to this op. 615 if not internal_control_inputs: 616 # pylint: disable=protected-access 617 op._add_control_input(self.GetControlPivot()) 618 # pylint: enable=protected-access 619 else: 620 for index in xrange(len(op.inputs)): 621 x = op.inputs[index] 622 real_x = self.AddValue(x) 623 if real_x is not x: 624 op._update_input(index, real_x) # pylint: disable=protected-access 625 626 if external_control_inputs: 627 # Use an identity to pull control inputs as data inputs. Note that we 628 # ignore ops which don't have outputs. TODO(phawkins): fix that. 629 with ops.control_dependencies(None): 630 self.Enter() 631 external_control_inputs = [ 632 array_ops.identity(x.outputs[0]).op 633 for x in external_control_inputs 634 if x.outputs 635 ] 636 self.Exit() 637 # pylint: disable=protected-access 638 op._add_control_inputs(external_control_inputs) 639 # pylint: enable=protected-access 640 641 # Mark op's outputs as seen by this context and any outer contexts. 642 output_names = [x.name for x in op.outputs] 643 context = self 644 while context is not None: 645 # pylint: disable=protected-access 646 context._values.update(output_names) 647 context = context._outer_context 648 # pylint: enable=protected-access 649 650 if self._outer_context: 651 self._outer_context.AddInnerOp(op) 652 653 def AddValue(self, val: core_types.Tensor) -> core_types.Tensor: 654 """Add `val` to the current context and its outer context recursively.""" 655 if not self._outer_context: 656 return val 657 658 if val.name in self._values: 659 # Use the real value if it comes from outer context. 660 result = self._external_values.get(val.name) 661 return val if result is None else result 662 663 result = val 664 self._values.add(val.name) 665 if self._outer_context: 666 result = self._outer_context.AddValue(val) 667 self._values.add(result.name) 668 669 self._external_values[val.name] = result 670 671 return result 672 673 def AddInnerOp(self, op: ops.Operation): 674 self.AddOp(op) 675 if self._outer_context: 676 self._outer_context.AddInnerOp(op) 677 678 @property 679 def grad_state(self): 680 # Define the gradient loop state associated with the TPUReplicateContext to 681 # be None as the TPUReplicateContext does not get nested nor does the 682 # grad_state outside the TPUReplicateContext affect the graph inside so the 683 # grad_state should be as if this is the top-level gradient state. 684 return None 685 686 @property 687 def back_prop(self): 688 """Forwards to the enclosing while context, if any.""" 689 if self.GetWhileContext(): 690 return self.GetWhileContext().back_prop 691 return False 692 693 def GetControlPivot(self) -> ops.Operation: 694 return self._pivot 695 696 def RequiresUniqueFunctionRetracing(self): 697 # More context: b/158152827. TPU stack uses the TPUReplicateContext to 698 # create replicated variable handles and cluster TPU computations, thus we 699 # always retrace a tf.function when the wrapped TPUReplicateContext changes. 700 return True 701 702 703class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext): 704 """The context for outside compilation in Tensorflow 2.0. 705 706 Every op added in this context will be assigned an _xla_outside_compilation 707 attribute. 708 """ 709 710 def __init__(self, name: Text): 711 control_flow_ops.ControlFlowContext.__init__(self) 712 self._name = name 713 714 def AddOp(self, op: ops.Operation) -> None: 715 if self._outer_context: 716 self._outer_context.AddOp(op) 717 # pylint: disable=protected-access 718 op._set_attr("_xla_outside_compilation", 719 attr_value_pb2.AttrValue(s=compat.as_bytes(self._name))) 720 # pylint: enable=protected-access 721 722 def AddInnerOp(self, op: ops.Operation) -> None: 723 if self._outer_context: 724 self._outer_context.AddInnerOp(op) 725 # pylint: disable=protected-access 726 op._set_attr("_xla_outside_compilation", 727 attr_value_pb2.AttrValue(s=compat.as_bytes(self._name))) 728 # pylint: enable=protected-access 729 730 def to_control_flow_context_def(self, context_def, export_scope=None): 731 raise NotImplementedError 732 733 734@tf_export(v1=["tpu.outside_compilation"]) 735def outside_compilation( 736 computation: Callable[..., Any], *args, **kwargs 737 ) -> Any: 738 """Builds part of a computation outside any current TPU replicate scope. 739 740 `tf.tpu.outside_compilation()` is used to run ops in `computation` on CPU 741 instead of running on TPU. For example, users can run ops that are not 742 supported on TPU's (e.g. tf.summary.write()) by explicitly placing those 743 ops on CPU's. Below usage of outside compilation will place ops in 744 `computation_with_string_ops` on CPU. 745 746 Example usage: 747 748 ```python 749 def computation_with_string_ops(x): 750 # strings types are not supported on TPU's and below ops must 751 # run on CPU instead. 752 output = tf.strings.format('1{}', x) 753 return tf.strings.to_number(output) 754 755 def tpu_computation(): 756 # Expected output is 11. 757 output = tf.tpu.outside_compilation(computation_with_string_ops, 1) 758 ``` 759 760 Outside compilation should be called inside TPUReplicateContext. That is, 761 `tf.tpu.outside_compilation()` should be called inside a function that is 762 passed to `tpu.split_compile_and_replicate()` -- this is implied when 763 outside compilation is invoked inside a function passed to TPUStrategy 764 `run()`. If invoked outside of TPUReplicateContext, 765 then this simply returns the result of `computation`, and therefore, 766 would be a no-op. Note that outside compilation is different from 767 `tf.distribute.experimental.TPUStrategy.merge_call()` as logic in 768 outside compilation is replicated and executed separately for each 769 replica. On the other hand, `merge_call()` requires a `merge_fn` 770 to aggregate the inputs from different replicas and is executed only 771 once. 772 773 For variables placed in TPU device, which includes variables created inside 774 TPUStrategy scope, outside compilation logic must not include variable 775 read/write. For variables placed on host, which is the case when variables 776 created via TPUEstimator, variable read/write is only allowed if the variable 777 is not accessed by any other ops in the TPU computation. Variable read/write 778 from outside compilation cluster is not visible from TPU computation and 779 vice versa. Therefore, if outside compilation logic contains such host 780 variables read/write ops and if the variables are accessed by TPU 781 computation as well, then this may lead to deadlock. 782 783 Internally, `tf.tpu.outside_compilation()` adds outside compilation 784 attributes to all ops in `computation`. During later graph pass, these 785 ops with outside compilation attribute is extracted out and replicated 786 into a host-side graph. Inputs to this extract host-side graph is sent 787 from TPU computation graph to host graph via a pair of XlaSendToHost and 788 XlaRecvFromHost ops. Note that using `tf.tpu.outside_compilation()` 789 may result in tensor transfer between TPU and CPU, leading to non-trivial 790 performance impact. 791 792 Args: 793 computation: A Python function that builds the computation to 794 place on the host. 795 *args: the positional arguments for the computation. 796 **kwargs: the keyword arguments for the computation. 797 798 Returns: 799 The Tensors returned by computation. 800 """ 801 args = [] if args is None else args 802 graph = ops.get_default_graph() 803 804 # If we are in TF 2 functions (control flow V2 functions, or tf.function()), 805 # we need to attach _xla_outside_compilation attribute directly because we are 806 # not in TPUReplicateContext. 807 if isinstance(graph, func_graph.FuncGraph): 808 try: 809 tpu_context, _ = _enclosing_tpu_context_and_graph() 810 except ValueError: 811 logging.warning( 812 "Outside compilation attempted outside TPUReplicateContext " 813 "scope. As no enclosing TPUReplicateContext can be found, " 814 "returning the result of `computation` as is.") 815 return computation(*args, **kwargs) 816 817 # pylint: disable=protected-access 818 outside_compilation_name = str(tpu_context._outside_compilation_counter) 819 tpu_context._outside_compilation_counter = ( 820 tpu_context._outside_compilation_counter + 1) 821 # pylint: enable=protected-access 822 823 outside_compilation_context = OutsideCompilationV2Context( 824 outside_compilation_name) 825 outside_compilation_context.Enter() 826 args = [] if args is None else args 827 retval = computation(*args, **kwargs) 828 outside_compilation_context.Exit() 829 return retval 830 831 # If we are in a TPUReplicateContext, signal that we are now 832 # outside_compilation 833 initial_context = graph._get_control_flow_context() # pylint: disable=protected-access 834 context = initial_context 835 while context: 836 if isinstance(context, TPUReplicateContext): 837 context._EnterOutsideCompilationScope() # pylint: disable=protected-access 838 context = context.outer_context 839 840 retval = computation(*args, **kwargs) 841 842 # If we are in a TPUReplicateContext, signal that we are no longer 843 # outside_compilation 844 final_context = graph._get_control_flow_context() # pylint: disable=protected-access 845 if initial_context is not final_context: 846 raise NotImplementedError( 847 "Control-flow context cannot be different at start and end of an " 848 "outside_compilation scope") 849 context = initial_context 850 while context: 851 if isinstance(context, TPUReplicateContext): 852 context._ExitOutsideCompilationScope() # pylint: disable=protected-access 853 context = context.outer_context 854 855 return retval 856 857 858@tf_export(v1=["tpu.PaddingSpec"]) 859class PaddingSpec(enum.IntEnum): 860 """Represents the type of padding policies for tpu.replicate.""" 861 # By default the policy is set to AUTO, the dynamic input shape dimension will 862 # be pad to maximum of all the replicas. 863 AUTO = 0 864 # Bucketize the dynamic input shape dimension into a power of 2. 865 POWER_OF_TWO = 1 866 867 868@tf_export("tpu.XLAOptions") 869class XLAOptions( 870 collections.namedtuple("XLAOptions", [ 871 "use_spmd_for_xla_partitioning", 872 "enable_xla_dynamic_padder", 873 ])): 874 """XLA compilation options. 875 876 Attributes: 877 use_spmd_for_xla_partitioning: Boolean. Whether to use XLA's SPMD 878 partitioner instead of MPMD partitioner when compiler partitioning is 879 requested. 880 enable_xla_dynamic_padder: Boolean. Whether to enable XLA dynamic padder 881 infrastructure to handle dynamic shapes inputs inside XLA. True by 882 default. Disabling this may cause correctness issues with dynamic shapes 883 inputs, as XLA will just assume the inputs are with padded shapes. However 884 users can optionally set it to False to improve device time if masking is 885 already handled in the user side. 886 """ 887 888 def __new__(cls, 889 use_spmd_for_xla_partitioning=True, 890 enable_xla_dynamic_padder=True): 891 return super(XLAOptions, cls).__new__(cls, use_spmd_for_xla_partitioning, 892 enable_xla_dynamic_padder) 893 894 895@tf_export(v1=["tpu.replicate"]) 896def replicate( 897 computation: Callable[..., Any], 898 inputs: Optional[List[List[core_types.Tensor]]] = None, 899 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 900 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 901 name: Optional[Text] = None, 902 maximum_shapes: Optional[Any] = None, 903 padding_spec: Optional[PaddingSpec] = None, 904 xla_options: Optional[XLAOptions] = None) -> List[Any]: 905 """Builds a graph operator that runs a replicated TPU computation. 906 907 Example for the basic usage that `inputs` has static shape: 908 909 ```python 910 911 def computation(x): 912 x = x + 1 913 return tf.math.reduce_mean(x) 914 915 x = tf.convert_to_tensor([1., 2., 3.]) 916 y = tf.convert_to_tensor([4., 5., 6.]) 917 tf.compat.v1.tpu.replicate(computation, inputs=[[x], [y]]) 918 ``` 919 920 If the `inputs` has dynamic shapes and you would like to automatically 921 bucketize the inputs to avoid XLA recompilation. See the advanced example 922 below: 923 924 ```python 925 926 def computation(x): 927 x = x + 1 928 return tf.math.reduce_mean(x) 929 930 # Assume input tensors in two replicas `x` and `y` both have dynamic shape 931 # ([None, 2]). 932 tf.compat.v1.tpu.replicate( 933 computation, 934 inputs=[x, y], 935 maximum_shapes=[tf.TensorShape([None, None])], 936 padding_spec=tf.compat.v1.tpu.PaddingSpec.POWER_OF_TWO) 937 ``` 938 939 Args: 940 computation: A Python function that builds the computation to replicate. 941 inputs: A list of lists of input tensors or `None` (equivalent to 942 `[[]]`), indexed by `[replica_num][input_num]`. All replicas must 943 have the same number of inputs. Each input can be a nested structure 944 containing values that are convertible to tensors. Note that passing an 945 N-dimension list of compatible values will result in a N-dimension list of 946 scalar tensors rather than a single Rank-N tensors. If you need different 947 behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. 948 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 949 of arguments as inputs to computation. 950 device_assignment: If not `None`, a `DeviceAssignment` describing the 951 mapping between logical cores in the computation with physical cores in 952 the TPU topology. Uses a default device assignment if `None`. The 953 `DeviceAssignment` may be omitted if each replica of the computation uses 954 only one core, and there is either only one replica, or the number of 955 replicas is equal to the number of cores in the TPU system. 956 name: (Deprecated) Does nothing. 957 maximum_shapes: A nested structure of tf.TensorShape representing the shape 958 to which the respective component of each input element in each replica 959 should be padded. Any unknown dimensions (e.g. 960 tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like 961 object) will be padded to the maximum size of that dimension over all 962 replicas. The structure of `maximum_shapes` needs to be the same as 963 `inputs[0]`. 964 padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the 965 padding policy when the `inputs` to `tpu.replicate` is dynamic. 966 One usage is to enable automatic bucketizing on the inputs by setting the 967 value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the 968 recompilation in the XLA side. 969 xla_options: An instance of `tpu.XLAOptions` which indicates the options 970 passed to XLA compiler. Use `None` for default options. 971 Returns: 972 A list of outputs, indexed by `[replica_num]` each output can be a nested 973 structure same as what computation() returns with a few exceptions. 974 975 Exceptions include: 976 1) None output: a NoOp would be returned which control-depends on 977 computation. 978 2) Single value output: A tuple containing the value would be returned. 979 3) Operation-only outputs: a NoOp would be returned which 980 control-depends on computation. 981 TODO(b/121383831): Investigate into removing these special cases. 982 983 Raises: 984 ValueError: If all replicas do not have equal numbers of input tensors. 985 ValueError: If the number of inputs per replica does not match 986 the number of formal parameters to `computation`. 987 ValueError: If the static `inputs` dimensions don't match with the values 988 given in `maximum_shapes`. 989 ValueError: If the structure of inputs per replica does not match 990 the structure of `maximum_shapes`. 991 """ 992 return split_compile_and_replicate( 993 computation, 994 inputs, 995 infeed_queue, 996 device_assignment, 997 name, 998 maximum_shapes=maximum_shapes, 999 padding_spec=padding_spec, 1000 xla_options=xla_options)[1] 1001 1002 1003def _ceil_to_pow_of_n(x, n): 1004 """Ceil input `x` to power of `n`.""" 1005 x = math_ops.cast(x, dtypes.float32) 1006 lognx = math_ops.log(x) / math_ops.log(n * 1.0) 1007 lognx = math_ops.ceil(lognx) 1008 result = math_ops.pow(n * 1.0, lognx) 1009 result = math_ops.cast(result, dtypes.int32) 1010 return result 1011 1012 1013def _pad_all_input( 1014 inputs: Iterable[core_types.Tensor], 1015 padded_shapes: List[Optional[tensor_shape.TensorShape]], 1016 padding_spec: PaddingSpec 1017) -> Tuple[List[List[Any]], List[dynamic_padding.PaddingMap]]: 1018 """Pad all input tensors given padded_shapes. 1019 1020 The real shape tensors will be concatenated with the padded original inputs. 1021 1022 Args: 1023 inputs: The original inputs. 1024 padded_shapes: A list of padded shapes for each input. If an entry is None, 1025 no padding is performed. 1026 padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the 1027 padding policy when the `inputs` to `tf.tpu.replicate` is dynamic. 1028 One usage is to enable automatic bucketizing on the inputs by setting the 1029 value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the 1030 recompilation in the XLA side. 1031 1032 Returns: 1033 The padded inputs and a PaddingMap list which maps the padded input 1034 dimension to the real shape argument index. 1035 """ 1036 # maximum_static_shapes[idx][i] indicates the maximum static size of ith 1037 # dimension of the idx input among all the replicas. 1038 maximum_static_shapes = [] 1039 # need_padding[idx][i] indicates whether the ith dimension of the idx input 1040 # needs padding. 1041 need_padding = [] 1042 input_shape_tensors = [] 1043 for core_idx, inputs_per_core in enumerate(inputs): 1044 for idx, input_tensor in enumerate(inputs_per_core): 1045 input_shape = input_tensor.get_shape().as_list() 1046 if core_idx == 0: 1047 input_shape_tensors.append([]) 1048 maximum_static_shapes.append(input_shape) 1049 need_padding.append(np.full_like(input_shape, False, dtype=bool)) 1050 else: 1051 for i, s in enumerate(input_shape): 1052 if s is None or s != maximum_static_shapes[idx][i]: 1053 need_padding[idx][i] = True 1054 maximum_static_shapes[idx] = max(input_shape, 1055 maximum_static_shapes[idx]) 1056 1057 # Append _POST_DEVICE_REWRITE_ATTR attributes to the real shape ops. 1058 real_input_shape = array_ops.shape(input_tensor) 1059 real_input_shape.op._set_attr( # pylint: disable=protected-access 1060 _POST_DEVICE_REWRITE_ATTR, 1061 attr_value_pb2.AttrValue(b=True)) 1062 input_shape_tensors[idx].append(real_input_shape) 1063 1064 maximum_shapes = [] 1065 for shapes_per_input in input_shape_tensors: 1066 maximum_shapes.append( 1067 math_ops.reduce_max(array_ops.stack(shapes_per_input), axis=0)) 1068 1069 padded_inputs = [] 1070 real_shapes = [] 1071 padding_maps = [] 1072 for core_idx, inputs_per_core in enumerate(inputs): 1073 padded_inputs.append([]) 1074 real_shapes.append([]) 1075 real_shape_idx = len(inputs_per_core) - 1 1076 for idx, input_tensor in enumerate(inputs_per_core): 1077 input_shape_tensor = input_shape_tensors[idx][core_idx] 1078 input_shape = input_tensor.get_shape().as_list() 1079 padded_shape = padded_shapes[idx] 1080 1081 # If we have no padded_shape, then skip padding. 1082 if any(need_padding[idx]) and padded_shape is not None: 1083 for i, s in enumerate(input_shape): 1084 if need_padding[idx][i]: 1085 if core_idx == 0: 1086 real_shape_idx += 1 1087 padding_map = dynamic_padding.PaddingMap() 1088 padding_map.arg_index = idx 1089 padding_map.shape_index = i 1090 padding_map.padding_arg_index = real_shape_idx 1091 padding_maps.append(padding_map) 1092 real_shapes[core_idx].append( 1093 math_ops.cast(input_shape_tensor[i], dtypes.int32)) 1094 1095 paddings = [] 1096 for i, s in enumerate(padded_shape.dims): 1097 if need_padding[idx][i]: 1098 # The minimum padded dimension size is 2 as XLA doesn't support size 1099 # 1 dynamic size. 1100 minimum_dynamic_dim_size = 2 1101 if s.value is not None: 1102 # Pad to the given maximum value. 1103 max_dim_size = max(s.value, minimum_dynamic_dim_size) 1104 else: 1105 # If maximum value is not given, then pad to the maximum dimension 1106 # among all the cores. 1107 max_dim_size = math_ops.maximum(maximum_shapes[idx][i], 1108 minimum_dynamic_dim_size) 1109 if padding_spec == PaddingSpec.POWER_OF_TWO: 1110 max_dim_size = _ceil_to_pow_of_n(max_dim_size, 2) 1111 # Pad to the given maximum value. 1112 padding = [0, max_dim_size - input_shape_tensor[i]] 1113 else: 1114 padding = [0, 0] 1115 paddings.append(padding) 1116 1117 if input_tensor.get_shape().is_fully_defined(): 1118 # TODO(rxsang): This is a hack to make sure padded_input has dynamic 1119 # shapes, so any tf.size/tf.shape op performed on it won't be constant 1120 # folded. Do we have better ways to do it? 1121 padded_input = control_flow_ops.cond( 1122 array_ops.constant(True), 1123 lambda: array_ops.pad(input_tensor, paddings), # pylint: disable=cell-var-from-loop 1124 lambda: input_tensor) 1125 else: 1126 padded_input = array_ops.pad(input_tensor, paddings) 1127 1128 # Append _POST_DEVICE_REWRITE_ATTR attributes to all padded inputs. 1129 padded_input.op._set_attr( # pylint: disable=protected-access 1130 _POST_DEVICE_REWRITE_ATTR, 1131 attr_value_pb2.AttrValue(b=True)) 1132 1133 padded_inputs[core_idx].append(padded_input) 1134 else: 1135 padded_inputs[core_idx].append(input_tensor) 1136 1137 num_replicas = len(padded_inputs) 1138 for i in range(num_replicas): 1139 padded_inputs[i].extend(real_shapes[i]) 1140 1141 return padded_inputs, padding_maps 1142 1143 1144def _flatten_and_filter_composite(maybe_composite, non_composite_output, 1145 composite_output=None): 1146 """For an input, replaced the input by a tuple if the input is composite. 1147 1148 If `maybe_composite` is not composite, return the parameter 1149 `non_composite_output` otherwise return a tuple which consists of the value of 1150 the parameter `composite_output` the same number of times as there are 1151 components of the composite tensor. 1152 1153 This is useful for computing a mask when flattening nested data with 1154 `expand_composites=True`. For example 1155 1156 ```python 1157 nest.flatten(data, expand_composites=True) 1158 ``` 1159 1160 and 1161 1162 ```python 1163 nest.flatten(nest.map( 1164 data, lambda x: _flatten_and_filter_composite(x, False, True))) 1165 ``` 1166 1167 will have the same length and second will be True if the tensor in the first 1168 is derived from a expanding a composite tensor. 1169 1170 Args: 1171 maybe_composite: A value to test for being a composite tensor. 1172 non_composite_output: The value to return when `maybe_composite` is not a 1173 composite. 1174 composite_output: the value to fill the output tuple with if 1175 `maybe_composite` is a composite. 1176 1177 Returns: 1178 `non_composite_output` or a tuple with multiple copies of 1179 `composite_output`. 1180 """ 1181 1182 if isinstance(maybe_composite, composite_tensor.CompositeTensor): 1183 num_components = len(nest.flatten(maybe_composite, expand_composites=True)) 1184 return (composite_output,) * num_components 1185 return non_composite_output 1186 1187 1188def split_compile_and_replicate( 1189 computation: Callable[..., Any], 1190 inputs: Optional[List[List[core_types.Tensor]]] = None, 1191 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 1192 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 1193 name: Optional[Text] = None, 1194 use_tpu: bool = True, 1195 maximum_shapes: Optional[Any] = None, 1196 padding_spec: Optional[PaddingSpec] = None, 1197 xla_options: Optional[XLAOptions] = None, 1198) -> List[List[core_types.Tensor]]: 1199 """Builds graph operators that runs compilation and replicated computation. 1200 1201 This is a lower level interface than replicate that returns a separate compile 1202 and execute output tensor. In the generated graph the compile op feeds into 1203 the execute op and no additional compilation is incurred when running the 1204 compile op before the execute op. The compile op returns additional 1205 information about the compilation but does not return the compiled program. 1206 1207 Args: 1208 computation: A Python function that builds the computation to replicate. 1209 inputs: A list of lists of input tensors or `None` (equivalent to 1210 `[[]]`), indexed by `[replica_num][input_num]`. All replicas must 1211 have the same number of inputs. Each input can be a nested structure 1212 containing values that are convertible to tensors. Note that passing an 1213 N-dimension list of compatible values will result in a N-dimension list of 1214 scalar tensors rather than a single Rank-N tensors. If you need different 1215 behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. 1216 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 1217 of arguments as inputs to computation. 1218 device_assignment: If not `None`, a `DeviceAssignment` describing the 1219 mapping between logical cores in the computation with physical cores in 1220 the TPU topology. Uses a default device assignment if `None`. The 1221 `DeviceAssignment` may be omitted if each replica of the computation uses 1222 only one core, and there is either only one replica, or the number of 1223 replicas is equal to the number of cores in the TPU system. 1224 name: (Deprecated) Does nothing. 1225 use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU 1226 backends. Currently, only supports a default placement (computation is 1227 placed on GPU if one is available, and on CPU if not). 1228 maximum_shapes: A nested structure of tf.TensorShape representing the shape 1229 to which the respective component of each input element in each replica 1230 should be padded. Any unknown dimensions (e.g. 1231 tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like 1232 object) will be padded to the maximum size of that dimension over all 1233 replicas. The structure of `maximum_shapes` needs to be the same as 1234 `inputs[0]`. 1235 padding_spec: An enum specified by `tf.tpu.PaddingSpec`. This describes the 1236 padding policy when the `inputs` to `tf.tpu.replicate` is dynamic. 1237 One usage is to enable automatic bucketizing on the inputs by setting the 1238 value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the 1239 recompilation in the XLA side. 1240 xla_options: An instance of `tpu.XLAOptions` which indicates the options 1241 passed to XLA compiler. Use `None` for default options. 1242 1243 Returns: 1244 A list of lists with the first list corresponding to the compile op and the 1245 second a list of output tensors, indexed by `[replica_num][output_num]`. 1246 Raises: 1247 ValueError: If all replicas do not have equal numbers of input tensors. 1248 ValueError: If the number of inputs per replica does not match 1249 the number of formal parameters to `computation`. 1250 ValueError: If the static `inputs` dimensions don't match with the values 1251 given in `maximum_shapes`. 1252 ValueError: If the structure of inputs per replica does not match 1253 the structure of `maximum_shapes`. 1254 """ 1255 del name 1256 inputs = [[]] if inputs is None else inputs 1257 xla_options = xla_options or XLAOptions() 1258 1259 metadata_kwargs = {} 1260 if device_assignment is not None: 1261 # Turn the Numpy array into a flattened list so we can pass it as an 1262 # operator attribute. 1263 metadata_kwargs = { 1264 "topology": 1265 device_assignment.topology.serialized(), 1266 "device_assignment": 1267 device_assignment.core_assignment.flatten().tolist() 1268 } 1269 metadata_kwargs["num_cores_per_replica"] = ( 1270 device_assignment.num_cores_per_replica) 1271 1272 # This entry is used for enabling automatic outside compilation. 1273 metadata_kwargs["allow_soft_placement"] = config.get_soft_device_placement() 1274 if config.get_soft_device_placement(): 1275 logging.info("Automatic outside compilation is enabled. " 1276 "Ops without XLA kernels will be automatically " 1277 "placed on CPU.") 1278 1279 if not isinstance(inputs, list): 1280 raise TypeError("tpu.replicate() inputs must be a list of lists/tuples, " 1281 f"received {type(inputs)}") 1282 if any(not isinstance(inp, (list, tuple)) for inp in inputs): 1283 raise TypeError( 1284 "tpu.replicate() inputs must be a list of lists/tuples, " 1285 f"received types: {[type(inp) for inp in inputs]}") 1286 1287 num_replicas = len(inputs) 1288 1289 # No replicas? Nothing to do. 1290 if num_replicas == 0: 1291 return [] 1292 1293 # Checks all replicas have the same structure. 1294 for i in xrange(1, num_replicas): 1295 nest.assert_same_structure(inputs[0], inputs[i]) 1296 1297 # Flatten inputs. This structure may contain None values, which will be 1298 # handled later. 1299 flat_inputs_with_nones = [ 1300 nest.flatten(per_replica_input, expand_composites=True) 1301 for per_replica_input in inputs 1302 ] 1303 # Mask parallel to one replica's inputs with True for tensors coming from 1304 # composites. 1305 is_composite = nest.flatten(nest.map_structure( 1306 lambda x: _flatten_and_filter_composite(x, False, True), inputs[0])) 1307 1308 # Converts inputs to Tensors, replacing Nones with a placeholder 0 since 1309 # tpu_ops.tpu_replicated_input() can't handle non-Tensor values. 1310 flat_inputs = [] 1311 for inp in flat_inputs_with_nones: 1312 flat_inputs.append([ 1313 constant_op.constant(0) if x is None else ops.convert_to_tensor(x) 1314 for x in inp 1315 ]) 1316 1317 # Verifies that all replicas have matching numbers and types of inputs 1318 flat_input_types = [x.dtype for x in flat_inputs[0]] 1319 input_arity = len(inputs[0]) 1320 flat_input_arity = len(flat_input_types) 1321 for i in range(num_replicas): 1322 if len(inputs[i]) != input_arity: 1323 raise ValueError("Replicas must have the same number of inputs. " 1324 "Replica 0 had {} inputs, replica {} had {} " 1325 "inputs.".format(input_arity, i, len(inputs[i]))) 1326 1327 types = [x.dtype for x in flat_inputs[i]] 1328 if types != flat_input_types: 1329 raise ValueError("Replicas must have matching input types. Replica 0 had " 1330 "input types {}, replica {} had input types {}".format( 1331 flat_input_types, i, types)) 1332 1333 arg_error = xla.check_function_argument_count( 1334 computation, input_arity, infeed_queue) 1335 if arg_error is not None: 1336 if infeed_queue is None: 1337 raise TypeError( 1338 "Supplied computation cannot be called with the specified inputs. " 1339 f"You specified {input_arity} inputs: {[i.name for i in inputs[0]]}, " 1340 f"but the computation needs{arg_error}") 1341 else: 1342 raise TypeError( 1343 "Supplied computation cannot be called with the specified inputs. " 1344 f"You specified {input_arity} inputs: {[i.name for i in inputs[0]]} ", 1345 f"and {infeed_queue.number_of_tuple_elements} additional inputs " 1346 f"from infeed, but the computation needs {arg_error}") 1347 1348 dynamic_shape_inputs = False 1349 if maximum_shapes: 1350 if infeed_queue: 1351 raise ValueError( 1352 "Dynamic input shapes are not supported with infeed queues") 1353 1354 # Make sure maximum_shapes has the same structure as inputs. 1355 nest.assert_same_structure(inputs[0], maximum_shapes, check_types=False) 1356 1357 # Flatten padded shapes: 1358 # For composite tensor components, we don't want to pad them. For each 1359 # entry of maximum_shapes that corresponds to a composite tensor, replace it 1360 # by a tuple of Nones of the same length as the number of components of the 1361 # composite tensor. When we flatten a second time, this makes 1362 # flat_maximum_shapes have the same length as flat_inputs[i]. We can then 1363 # avoid padding these tensors. The assumption is that they will be used by 1364 # outside compilation or that the components are statically shaped and will 1365 # be used by tpu compatible ops. 1366 flat_maximum_shapes = nest.flatten( 1367 [_flatten_and_filter_composite(x, y) 1368 for x, y in zip(nest.flatten(inputs[0]), 1369 nest.flatten(maximum_shapes))]) 1370 flat_maximum_shapes = [ 1371 tensor_shape.TensorShape(s) if s is not None else None 1372 for s in flat_maximum_shapes 1373 ] 1374 nest.assert_same_structure(flat_inputs[0], flat_maximum_shapes, 1375 check_types=False) 1376 1377 unpadded_inputs = flat_inputs 1378 flat_inputs, padding_maps = _pad_all_input(unpadded_inputs, 1379 flat_maximum_shapes, 1380 padding_spec) 1381 if padding_maps: 1382 dynamic_shape_inputs = True 1383 logging.info("TPU has inputs with dynamic shapes: %s", unpadded_inputs[0]) 1384 1385 metadata_kwargs["step_marker_location"] = getattr( 1386 computation, "step_marker_location", "STEP_MARK_AT_ENTRY") 1387 metadata_kwargs["use_spmd_for_xla_partitioning"] = \ 1388 xla_options.use_spmd_for_xla_partitioning 1389 1390 graph = ops.get_default_graph() 1391 1392 # Fan-in: Builds a TPUReplicatedInput node for each input. 1393 flat_replicated_inputs = [] 1394 for i in range(0, len(flat_inputs[0])): 1395 replicas = [flat_inputs[replica][i] for replica in xrange(num_replicas)] 1396 flat_replicated_inputs.append( 1397 tpu_ops.tpu_replicated_input( 1398 replicas, name="input{}".format(i), index=i)) 1399 if isinstance(graph, func_graph.FuncGraph): 1400 # When we are in Tensorflow 2.0 function, 'graph' will be a FuncGraph 1401 # object. If both outside graph and this function have a TPU cluster, 1402 # they will have the same cluster name and it will cause problems (because 1403 # we lower functional ops in Tensorflow 2.0). Append function name to 1404 # 'cluster_name' to avoid cluster name collision. 1405 cluster_name = graph.unique_name("cluster_" + graph.name) 1406 else: 1407 cluster_name = graph.unique_name("cluster") 1408 pivot = control_flow_ops.no_op(name=cluster_name + "/pivot") 1409 pivot._set_attr(_PIVOT_FOR_CLUSTER, # pylint: disable=protected-access 1410 attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name))) 1411 context = TPUReplicateContext( 1412 name=cluster_name, num_replicas=num_replicas, pivot=pivot) 1413 try: 1414 context.Enter() 1415 1416 metadata = tpu_ops.tpu_replicate_metadata( 1417 num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs) 1418 1419 with tpu_function.tpu_shard_context( 1420 num_replicas), ops.control_dependencies([metadata]): 1421 1422 if dynamic_shape_inputs and xla_options.enable_xla_dynamic_padder: 1423 for padding_map in padding_maps: 1424 input_shape = flat_replicated_inputs[padding_map.arg_index].shape 1425 flat_replicated_inputs[ 1426 padding_map.arg_index] = tf2xla.set_dynamic_dimension_size( 1427 flat_replicated_inputs[padding_map.arg_index], 1428 padding_map.shape_index, 1429 flat_replicated_inputs[padding_map.padding_arg_index]) 1430 flat_replicated_inputs[padding_map.arg_index].set_shape(input_shape) 1431 1432 # Add identity ops so even unused inputs are "consumed" by the 1433 # computation. This is to avoid orphaned TPUReplicatedInput nodes. 1434 # TODO(phawkins): consider instead pruning unused TPUReplicatedInput 1435 # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs. 1436 flat_replicated_inputs = [ 1437 array_ops.identity(x, name="replicated_input_{}".format(i)) 1438 for i, x in enumerate(flat_replicated_inputs) 1439 ] 1440 for i, composite in zip(flat_replicated_inputs, is_composite): 1441 # pylint: disable=protected-access 1442 # Add an attribute to the identity node so that they could be removed in 1443 # encapsulate TPU computation pass if unused. However we don't remove 1444 # inputs when dynamic padding is enabled. 1445 # TODO(rxsang): Use other ways except argument index in padding_map so 1446 # outside compilation can work with dynamic padding correctly. 1447 if not dynamic_shape_inputs or composite: 1448 i.op._set_attr("_tpu_input_identity", 1449 attr_value_pb2.AttrValue(b=True)) 1450 # pylint: enable=protected-access 1451 1452 # Clobber replicated placeholders with Nones. 1453 computation_inputs = [ 1454 None if inp is None else replicated for replicated, inp in zip( 1455 flat_replicated_inputs, flat_inputs_with_nones[0]) 1456 ] 1457 1458 # Unflatten the computation inputs to match original input structure. 1459 computation_inputs = nest.pack_sequence_as( 1460 structure=inputs[0], 1461 flat_sequence=computation_inputs[:flat_input_arity], 1462 expand_composites=True) 1463 1464 # If there is an infeed queue, adds the dequeued values to the 1465 # computation's inputs. 1466 if infeed_queue is not None: 1467 infeed_queue.set_number_of_shards(num_replicas) 1468 for t in infeed_queue.generate_dequeue_op(): 1469 computation_inputs.append(t) 1470 1471 # Only resource variables work inside a TPU computation, so turn on 1472 # resource variables for the computation. 1473 # TODO(phawkins): consider removing this code. It will 1474 # be less confusing to clients if they knowingly choose to use resource 1475 # variables. 1476 # Partitioned variables is not supported (b/112311320). 1477 vscope = variable_scope.get_variable_scope() 1478 saved_use_resource = vscope.use_resource 1479 saved_custom_getter = vscope.custom_getter 1480 1481 def custom_getter(getter, name, *args, **kwargs): 1482 """Variables on TPU have a few restrictions.""" 1483 partitioner = kwargs.get("partitioner", None) 1484 if partitioner is not None: 1485 kwargs["partitioner"] = None 1486 logging.warning( 1487 "Partitioned variables are not supported on TPU. Got " 1488 "`partitioner` that is %s for variable %s. " 1489 "Setting `partitioner` to `None`.", partitioner, name) 1490 if saved_custom_getter is None: 1491 return getter(name, *args, **kwargs) 1492 else: 1493 return saved_custom_getter(getter, name, *args, **kwargs) 1494 1495 vscope.set_use_resource(True) 1496 vscope.set_custom_getter(custom_getter) 1497 1498 outputs = computation(*computation_inputs) 1499 1500 vscope.set_use_resource(saved_use_resource) 1501 vscope.set_custom_getter(saved_custom_getter) 1502 1503 outputs_is_flat = xla.is_flat(outputs) 1504 if outputs_is_flat: 1505 output_tensors, control_deps, pack_template = _postprocess_flat_outputs( 1506 outputs) 1507 else: 1508 output_tensors, control_deps, pack_template = ( 1509 _postprocess_non_flat_outputs(outputs)) 1510 1511 # tensor_tracer imports tpu.py. Local import to tensor_tracer to avoid 1512 # import-cycle 1513 if typing.TYPE_CHECKING: 1514 tensor_tracer = Any 1515 else: 1516 # pylint: disable=g-import-not-at-top 1517 from tensorflow.python.tpu import tensor_tracer 1518 # pylint: enable=g-import-not-at-top 1519 if tensor_tracer.TensorTracer.is_enabled(): 1520 tt = tensor_tracer.TensorTracer() 1521 output_tensors = tt.trace_tpu(ops.get_default_graph(), 1522 output_tensors, control_deps, 1523 num_replicas) 1524 1525 context.ExitResult(output_tensors) 1526 finally: 1527 context.report_unsupported_operations() 1528 context.Exit() 1529 host_compute_core = context.HostComputeCore() 1530 1531 if host_compute_core: 1532 attr_value = attr_value_pb2.AttrValue() 1533 attr_value.list.s.extend(compat.as_bytes(x) for x in host_compute_core) 1534 metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access 1535 1536 with ops.control_dependencies([metadata]): 1537 if use_tpu: 1538 compile_status = tpu_ops.tpu_compilation_result() 1539 op = compile_status.op 1540 attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)) 1541 op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access 1542 else: 1543 compile_status = control_flow_ops.no_op(name="compilation_status") 1544 1545 if not output_tensors: 1546 # Returns a list of NoOps dependent on the replication Op, indexed by 1547 # [replica_num]. 1548 return [ 1549 compile_status, 1550 [ 1551 control_flow_ops.group(control_deps, name="shard_%d" % i) 1552 for i in range(num_replicas) 1553 ] 1554 ] 1555 1556 # Fan-out: Builds a TPUReplicatedOutput node for each output. 1557 replicated_outputs = [[] for i in range(num_replicas)] 1558 for i, t in enumerate(output_tensors): 1559 1560 # None values returned by the computation can't be sent to 1561 # tpu_ops.tpu_replicated_output(), we handle them specially here. We can 1562 # avoid the placeholder 0 routine required on the inputs since outputs are 1563 # replicated per-tensor, not per-replica, so we can skip replication. 1564 if t is None: 1565 for replica in range(num_replicas): 1566 replicated_outputs[replica].append(None) 1567 continue 1568 1569 # Fan-out: Builds a TPUReplicatedOutput node for each output. 1570 ys = tpu_ops.tpu_replicated_output( 1571 t, num_replicas, name="output{}".format(i)) 1572 1573 # Wraps the outputs in identity operators so the names of any possible 1574 # `fetch` nodes are preserved by the replication rewrite. 1575 with ops.control_dependencies(control_deps): 1576 for replica in range(num_replicas): 1577 replicated_outputs[replica].append( 1578 array_ops.identity( 1579 ys[replica], name="output_%d_shard_%d" % (i, replica))) 1580 1581 replicated_outputs = [ 1582 nest.pack_sequence_as(pack_template, replica_outs, expand_composites=True) 1583 for replica_outs in replicated_outputs 1584 ] 1585 1586 return [compile_status, replicated_outputs] 1587 1588 1589def _postprocess_flat_outputs( 1590 outputs: Any 1591) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]: 1592 """Validates non-flat outputs, add backs device assignments and other attrs. 1593 1594 Args: 1595 outputs: Output from `computation` inside `tpu.rewrite`. 1596 1597 Returns: 1598 - Tensors extracted from outputs. 1599 - Operations extracted from outputs. 1600 - A pack template for use with nest.pack_sequence_as to pack the tensors. 1601 """ 1602 # Following code segment is to preserve legacy behavior. Previously we only 1603 # supported flat outputs and thus for consistency it was nice to convert even 1604 # single element into a tuple. But now that we support arbitrary output 1605 # structure, this is no longer necessary. 1606 # TODO(b/121383831): Migrate all legacy use cases and delete this special 1607 # case. 1608 # If the computation returns `None`, make it an empty tuple. 1609 if outputs is None: 1610 outputs = tuple() 1611 1612 # For legacy / backwards compatibility reasons we return a list for "flat" 1613 # output values (even if the user's flat return value was a different type or 1614 # even just a scalar value) so use nest.flatten to compute a flat list pack 1615 # template. 1616 pack_template = nest.flatten(outputs, expand_composites=False) 1617 1618 # Even though outputs is already "flat", we flatten any composites so their 1619 # component tensors can be tagged and replicated. The pack_template will be 1620 # used by the caller to repack the composite tensors. 1621 outputs = nest.flatten(outputs, expand_composites=True) 1622 1623 # Append `no_op` here so that fetching any return value of this function 1624 # will trigger TPUExecute node. 1625 outputs += (control_flow_ops.no_op(),) 1626 1627 maybe_convert = lambda x: None if x is None else ops.convert_to_tensor(x) 1628 try: 1629 with ops.device(core(0)): 1630 outputs = [ 1631 o if isinstance(o, ops.Operation) else maybe_convert(o) 1632 for o in outputs 1633 ] 1634 except Exception as e: 1635 raise ValueError( 1636 "TPU function return values must all either be Operations or " 1637 f"convertible to Tensors. Got error: {e}") 1638 1639 # Separates the returned Operations and Tensors. 1640 output_operations = [o for o in outputs if isinstance(o, ops.Operation)] 1641 output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] 1642 1643 if outputs != output_tensors + output_operations: 1644 raise ValueError( 1645 "TPU functions must return zero-or more Tensor values followed by " 1646 "zero or more Operations.") 1647 1648 # Trim operations off the end of the pack template. output_operations has 1 1649 # extra element due to the no-op that is added. 1650 if len(output_operations) > 1: 1651 pack_template = pack_template[:1 - len(output_operations)] 1652 1653 # Wraps outputs in Identity ops. Otherwise a replicated input copied 1654 # straight to an output would bypass the replicate(). This would be bad 1655 # because the TPUReplicatedInput/TPUReplicatedOutput operator would not 1656 # be rewritten away, leading to a runtime error. 1657 # TODO(phawkins): extend the rewrite to elide these nodes instead. 1658 new_output_tensors = [] 1659 for t in output_tensors: 1660 if t is None: 1661 new_output_tensors.append(None) 1662 with ops.device(t.device if t.device else core(0)): 1663 o = array_ops.identity(t) 1664 # pylint: disable=protected-access 1665 o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) 1666 # pylint: enable=protected-access 1667 new_output_tensors.append(o) 1668 return new_output_tensors, output_operations, pack_template 1669 1670 1671def _postprocess_non_flat_outputs( 1672 outputs: Any 1673) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]: 1674 """Validates non-flat outputs, add backs device assignments and other attrs. 1675 1676 Args: 1677 outputs: Output from `computation` inside `tpu.rewrite`. 1678 1679 Returns: 1680 - Tensors extracted from outputs. 1681 - An empty Operations list because Operations are not allowed in non-flat 1682 outputs. 1683 - A pack template for use with nest.pack_sequence_as to pack the tensors. 1684 """ 1685 1686 # Flatten output items. 1687 flat_outputs = nest.flatten(outputs, expand_composites=True) 1688 1689 # Convert all non-None non-Operation outputs to Tensors. 1690 for i, o in enumerate(flat_outputs): 1691 if o is None: 1692 flat_outputs[i] = None 1693 continue 1694 1695 if isinstance(o, ops.Operation): 1696 raise ValueError( 1697 "tpu.rewrite does not support Operation as return value in non-flat " 1698 "output structure. You can set returned Operations as control " 1699 "dependencies of returned Tensors so Operations are triggered when " 1700 f'Tensors are evaluated. Operation found: "{o.name}"') 1701 1702 try: 1703 o = ops.convert_to_tensor(o) 1704 except Exception as e: 1705 raise ValueError( 1706 "TPU function return values must all either be Operations or " 1707 f'convertible to Tensors. Got error: "{e}"') 1708 1709 # Wraps outputs in Identity ops. Otherwise a replicated input copied 1710 # straight to an output would bypass the replicate(). This would be bad 1711 # because the TPUReplicatedInput/TPUReplicatedOutput operator would not 1712 # be rewritten away, leading to a runtime error. 1713 # TODO(phawkins): extend the rewrite to elide these nodes instead. 1714 with ops.device(o.device if o.device else core(0)): 1715 o = array_ops.identity(o) 1716 # pylint: disable=protected-access 1717 o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) 1718 # pylint: enable=protected-access 1719 flat_outputs[i] = array_ops.identity(o) 1720 1721 # All flat_outputs are Tensors, and no Operations. 1722 return flat_outputs, [], outputs 1723 1724 1725def split_compile_and_shard( 1726 computation: Callable[..., Any], 1727 inputs: Optional[List[List[Optional[core_types.Tensor]]]] = None, 1728 num_shards: int = 1, 1729 input_shard_axes: Optional[List[int]] = None, 1730 outputs_from_all_shards: Union[bool, List[bool]] = True, 1731 output_shard_axes: Optional[List[int]] = None, 1732 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 1733 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 1734 name: Optional[Text] = None, 1735 xla_options: Optional[XLAOptions] = None, 1736 ) -> Tuple[ops.Operation, List[core_types.Tensor]]: 1737 """Shards `computation` for parallel execution. 1738 1739 `inputs` must be a list of Tensors or None (equivalent to an empty list), each 1740 of which has a corresponding split axis (from `input_shard_axes`). Each input 1741 is split into `num_shards` pieces along the corresponding axis, and 1742 computation is applied to each shard in parallel. 1743 1744 Tensors are broadcast to all shards if they are lexically captured by 1745 `computation`. e.g., 1746 1747 x = tf.constant(7) 1748 def computation(): 1749 return x + 3 1750 ... = shard(computation, ...) 1751 1752 If `outputs_from_all_shards` is true, the outputs from all shards of 1753 `computation` are concatenated back together along their `output_shard_axes`. 1754 Otherwise, each output is taken from an arbitrary shard. 1755 1756 Inputs and outputs of the computation must be at least rank-1 Tensors. 1757 1758 Args: 1759 computation: A Python function that builds a computation to apply to each 1760 shard of the input. 1761 inputs: A list of input tensors or None (equivalent to an empty list). Each 1762 input tensor has a corresponding shard axes, given by `input_shard_axes`, 1763 which must have size divisible by `num_shards`. 1764 num_shards: The number of shards. 1765 input_shard_axes: A list of dimensions along which to shard `inputs`, or 1766 `None`. `None` means "shard all inputs along dimension 0". If not `None`, 1767 there must be one dimension per input. 1768 outputs_from_all_shards: Boolean or list of boolean. For each output, if 1769 `True`, outputs from all shards are concatenated along the corresponding 1770 `output_shard_axes` entry. Otherwise, each output is taken 1771 from an arbitrary shard. If the argument is a boolean, the argument's 1772 value is used for each output. 1773 output_shard_axes: A list of dimensions along which to concatenate the 1774 outputs of `computation`, or `None`. `None` means "concatenate all outputs 1775 along dimension 0". If not `None`, there must be one dimension per output. 1776 Ignored if `outputs_from_all_shards` is False. 1777 infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs 1778 of `computation`. 1779 device_assignment: If not `None`, a `DeviceAssignment` describing the 1780 mapping between logical cores in the computation with physical cores in 1781 the TPU topology. Uses a default device assignment if `None`. The 1782 `DeviceAssignment` may be omitted if each shard of the computation uses 1783 only one core, and there is either only one shard, or the number of shards 1784 is equal to the number of cores in the TPU system. 1785 name: (Deprecated) Does nothing. 1786 xla_options: An instance of `tpu.XLAOptions` which indicates the options 1787 passed to XLA compiler. Use `None` for default options. 1788 Returns: 1789 A tuple of (compile op, [output tensors]). 1790 Raises: 1791 ValueError: If num_shards <= 0 1792 ValueError: If len(input_shard_axes) != len(inputs) 1793 ValueError: If len(output_shard_axes) != len(outputs from `computation`) 1794 """ 1795 # TODO(phawkins): consider adding support for broadcasting Tensors passed as 1796 # inputs. 1797 1798 if num_shards <= 0: 1799 raise ValueError( 1800 f"num_shards must be a positive integer. Received {num_shards}") 1801 1802 inputs = [] if inputs is None else inputs 1803 if not isinstance(inputs, list): 1804 raise TypeError("tpu.shard()'s inputs must be a list of Tensors or None. " 1805 f"Received {type(inputs)}") 1806 1807 # Converts inputs to Tensors. 1808 inputs = [ops.convert_to_tensor(x) for x in inputs] 1809 1810 if input_shard_axes is None: 1811 input_shard_axes = [0] * len(inputs) 1812 if len(inputs) != len(input_shard_axes): 1813 raise ValueError("Length of input_shard_axes must be equal to the number " 1814 f"of inputs. Received {len(inputs)} inputs and " 1815 f"{len(input_shard_axes)} input_shard_axes.") 1816 1817 if inputs: 1818 # Splits the `inputs` along the corresponding `input_shard_axes`, giving 1819 # lists with layout [input][shard] 1820 split_inputs = [ 1821 array_ops.split(x, num_shards, axis=axis) 1822 for (axis, x) in zip(input_shard_axes, inputs)] 1823 1824 # Transposes the input lists to have layout [shard][input] 1825 transposed_inputs = [list(i) for i in zip(*split_inputs)] 1826 else: 1827 transposed_inputs = [[]] * num_shards 1828 1829 compile_op, outputs = split_compile_and_replicate( 1830 computation, 1831 transposed_inputs, 1832 infeed_queue=infeed_queue, 1833 device_assignment=device_assignment, 1834 name=name, 1835 xla_options=xla_options) 1836 1837 # There must be at least one shard since num_shards > 0. 1838 # TODO(b/36647078) remove disable when pylint bug is fixed. 1839 # pylint: disable=indexing-exception 1840 if isinstance(outputs[0], ops.Operation): 1841 # pylint: enable=indexing-exception 1842 # There were no outputs from the computation and replicate returned a list 1843 # of NoOps with control dependencies on the computation. Return the first 1844 # one so it can be used as a control dependency or fetch node. 1845 # TODO(b/36647078) remove disable when pylint bug is fixed. 1846 # pylint: disable=indexing-exception 1847 return compile_op, [outputs[0]] 1848 # pylint: enable=indexing-exception 1849 1850 # TODO(b/36647078) remove disable when pylint bug is fixed. 1851 # pylint: disable=indexing-exception 1852 num_outputs = len(outputs[0]) 1853 # pylint: enable=indexing-exception 1854 1855 if output_shard_axes is None: 1856 output_shard_axes = [0] * num_outputs 1857 if num_outputs != len(output_shard_axes): 1858 raise ValueError("Length of output_shard_axes must be equal to the number " 1859 f"of outputs. Received {num_outputs} outputs " 1860 f"and {len(output_shard_axes)} output_shard_axes.") 1861 1862 if isinstance(outputs_from_all_shards, bool): 1863 outputs_from_all_shards = [outputs_from_all_shards] * num_outputs 1864 1865 if num_outputs != len(outputs_from_all_shards): 1866 raise ValueError( 1867 "Length of outputs_from_all_shards must be equal to the number of " 1868 f"outputs. Received {num_outputs} outputs and " 1869 f"{len(outputs_from_all_shards)} outputs_from_all_shards.") 1870 1871 results = [] 1872 for (axis, all_shards, x) in zip(output_shard_axes, outputs_from_all_shards, 1873 zip(*outputs)): 1874 if all_shards: 1875 # Concatenate all of the outputs together (use stack for scalars). 1876 shape = x[0].shape 1877 is_scalar = shape is not None and (shape.ndims == 0) 1878 results.append((array_ops.stack(list(x)) if is_scalar 1879 else array_ops.concat(list(x), axis=axis))) 1880 else: 1881 # TODO(phawkins): use a smarter policy, e.g., round-robin across shards. 1882 results.append(x[0]) 1883 1884 return compile_op, results 1885 1886 1887@tf_export(v1=["tpu.shard"]) 1888def shard( 1889 computation: Callable[..., Any], 1890 inputs: Optional[List[core_types.Tensor]] = None, 1891 num_shards: int = 1, 1892 input_shard_axes: Optional[List[int]] = None, 1893 outputs_from_all_shards: Union[bool, List[bool]] = True, 1894 output_shard_axes: Optional[List[int]] = None, 1895 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 1896 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 1897 name: Optional[Text] = None, 1898 xla_options: Optional[XLAOptions] = None) -> List[core_types.Tensor]: 1899 """Shards `computation` for parallel execution. 1900 1901 `inputs` must be a list of Tensors or None (equivalent to an empty list), each 1902 of which has a corresponding split axis (from `input_shard_axes`). Each input 1903 is split into `num_shards` pieces along the corresponding axis, and 1904 computation is applied to each shard in parallel. 1905 1906 Tensors are broadcast to all shards if they are lexically captured by 1907 `computation`. e.g., 1908 1909 x = tf.constant(7) 1910 def computation(): 1911 return x + 3 1912 ... = shard(computation, ...) 1913 1914 TODO(phawkins): consider adding support for broadcasting Tensors passed 1915 as inputs. 1916 1917 If `outputs_from_all_shards` is true, the outputs from all shards of 1918 `computation` are concatenated back together along their `output_shard_axes`. 1919 Otherwise, each output is taken from an arbitrary shard. 1920 1921 Inputs and outputs of the computation must be at least rank-1 Tensors. 1922 1923 Args: 1924 computation: A Python function that builds a computation to apply to each 1925 shard of the input. 1926 inputs: A list of input tensors or None (equivalent to an empty list). Each 1927 input tensor has a corresponding shard axes, given by `input_shard_axes`, 1928 which must have size divisible by `num_shards`. 1929 num_shards: The number of shards. 1930 input_shard_axes: A list of dimensions along which to shard `inputs`, or 1931 `None`. `None` means "shard all inputs along dimension 0". If not `None`, 1932 there must be one dimension per input. 1933 outputs_from_all_shards: Boolean or list of boolean. For each output, if 1934 `True`, outputs from all shards are concatenated along the corresponding 1935 `output_shard_axes` entry. Otherwise, each output is taken 1936 from an arbitrary shard. If the argument is a boolean, the argument's 1937 value is used for each output. 1938 output_shard_axes: A list of dimensions along which to concatenate the 1939 outputs of `computation`, or `None`. `None` means "concatenate all outputs 1940 along dimension 0". If not `None`, there must be one dimension per output. 1941 Ignored if `outputs_from_all_shards` is False. 1942 infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs 1943 of `computation`. 1944 device_assignment: If not `None`, a `DeviceAssignment` describing the 1945 mapping between logical cores in the computation with physical cores in 1946 the TPU topology. Uses a default device assignment if `None`. The 1947 `DeviceAssignment` may be omitted if each shard of the computation uses 1948 only one core, and there is either only one shard, or the number of shards 1949 is equal to the number of cores in the TPU system. 1950 name: (Deprecated) Does nothing. 1951 xla_options: An instance of `tpu.XLAOptions` which indicates the options 1952 passed to XLA compiler. Use `None` for default options. 1953 Returns: 1954 A list of output tensors. 1955 Raises: 1956 ValueError: If num_shards <= 0 1957 ValueError: If len(input_shard_axes) != len(inputs) 1958 ValueError: If len(output_shard_axes) != len(outputs from `computation`) 1959 """ 1960 return split_compile_and_shard( 1961 computation, 1962 inputs=inputs, 1963 num_shards=num_shards, 1964 input_shard_axes=input_shard_axes, 1965 outputs_from_all_shards=outputs_from_all_shards, 1966 output_shard_axes=output_shard_axes, 1967 infeed_queue=infeed_queue, 1968 device_assignment=device_assignment, 1969 name=name, 1970 xla_options=xla_options)[1] 1971 1972 1973@tf_export(v1=["tpu.batch_parallel"]) 1974def batch_parallel( 1975 computation: Callable[..., Any], 1976 inputs: Optional[List[List[Optional[core_types.Tensor]]]] = None, 1977 num_shards: int = 1, 1978 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 1979 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 1980 name: Optional[Text] = None, 1981 xla_options: Optional[XLAOptions] = None): 1982 """Shards `computation` along the batch dimension for parallel execution. 1983 1984 Convenience wrapper around shard(). 1985 1986 `inputs` must be a list of Tensors or None (equivalent to an empty list). 1987 Each input is split into `num_shards` pieces along the 0-th dimension, and 1988 computation is applied to each shard in parallel. 1989 1990 Tensors are broadcast to all shards if they are lexically captured by 1991 `computation`. e.g., 1992 1993 x = tf.constant(7) 1994 def computation(): 1995 return x + 3 1996 ... = shard(computation, ...) 1997 1998 The outputs from all shards are concatenated back together along their 0-th 1999 dimension. 2000 2001 Inputs and outputs of the computation must be at least rank-1 Tensors. 2002 2003 Args: 2004 computation: A Python function that builds a computation to apply to each 2005 shard of the input. 2006 inputs: A list of input tensors or None (equivalent to an empty list). The 2007 0-th dimension of each Tensor must have size divisible by `num_shards`. 2008 num_shards: The number of shards. 2009 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 2010 of arguments as inputs to `computation`. 2011 device_assignment: If not `None`, a `DeviceAssignment` describing the 2012 mapping between logical cores in the computation with physical cores in 2013 the TPU topology. Uses a default device assignment if `None`. The 2014 `DeviceAssignment` may be omitted if each shard of the computation uses 2015 only one core, and there is either only one shard, or the number of shards 2016 is equal to the number of cores in the TPU system. 2017 name: (Deprecated) Does nothing. 2018 xla_options: An instance of `tpu.XLAOptions` which indicates the options 2019 passed to XLA compiler. Use `None` for default options. 2020 Returns: 2021 A list of output tensors. 2022 Raises: 2023 ValueError: If `num_shards <= 0` 2024 """ 2025 return shard( 2026 computation, 2027 inputs, 2028 num_shards=num_shards, 2029 infeed_queue=infeed_queue, 2030 device_assignment=device_assignment, 2031 name=name, 2032 xla_options=xla_options) 2033 2034 2035@tf_export(v1=["tpu.rewrite"]) 2036def rewrite( 2037 computation: Callable[..., Any], 2038 inputs: Optional[List[List[Optional[core_types.Tensor]]]] = None, 2039 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 2040 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 2041 name: Optional[Text] = None, 2042 xla_options: Optional[XLAOptions] = None) -> Any: 2043 """Rewrites `computation` for execution on a TPU system. 2044 2045 Args: 2046 computation: A Python function that builds a computation to apply to the 2047 input. If the function takes n inputs, 'inputs' should be a list of n 2048 tensors. 2049 2050 `computation` may return a list of operations and tensors. Tensors must 2051 come before operations in the returned list. The return value of 2052 `rewrite` is a list of tensors corresponding to the tensors from the 2053 output of `computation`. 2054 2055 All `Operation`s constructed during `computation` will be executed when 2056 evaluating any of the returned output tensors, not just the ones returned. 2057 inputs: A list of input tensors or `None` (equivalent to an empty list). 2058 Each input can be a nested structure containing values that are 2059 convertible to tensors. Note that passing an N-dimension list of 2060 compatible values will result in a N-dimension list of scalar tensors 2061 rather than a single Rank-N tensors. If you need different behavior, 2062 convert part of inputs to tensors with `tf.convert_to_tensor`. 2063 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 2064 of arguments as inputs to `computation`. 2065 device_assignment: if not `None`, a `DeviceAssignment` describing the 2066 mapping between logical cores in the computation with physical cores in 2067 the TPU topology. May be omitted for a single-core computation, in which 2068 case the core attached to task 0, TPU device 0 is used. 2069 name: (Deprecated) Does nothing. 2070 xla_options: An instance of `tpu.XLAOptions` which indicates the options 2071 passed to XLA compiler. Use `None` for default options. 2072 Returns: 2073 Same data structure as if computation(*inputs) is called directly with some 2074 exceptions for correctness. Exceptions include: 2075 1) None output: a NoOp would be returned which control-depends on 2076 computation. 2077 2) Single value output: A tuple containing the value would be returned. 2078 3) Operation-only outputs: a NoOp would be returned which 2079 control-depends on computation. 2080 TODO(b/121383831): Investigate into removing these special cases. 2081 """ 2082 # TODO(b/36647078) remove disable when pylint bug is fixed. 2083 # pylint: disable=indexing-exception 2084 return replicate( 2085 computation, 2086 None if inputs is None else [inputs], 2087 infeed_queue=infeed_queue, 2088 device_assignment=device_assignment, 2089 name=name, 2090 xla_options=xla_options)[0] 2091 # pylint: enable=indexing-exception 2092 2093 # Operations that indicate some error in the user's inference graph. 2094 2095 2096_DENYLISTED_INFERENCE_OPS = set([ 2097 "ReadVariableOp", 2098 "AssignVariableOp", 2099 "AssignAddVariableOp", 2100 "AssignSubVariableOp", 2101 "VarHandleOp", 2102 "Variable", 2103 "VariableV2", 2104]) 2105 2106 2107def under_tpu_inference_context() -> bool: 2108 """Check if it is currently under `_TPUInferenceContext`.""" 2109 graph = ops.get_default_graph() 2110 while graph: 2111 context = graph._get_control_flow_context() # pylint: disable=protected-access 2112 while context: 2113 if isinstance(context, _TPUInferenceContext): 2114 return True 2115 context = context.outer_context 2116 if isinstance(graph, function._FuncGraph): # pylint: disable=protected-access 2117 graph = graph._outer_graph # pylint: disable=protected-access 2118 elif isinstance(graph, func_graph.FuncGraph): 2119 graph = graph.outer_graph 2120 else: 2121 return False 2122 2123 2124class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext): 2125 """A `ControlFlowContext` for nodes inside a TPU inference computation. 2126 2127 The primary role of `_TPUInferenceContext` is to indicate the mode of 2128 operation and possibly sanity check operators inside a 2129 tpu.rewrite_for_inference() computation. 2130 """ 2131 2132 def __init__(self, name: Text, check_ops: bool = True): 2133 super(_TPUInferenceContext, self).__init__() 2134 self._name = name 2135 self._check_ops = check_ops 2136 2137 def AddOp(self, op): 2138 self._AddOpInternal(op) 2139 2140 def _AddOpInternal(self, op): 2141 # pylint: disable=protected-access 2142 if self._check_ops and op.type in _DENYLISTED_INFERENCE_OPS: 2143 raise NotImplementedError( 2144 f"Operation of type {op.type} ({op.name}) is not supported on the " 2145 "TPU for inference. Execution will fail if this op is used in the " 2146 "graph. Make sure your variables are using variable_scope.") 2147 if self._outer_context: 2148 self._outer_context.AddInnerOp(op) 2149 2150 def AddValue(self, val): 2151 result = val 2152 if self._outer_context: 2153 result = self._outer_context.AddValue(val) 2154 return result 2155 2156 def AddInnerOp(self, op): 2157 self._AddOpInternal(op) 2158 2159 @property 2160 def grad_state(self): 2161 return None 2162 2163 2164def validate_inference_rewrite_for_variables(graph: ops.Graph): 2165 """Validates whether rewrite_for_inference() 'worked' for variables. 2166 2167 The rewrite_for_inference() method is supposed to append GuaranteeConstOps 2168 after ReadVariableOps, but this mechanism works only if you are using 2169 tf.compat.v1.get_variable() to create and access variables in your tpu 2170 computation. This validation method can be called immediately after calling 2171 tpu.rewrite_for_inference() to check whether GuaranteeConstOps where added 2172 to the graph. 2173 2174 Typical usages: 2175 tpu.validate_inference_rewrite_for_variables( 2176 tf.compat.v1.get_default_graph()) 2177 2178 tpu.validate_inference_rewrite_for_variables(sess.graph) 2179 2180 Args: 2181 graph: The graph which needs to be validated. 2182 Raises: 2183 RuntimeError: if validation failed. 2184 """ 2185 if not any(x.type == "GuaranteeConst" for x in graph.get_operations()): 2186 raise RuntimeError( 2187 "No GuaranteeConst ops found in the graph after running " 2188 "tpu.rewrite_for_inference(...). Please check that you are using " 2189 "tf.get_variable() to create and access variables in your tpu " 2190 "computation.") 2191 2192 2193def rewrite_for_inference( 2194 computation: Callable[..., Any], 2195 inputs: Optional[List[core_types.Tensor]] = None, 2196 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 2197 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 2198 name: Optional[Text] = None) -> List[core_types.Tensor]: 2199 """Rewrites `computation` for inference on a TPU system. 2200 2201 Other than 'rewriting' the computation to run on a TPU, if using variables 2202 in your computation, it moves the ReadVariableOps outside the TPU 2203 computation, and adds GuaranteeConst ops just after the ReadVariableOps. 2204 This mechanism works only if you are using tf.compat.v1.get_variable() to 2205 create and access variables in your tpu computation. You can validate 2206 whether this worked, by calling validate_inference_rewrite_for_variables() 2207 method immediately after this method to check whether GuaranteeConstOps 2208 where added to the graph. 2209 2210 Args: 2211 computation: A Python function that builds a computation to apply to the 2212 input. If the function takes n inputs, 'inputs' should be a list of n 2213 tensors. If the function returns m outputs, rewrite will return a list of 2214 m tensors. 2215 inputs: A list of input tensors or `None` (equivalent to an empty list). 2216 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 2217 of arguments as inputs to `computation`. 2218 device_assignment: if not `None`, a `DeviceAssignment` describing the 2219 mapping between logical cores in the computation with physical cores in 2220 the TPU topology. May be omitted for a single-core computation, in which 2221 case the core attached to task 0, TPU device 0 is used. 2222 name: The name of the operator. 2223 Returns: 2224 A list of output tensors. 2225 """ 2226 2227 def guarantee_const_getter(getter, name, *args, **kwargs): 2228 with ops.control_dependencies(None): 2229 return array_ops.guarantee_const( 2230 getter(name, *args, **kwargs), name=name + "/GuaranteeConst") 2231 2232 def wrapped_computation(*args, **kwargs): 2233 """Execute computation under `_TPUInferenceContext`.""" 2234 context = _TPUInferenceContext( 2235 name=ops.get_default_graph().unique_name("rewrite_for_inference")) 2236 try: 2237 context.Enter() 2238 2239 vscope = variable_scope.get_variable_scope() 2240 prev_custom_getter = vscope.custom_getter 2241 prev_caching_device = vscope.caching_device 2242 vscope.set_custom_getter(guarantee_const_getter) 2243 vscope.set_caching_device(lambda op: op.device) 2244 2245 result = computation(*args, **kwargs) 2246 2247 vscope.set_custom_getter(prev_custom_getter) 2248 vscope.set_caching_device(prev_caching_device) 2249 finally: 2250 context.Exit() 2251 return result 2252 2253 # pylint: disable=undefined-variable 2254 return rewrite( 2255 wrapped_computation, 2256 inputs=inputs, 2257 infeed_queue=infeed_queue, 2258 device_assignment=device_assignment, 2259 name=name) 2260 # pylint: enable=undefined-variable 2261 2262 2263def prune_unconnected_ops_from_xla(prune_graph: ops.Graph): 2264 """Prunes unconnected ops as listed in _UNCONNECTED_OPS_TO_PRUNE. 2265 2266 Args: 2267 prune_graph: A tensorflow graph from which we wish to prune unconnected ops 2268 as listed in _UNCONNECTED_OPS_TO_PRUNE. In general, these ops should have 2269 no inputs and no consumers. These can often be left behind due to graph 2270 construction rewiring (for instance TF-Hub). While they never execute, 2271 they will cause XLA compile to fail so we strip them from XLA compile by 2272 removing the tpu_replicate attribute. 2273 """ 2274 # Scan over the top level graph and all function graphs. 2275 for graph in [prune_graph] + [ 2276 f for f in prune_graph._functions.values() # pylint: disable=protected-access 2277 ]: 2278 if not isinstance(graph, ops.Graph): 2279 continue 2280 for op in graph.get_operations(): 2281 if op.type not in _UNCONNECTED_OPS_TO_PRUNE: 2282 continue 2283 outputs_consumed = False 2284 for output in op.outputs: 2285 if output.consumers(): 2286 outputs_consumed = True 2287 break 2288 if not outputs_consumed: 2289 logging.info( 2290 "Pruning OP %s of type %s from XLA Compile due to " 2291 "it being disconnected.", op.name, op.type) 2292 op._clear_attr(_TPU_REPLICATE_ATTR) # pylint: disable=protected-access 2293