1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ====================================== 15 16"""Library of TPU helper functions.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import enum 24import typing 25from typing import Any, Callable, Iterable, List, Optional, Text, Tuple, Union 26 27from absl import logging 28import numpy as np 29from six.moves import xrange # pylint: disable=redefined-builtin 30 31from tensorflow.compiler.tf2xla.python import xla as tf2xla 32from tensorflow.core.framework import attr_value_pb2 33from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding 34from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as embedding_pb2 35from tensorflow.python.compiler.xla import xla 36from tensorflow.python.distribute import device_util 37from tensorflow.python.distribute import distribution_strategy_context 38from tensorflow.python.framework import auto_control_deps 39from tensorflow.python.framework import c_api_util 40from tensorflow.python.framework import composite_tensor 41from tensorflow.python.framework import config 42from tensorflow.python.framework import constant_op 43from tensorflow.python.framework import device as pydev 44from tensorflow.python.framework import dtypes 45from tensorflow.python.framework import errors 46from tensorflow.python.framework import func_graph 47from tensorflow.python.framework import function 48from tensorflow.python.framework import ops 49from tensorflow.python.framework import tensor_shape 50from tensorflow.python.ops import array_ops 51from tensorflow.python.ops import control_flow_ops 52from tensorflow.python.ops import math_ops 53from tensorflow.python.ops import variable_scope 54from tensorflow.python.ops import variables 55from tensorflow.python.tpu import device_assignment as device_assignment_lib 56from tensorflow.python.tpu import tpu_feed 57from tensorflow.python.tpu import tpu_function 58from tensorflow.python.tpu import tpu_name_util 59from tensorflow.python.tpu.ops import tpu_ops 60from tensorflow.python.types import core as core_types 61from tensorflow.python.util import compat 62from tensorflow.python.util import nest 63from tensorflow.python.util import object_identity 64from tensorflow.python.util.tf_export import tf_export 65 66ops.NotDifferentiable("TPUReplicatedInput") 67 68# Operations that indicate some error in the users graph, e.g. a placeholder 69# that's introduced outside of the infeed. 70_DENYLISTED_OPS = set([ 71 "Placeholder", 72]) 73 74# XLA doesn't currently support reading of intermediate tensors, thus some ops 75# are not supported. 76_UNSUPPORTED_OPS = set([ 77 "AudioSummary", 78 "AudioSummaryV2", 79 "HistogramSummary", 80 "ImageSummary", 81 "MergeSummary", 82 "Print", 83 "ScalarSummary", 84 "TensorSummary", 85 "TensorSummaryV2", 86 ]) 87 88# Ops which can be safely pruned from XLA compile if they have no consumers. 89# These ops should also have no inputs. 90_UNCONNECTED_OPS_TO_PRUNE = set(["Placeholder", "VarHandleOp"]) 91 92_MAX_WARNING_LINES = 5 93 94_TPU_REPLICATE_ATTR = "_tpu_replicate" 95_POST_DEVICE_REWRITE_ATTR = "_post_device_rewrite" 96_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status" 97_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation" 98_PIVOT_FOR_CLUSTER = "_pivot_for_cluster" 99 100 101core = tpu_name_util.core 102 103 104def _tpu_system_device_name(job: Optional[Text]) -> Text: 105 """Returns the device name for the TPU_SYSTEM device of `job`.""" 106 if job is None: 107 return "/device:TPU_SYSTEM:0" 108 else: 109 return "/job:%s/device:TPU_SYSTEM:0" % job 110 111 112@tf_export(v1=["tpu.initialize_system"]) 113def initialize_system( 114 embedding_config: Optional[embedding_pb2.TPUEmbeddingConfiguration] = None, 115 job: Optional[Text] = None, 116 compilation_failure_closes_chips: bool = True 117) -> core_types.Tensor: 118 """Initializes a distributed TPU system for use with TensorFlow. 119 120 Args: 121 embedding_config: If not None, a `TPUEmbeddingConfiguration` proto 122 describing the desired configuration of the hardware embedding lookup 123 tables. If embedding_config is None, no hardware embeddings can be used. 124 job: The job (the XXX in TensorFlow device specification /job:XXX) that 125 contains the TPU devices that will be initialized. If job=None it is 126 assumed there is only one job in the TensorFlow flock, and an error will 127 be returned if this assumption does not hold. 128 compilation_failure_closes_chips: Set the configuration whether 129 we want to close TPU chips when there is a compilation failure. 130 Returns: 131 A serialized `TopologyProto` that describes the TPU system. Note: 132 the topology must be evaluated using `Session.run` before it can be used. 133 """ 134 config_string = ("" if embedding_config is None else 135 embedding_config.SerializeToString()) 136 with ops.device(_tpu_system_device_name(job)): 137 topology = tpu_ops.configure_distributed_tpu( 138 compilation_failure_closes_chips=compilation_failure_closes_chips) 139 140 if embedding_config is None: 141 return topology 142 143 # This set of control dependencies is needed as this function is expected to 144 # return an op which will return the topology when executed, but we need to 145 # call the embedding initialization op between initializing the TPU and 146 # returning the topology. 147 with ops.control_dependencies([topology]): 148 embedding_init = tpu_ops.configure_tpu_embedding(config=config_string) 149 with ops.control_dependencies([embedding_init]): 150 return array_ops.identity(topology, name="tpu_init_identity") 151 152 153def initialize_system_for_tpu_embedding( 154 embedding_config: embedding_pb2.TPUEmbeddingConfiguration, 155 job: Optional[Text] = None, 156) -> ops.Operation: 157 """Initializes a distributed TPU Embedding system for use with TensorFlow. 158 159 The following two are equivalent: 160 1. initialize_system() with embedding_config. 161 2. initialize_system() without embedding_config, then 162 initialize_system_for_tpu_embedding(). 163 initialize_system() should not be called with embedding_config if 164 initialize_system_for_tpu_embedding() is meant to be called later. 165 166 Args: 167 embedding_config: a `TPUEmbeddingConfiguration` proto describing the desired 168 configuration of the hardware embedding lookup tables. 169 job: The job (the XXX in TensorFlow device specification /job:XXX) that 170 contains the TPU devices that will be initialized. If job=None it is 171 assumed there is only one job in the TensorFlow flock, and an error will 172 be returned if this assumption does not hold. 173 174 Returns: 175 A no-op. 176 """ 177 config_string = embedding_config.SerializeToString() 178 with ops.device(_tpu_system_device_name(job)): 179 return tpu_ops.configure_tpu_embedding(config=config_string) 180 181 182@tf_export(v1=["tpu.shutdown_system"]) 183def shutdown_system(job: Optional[Text] = None) -> ops.Operation: 184 """Shuts down a running a distributed TPU system. 185 186 Args: 187 job: The job (the XXX in TensorFlow device specification /job:XXX) that 188 contains the TPU devices that will be shutdown. If job=None it is 189 assumed there is only one job in the TensorFlow flock, and an error will 190 be returned if this assumption does not hold. 191 """ 192 with ops.device(_tpu_system_device_name(job)): 193 shutdown_distributed_tpu = tpu_ops.shutdown_distributed_tpu() 194 return shutdown_distributed_tpu 195 196 197def _enclosing_tpu_context_and_graph() -> Tuple[Any, Any]: 198 """Returns the TPUReplicateContext and its associated graph.""" 199 graph = ops.get_default_graph() 200 while graph is not None: 201 # pylint: disable=protected-access 202 context_ = graph._get_control_flow_context() 203 # pylint: enable=protected-access 204 while context_ is not None: 205 if isinstance(context_, TPUReplicateContext): 206 return context_, graph 207 context_ = context_.outer_context 208 graph = getattr(graph, "outer_graph", None) 209 raise ValueError("get_replicated_var_handle() called without " 210 "TPUReplicateContext. This shouldn't happen. Please file " 211 "a bug.") 212 213 214def is_tpu_strategy(strategy: Any) -> bool: 215 is_tpu_strat = lambda k: k.__name__.startswith("TPUStrategy") 216 clz = strategy.__class__ 217 return is_tpu_strat(clz) or any(map(is_tpu_strat, clz.__bases__)) 218 219 220def _enclosing_tpu_device_assignment( 221) -> Optional[device_assignment_lib.DeviceAssignment]: 222 if not distribution_strategy_context.has_strategy(): 223 return None 224 strategy = distribution_strategy_context.get_strategy() 225 if not is_tpu_strategy(strategy): 226 return None 227 return strategy.extended._device_assignment # pylint: disable=protected-access 228 229 230@auto_control_deps.register_acd_resource_resolver 231def tpu_replicated_input_resolver( 232 op: ops.Operation, 233 resource_reads: object_identity.ObjectIdentitySet, 234 resource_writes: object_identity.ObjectIdentitySet) -> bool: 235 """Replaces TPUReplicatedInput outputs with its inputs in resource_inputs.""" 236 # Ignore TPUReplicatedInput for ACD purposes since we will be directly adding 237 # control deps on the replicated inputs. 238 if op.type == "TPUReplicatedInput": 239 if resource_reads or resource_writes: 240 resource_reads.clear() 241 resource_writes.clear() 242 return True 243 else: 244 return False 245 # Replace tensors in `resource_inputs` which are outputs of TPUReplicatedInput 246 # with the actual replicated inputs. This allows ACD to correct add control 247 # deps when there are multiple calls to `run` in a 248 # `tf.function`. 249 def replace_with_unreplicated_resources(resource_inputs): 250 """Replaces handles in `resource_inputs` with their unreplicated inputs.""" 251 to_remove = [] 252 to_add = [] 253 for resource in resource_inputs: 254 if resource.op.type == "TPUReplicatedInput": 255 to_remove.append(resource) 256 to_add.extend(resource.op.inputs) 257 for t in to_remove: 258 resource_inputs.discard(t) 259 resource_inputs.update(to_add) 260 return to_add or to_remove 261 262 return bool(replace_with_unreplicated_resources(resource_reads) or 263 replace_with_unreplicated_resources(resource_writes)) 264 265 266class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): 267 """A `ControlFlowContext` for nodes inside a TPU computation. 268 269 The primary role of `TPUReplicateContext` is to mark operators inside a 270 tpu.replicate() computation with the attribute "_tpu_replicate=XYZ", where XYZ 271 is a unique name. 272 273 We use a `ControlFlowContext` to perform the annotation since it integrates 274 with Tensorflow constructs like ResourceVariables. For example, if a 275 `ResourceVariable` is constructed inside a tpu.replicate() block, the 276 `ResourceVariable` implementation can use 277 `with ops.control_dependencies(None)` to build the variable's definition 278 outside the replicated computation. 279 """ 280 281 def __init__(self, name: Text, num_replicas: int, pivot: ops.Operation): 282 """Builds a new TPUReplicateContext. 283 284 Args: 285 name: a unique name for the context, used to populate the `_tpu_replicate` 286 attribute. 287 num_replicas: an integer that gives the number of replicas for the 288 computation. 289 pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any 290 inputs will have a control dependency on the pivot node. This ensures 291 that nodes are correctly included in any enclosing control flow 292 contexts. 293 """ 294 super(TPUReplicateContext, self).__init__() 295 self._num_replicas = num_replicas 296 self._outer_device_function_stack = None 297 self._oc_dev_fn_stack = None 298 self._outside_compilation_cluster = None 299 self._outside_compilation_v2_context = None 300 self._outside_compilation_counter = 0 301 self._in_gradient_colocation = None 302 self._gradient_colocation_stack = [] 303 self._host_compute_core = [] 304 self._name = name 305 self._name_as_bytes = compat.as_bytes(name) 306 self._tpu_relicate_attr_buf = c_api_util.ScopedTFBuffer( 307 attr_value_pb2.AttrValue(s=self._name_as_bytes).SerializeToString()) 308 self._unsupported_ops = [] 309 self._pivot = pivot 310 self._replicated_vars = {} 311 312 def get_replicated_var_handle( 313 self, 314 name: Text, 315 vars_: List[variables.Variable], 316 is_mirrored: bool = False, 317 is_packed: bool = False) -> core_types.Tensor: 318 """Returns a variable handle for replicated TPU variable 'var'. 319 320 This is a method used by an experimental replicated variable implementation 321 and is not intended as a public API. 322 323 Args: 324 name: The common name of the variable. 325 vars_: The replicated TPU variables. 326 is_mirrored: Whether the variables are mirrored, which guarantees the 327 values in each replica are always the same. 328 is_packed: Whether the replicated variables are packed into one variable. 329 330 Returns: 331 The handle of the TPU replicated input node. 332 """ 333 device_assignment = _enclosing_tpu_device_assignment() 334 # We don't need to put device assignment as part of the replicated_vars key 335 # because each TPUReplicateContext will only have one device assignment. 336 handle = self._replicated_vars.get(name) 337 if handle is not None: 338 return handle 339 340 if device_assignment is not None and not is_packed: 341 # Find a variable copy for each replica in the device assignment. 342 # Note that the order of devices for replicas for the variable and the 343 # device assignment might not match. 344 job_name = pydev.DeviceSpec.from_string(vars_[0].device).job 345 devices_to_vars = {device_util.canonicalize(v.device): v for v in vars_} 346 replicated_vars = [] 347 for replica_id in range(device_assignment.num_replicas): 348 for logical_core in range(device_assignment.num_cores_per_replica): 349 device = device_util.canonicalize( 350 device_assignment.tpu_device( 351 replica=replica_id, logical_core=logical_core, job=job_name)) 352 if device in devices_to_vars: 353 replicated_vars.append(devices_to_vars[device]) 354 break 355 else: 356 raise ValueError( 357 "Failed to find a variable on any device in replica {} for " 358 "current device assignment".format(replica_id)) 359 else: 360 replicated_vars = vars_ 361 362 # Builds a TPUReplicatedInput node for the variable, if one does not already 363 # exist. The TPUReplicatedInput node must belong to the enclosing 364 # control-flow scope of the TPUReplicateContext. 365 # TODO(phawkins): consider changing the contract of the TPU encapsulation 366 # so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope 367 # instead. 368 369 _, graph = _enclosing_tpu_context_and_graph() 370 with graph.as_default(): 371 # pylint: disable=protected-access 372 saved_context = graph._get_control_flow_context() 373 graph._set_control_flow_context(self.outer_context) 374 handle = tpu_ops.tpu_replicated_input([v.handle for v in replicated_vars], 375 name=name + "/handle", 376 is_mirrored_variable=is_mirrored, 377 is_packed=is_packed) 378 graph._set_control_flow_context(saved_context) 379 # pylint: enable=protected-access 380 self._replicated_vars[name] = handle 381 return handle 382 383 def report_unsupported_operations(self) -> None: 384 if self._unsupported_ops: 385 op_str = "\n".join(" %s (%s)" % (op.type, op.name) 386 for op in self._unsupported_ops[:_MAX_WARNING_LINES]) 387 logging.warning("%d unsupported operations found: \n%s", 388 len(self._unsupported_ops), op_str) 389 if len(self._unsupported_ops) > _MAX_WARNING_LINES: 390 logging.warning("... and %d more" % 391 (len(self._unsupported_ops) - _MAX_WARNING_LINES)) 392 393 def EnterGradientColocation(self, op: ops.Operation, gradient_uid: Text): 394 if op is not None: 395 if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access 396 # If we are in TF 2 functions (control flow V2 functions, or 397 # tf.function()), we need to attach _xla_outside_compilation attribute 398 # directly because we are not in TPUReplicateContext. 399 try: 400 outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii") 401 except ValueError: 402 # The attr was not present: do nothing. 403 return 404 parts = outside_attr.split(".") 405 cluster = parts[0] + "." + gradient_uid 406 self._outside_compilation_v2_context = OutsideCompilationV2Context( 407 cluster) 408 self._outside_compilation_v2_context.Enter() 409 return 410 self._gradient_colocation_stack.append(op) 411 if not self._outside_compilation_cluster: 412 try: 413 outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii") 414 if self._in_gradient_colocation: 415 raise NotImplementedError( 416 "Cannot nest gradient colocation operations outside compilation" 417 ) 418 if gradient_uid == "__unsupported__": 419 raise NotImplementedError( 420 "No gradient_uid calling gradient within outside_compilation") 421 # When we take the gradient of an op X in an outside_compilation 422 # cluster C in a forward computation we would like to put the ops 423 # corresponding to the gradient of X into a new outside_compilation 424 # cluster C'. However, if we take the gradient of X twice, the second 425 # one should get yet another new outside_compilation cluster C''. 426 # 427 # The mechanism we adopt is to use a 'root_cluster' which is the 428 # cluster that X was in before we took gradients, and a 'gradient_uid' 429 # which is different for every invocation of gradients, and put the 430 # gradient of X in cluster 'root_cluster.gradient_uid'. 431 # 432 # When taking a gradient of a gradient, some ops will be colocated 433 # with Op in the forward pass (e.g., cluster root_cluster) and some in 434 # the backward pass (e.g., cluster root_cluster.initial_gradient_uid). 435 # We need all of the grad-of-grad ops to be in the same cluster to 436 # avoid cyclic dependencies between clusters. We adopt a heuristic 437 # that puts any op clustered with root_cluster.<xxx> in 438 # root_cluster.gradient_uid, even if xxx was initial_gradient_uid. 439 self._in_gradient_colocation = op 440 parts = outside_attr.split(".") 441 cluster = parts[0] + "." + gradient_uid 442 self._EnterOutsideCompilationScope(cluster=cluster) 443 except ValueError: 444 # The attr was not present: do nothing. 445 pass 446 447 def ExitGradientColocation(self, op: ops.Operation, gradient_uid: Text): 448 if op is not None: 449 if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access 450 # Inside a TF2 tf.function or control flow graph and `op` was not 451 # marked to be outside compiled. 452 assert self._outside_compilation_v2_context is None 453 return 454 if self._outside_compilation_v2_context is not None: 455 # Inside a TF2 tf.function or control flow graph and `op` was 456 # marked to be outside compiled. 457 self._outside_compilation_v2_context.Exit() 458 self._outside_compilation_v2_context = None 459 return 460 if not self._gradient_colocation_stack: 461 raise errors.InternalError( 462 op.node_def, op, 463 "Badly nested gradient colocation: empty stack when popping Op " + 464 op.name) 465 last_op = self._gradient_colocation_stack.pop() 466 if op is last_op: 467 if op is self._in_gradient_colocation: 468 self._in_gradient_colocation = None 469 self._ExitOutsideCompilationScope() 470 else: 471 raise errors.InternalError( 472 op.node_def, op, "Badly nested gradient colocation, expected " + 473 last_op + ", got " + op.name) 474 475 def _EnterOutsideCompilationScope(self, cluster: Optional[Text] = None): 476 477 class FakeOp(object): 478 """A helper class to determine the current device. 479 480 Supports only the type and device set/get methods needed to run the 481 graph's _apply_device_function method. 482 """ 483 484 def __init__(self): 485 self._device = "" 486 487 @property 488 def type(self): 489 return "FakeOp" 490 491 @property 492 def device(self): 493 return self._device 494 495 def _set_device(self, device): 496 if isinstance(device, pydev.DeviceSpec): 497 self._device = device.to_string() 498 else: 499 self._device = device 500 501 def _set_device_from_string(self, device_str): 502 self._device = device_str 503 504 if self._outside_compilation_cluster: 505 raise NotImplementedError("Cannot nest outside_compilation clusters") 506 if cluster: 507 self._outside_compilation_cluster = cluster 508 else: 509 self._outside_compilation_cluster = str(self._outside_compilation_counter) 510 self._outside_compilation_counter += 1 511 graph = ops.get_default_graph() 512 fake_op = FakeOp() 513 graph._apply_device_functions(fake_op) # pylint: disable=protected-access 514 device = pydev.DeviceSpec.from_string(fake_op.device) 515 if (device.device_type == "TPU_REPLICATED_CORE" and 516 device.device_index is not None): 517 self._host_compute_core.append(self._outside_compilation_cluster + ":" + 518 str(device.device_index)) 519 self._oc_dev_fn_stack = graph._device_function_stack # pylint: disable=protected-access 520 graph._device_function_stack = self._outer_device_function_stack # pylint: disable=protected-access 521 522 def _ExitOutsideCompilationScope(self): 523 if not self._outside_compilation_cluster: 524 raise NotImplementedError( 525 "Attempted to exit outside_compilation scope when not in scope") 526 self._outside_compilation_cluster = None 527 graph = ops.get_default_graph() 528 graph._device_function_stack = self._oc_dev_fn_stack # pylint: disable=protected-access 529 530 def Enter(self) -> None: 531 if not self._outer_device_function_stack: 532 # Capture the device function stack at the time of first entry 533 # since that is the stack that will be used outside_compilation. 534 graph = ops.get_default_graph() 535 # pylint: disable=protected-access 536 self._outer_device_function_stack = graph._device_function_stack.copy() 537 # pylint: enable=protected-access 538 super(TPUReplicateContext, self).Enter() 539 540 def HostComputeCore(self) -> List[Text]: 541 return self._host_compute_core 542 543 def _RemoveExternalControlEdges( 544 self, op: ops.Operation 545 ) -> Tuple[List[ops.Operation], List[ops.Operation]]: 546 """Remove any external control dependency on this op.""" 547 internal_control_inputs = [] 548 external_control_inputs = [] 549 for x in op.control_inputs: 550 # pylint: disable=protected-access 551 is_internal_op = False 552 ctxt = x._get_control_flow_context() 553 while ctxt is not None: 554 if ctxt == self: 555 is_internal_op = True 556 break 557 ctxt = ctxt._outer_context 558 if is_internal_op: 559 internal_control_inputs.append(x) 560 else: 561 external_control_inputs.append(x) 562 # pylint: enable=protected-access 563 # pylint: disable=protected-access 564 op._remove_all_control_inputs() 565 op._add_control_inputs(internal_control_inputs) 566 # pylint: enable=protected-access 567 return internal_control_inputs, external_control_inputs 568 569 def AddOp(self, op: ops.Operation) -> None: 570 # pylint: disable=protected-access 571 if op.type in _DENYLISTED_OPS: 572 logging.error("Operation of type %s (%s) is not supported on the TPU. " 573 "Execution will fail if this op is used in the graph. ", 574 op.type, op.name) 575 576 if op.type in _UNSUPPORTED_OPS: 577 self._unsupported_ops.append(op) 578 579 if any(x.dtype._is_ref_dtype for x in op.inputs): 580 raise NotImplementedError( 581 "Non-resource Variables are not supported inside TPU computations " 582 "(operator name: %s)" % op.name) 583 584 # TensorFlowOpLayer may clone nodes that are in tpu.rewrite()s. It'll add 585 # the "_cloned" attribute and we should continue in that case. 586 if (_TPU_REPLICATE_ATTR in op.node_def.attr and 587 "_cloned" not in op.node_def.attr): 588 raise ValueError("TPU computations cannot be nested on op (%s)" % 589 op) 590 op._set_attr_with_buf(_TPU_REPLICATE_ATTR, 591 self._tpu_relicate_attr_buf.buffer) 592 if self._outside_compilation_cluster: 593 op._set_attr( 594 _OUTSIDE_COMPILATION_ATTR, 595 attr_value_pb2.AttrValue( 596 s=compat.as_bytes(self._outside_compilation_cluster))) 597 if self._num_replicas > 1 or not self._outside_compilation_cluster: 598 # Prevent feeding or fetching anything that is being compiled, 599 # and any replicated outside_compilation Op. 600 op.graph.prevent_feeding(op) 601 op.graph.prevent_fetching(op) 602 603 # Remove any control edges from outer control flow contexts. These may cause 604 # mismatched frame errors. 605 (internal_control_inputs, 606 external_control_inputs) = self._RemoveExternalControlEdges(op) 607 608 if not op.inputs: 609 # Add a control edge from the control pivot to this op. 610 if not internal_control_inputs: 611 # pylint: disable=protected-access 612 op._add_control_input(self.GetControlPivot()) 613 # pylint: enable=protected-access 614 else: 615 for index in xrange(len(op.inputs)): 616 x = op.inputs[index] 617 real_x = self.AddValue(x) 618 if real_x is not x: 619 op._update_input(index, real_x) # pylint: disable=protected-access 620 621 if external_control_inputs: 622 # Use an identity to pull control inputs as data inputs. Note that we 623 # ignore ops which don't have outputs. TODO(phawkins): fix that. 624 with ops.control_dependencies(None): 625 self.Enter() 626 external_control_inputs = [ 627 array_ops.identity(x.outputs[0]).op 628 for x in external_control_inputs 629 if x.outputs 630 ] 631 self.Exit() 632 # pylint: disable=protected-access 633 op._add_control_inputs(external_control_inputs) 634 # pylint: enable=protected-access 635 636 # Mark op's outputs as seen by this context and any outer contexts. 637 output_names = [x.name for x in op.outputs] 638 context = self 639 while context is not None: 640 # pylint: disable=protected-access 641 context._values.update(output_names) 642 context = context._outer_context 643 # pylint: enable=protected-access 644 645 if self._outer_context: 646 self._outer_context.AddInnerOp(op) 647 648 def AddValue(self, val: core_types.Tensor) -> core_types.Tensor: 649 """Add `val` to the current context and its outer context recursively.""" 650 if not self._outer_context: 651 return val 652 653 if val.name in self._values: 654 # Use the real value if it comes from outer context. 655 result = self._external_values.get(val.name) 656 return val if result is None else result 657 658 result = val 659 self._values.add(val.name) 660 if self._outer_context: 661 result = self._outer_context.AddValue(val) 662 self._values.add(result.name) 663 664 self._external_values[val.name] = result 665 666 return result 667 668 def AddInnerOp(self, op: ops.Operation): 669 self.AddOp(op) 670 if self._outer_context: 671 self._outer_context.AddInnerOp(op) 672 673 @property 674 def grad_state(self): 675 # Define the gradient loop state associated with the TPUReplicateContext to 676 # be None as the TPUReplicateContext does not get nested nor does the 677 # grad_state outside the TPUReplicateContext affect the graph inside so the 678 # grad_state should be as if this is the top-level gradient state. 679 return None 680 681 @property 682 def back_prop(self): 683 """Forwards to the enclosing while context, if any.""" 684 if self.GetWhileContext(): 685 return self.GetWhileContext().back_prop 686 return False 687 688 def GetControlPivot(self) -> ops.Operation: 689 return self._pivot 690 691 def RequiresUniqueFunctionRetracing(self): 692 # More context: b/158152827. TPU stack uses the TPUReplicateContext to 693 # create replicated variable handles and cluster TPU computations, thus we 694 # always retrace a tf.function when the wrapped TPUReplicateContext changes. 695 return True 696 697 698class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext): 699 """The context for outside compilation in Tensorflow 2.0. 700 701 Every op added in this context will be assigned an _xla_outside_compilation 702 attribute. 703 """ 704 705 def __init__(self, name: Text): 706 control_flow_ops.ControlFlowContext.__init__(self) 707 self._name = name 708 709 def AddOp(self, op: ops.Operation) -> None: 710 if self._outer_context: 711 self._outer_context.AddOp(op) 712 # pylint: disable=protected-access 713 op._set_attr("_xla_outside_compilation", 714 attr_value_pb2.AttrValue(s=compat.as_bytes(self._name))) 715 # pylint: enable=protected-access 716 717 def AddInnerOp(self, op: ops.Operation) -> None: 718 if self._outer_context: 719 self._outer_context.AddInnerOp(op) 720 # pylint: disable=protected-access 721 op._set_attr("_xla_outside_compilation", 722 attr_value_pb2.AttrValue(s=compat.as_bytes(self._name))) 723 # pylint: enable=protected-access 724 725 def to_control_flow_context_def(self, context_def, export_scope=None): 726 raise NotImplementedError("to_control_flow_context_def not implemented") 727 728 729@tf_export(v1=["tpu.outside_compilation"]) 730def outside_compilation( 731 computation: Callable[..., Any], *args, **kwargs 732 ) -> Any: 733 """Builds part of a computation outside any current TPU replicate scope. 734 735 `tf.tpu.outside_compilation()` is used to run ops in `computation` on CPU 736 instead of running on TPU. For example, users can run ops that are not 737 supported on TPU's (e.g. tf.summary.write()) by explicitly placing those 738 ops on CPU's. Below usage of outside compilation will place ops in 739 `computation_with_string_ops` on CPU. 740 741 Example usage: 742 743 ```python 744 def computation_with_string_ops(x): 745 # strings types are not supported on TPU's and below ops must 746 # run on CPU instead. 747 output = tf.strings.format('1{}', x) 748 return tf.strings.to_number(output) 749 750 def tpu_computation(): 751 # Expected output is 11. 752 output = tf.tpu.outside_compilation(computation_with_string_ops, 1) 753 ``` 754 755 Outside compilation should be called inside TPUReplicateContext. That is, 756 `tf.tpu.outside_compilation()` should be called inside a function that is 757 passed to `tpu.split_compile_and_replicate()` -- this is implied when 758 outside compilation is invoked inside a function passed to TPUStrategy 759 `run()`. If invoked outside of TPUReplicateContext, 760 then this simply returns the result of `computation`, and therefore, 761 would be a no-op. Note that outside compilation is different from 762 `tf.distribute.experimental.TPUStrategy.merge_call()` as logic in 763 outside compilation is replicated and executed separately for each 764 replica. On the other hand, `merge_call()` requires a `merge_fn` 765 to aggregate the inputs from different replicas and is executed only 766 once. 767 768 For variables placed in TPU device, which includes variables created inside 769 TPUStrategy scope, outside compilation logic must not include variable 770 read/write. For variables placed on host, which is the case when variables 771 created via TPUEstimator, variable read/write is only allowed if the variable 772 is not accessed by any other ops in the TPU computation. Variable read/write 773 from outside compilation cluster is not visible from TPU computation and 774 vice versa. Therefore, if outside compilation logic contains such host 775 variables read/write ops and if the variables are accessed by TPU 776 computation as well, then this may lead to deadlock. 777 778 Internally, `tf.tpu.outside_compilation()` adds outside compilation 779 attributes to all ops in `computation`. During later graph pass, these 780 ops with outside compilation attribute is extracted out and replicated 781 into a host-side graph. Inputs to this extract host-side graph is sent 782 from TPU computation graph to host graph via a pair of XlaSendToHost and 783 XlaRecvFromHost ops. Note that using `tf.tpu.outside_compilation()` 784 may result in tensor transfer between TPU and CPU, leading to non-trivial 785 performance impact. 786 787 Args: 788 computation: A Python function that builds the computation to 789 place on the host. 790 *args: the positional arguments for the computation. 791 **kwargs: the keyword arguments for the computation. 792 793 Returns: 794 The Tensors returned by computation. 795 """ 796 args = [] if args is None else args 797 graph = ops.get_default_graph() 798 799 # If we are in TF 2 functions (control flow V2 functions, or tf.function()), 800 # we need to attach _xla_outside_compilation attribute directly because we are 801 # not in TPUReplicateContext. 802 if isinstance(graph, func_graph.FuncGraph): 803 try: 804 tpu_context, _ = _enclosing_tpu_context_and_graph() 805 except ValueError: 806 logging.warning( 807 "Outside compilation attempted outside TPUReplicateContext " 808 "scope. As no enclosing TPUReplicateContext can be found, " 809 "returning the result of `computation` as is.") 810 return computation(*args, **kwargs) 811 812 # pylint: disable=protected-access 813 outside_compilation_name = str(tpu_context._outside_compilation_counter) 814 tpu_context._outside_compilation_counter = ( 815 tpu_context._outside_compilation_counter + 1) 816 # pylint: enable=protected-access 817 818 outside_compilation_context = OutsideCompilationV2Context( 819 outside_compilation_name) 820 outside_compilation_context.Enter() 821 args = [] if args is None else args 822 retval = computation(*args, **kwargs) 823 outside_compilation_context.Exit() 824 return retval 825 826 # If we are in a TPUReplicateContext, signal that we are now 827 # outside_compilation 828 initial_context = graph._get_control_flow_context() # pylint: disable=protected-access 829 context = initial_context 830 while context: 831 if isinstance(context, TPUReplicateContext): 832 context._EnterOutsideCompilationScope() # pylint: disable=protected-access 833 context = context.outer_context 834 835 retval = computation(*args, **kwargs) 836 837 # If we are in a TPUReplicateContext, signal that we are no longer 838 # outside_compilation 839 final_context = graph._get_control_flow_context() # pylint: disable=protected-access 840 if initial_context is not final_context: 841 raise NotImplementedError( 842 "Control-flow context cannot be different at start and end of an " 843 "outside_compilation scope") 844 context = initial_context 845 while context: 846 if isinstance(context, TPUReplicateContext): 847 context._ExitOutsideCompilationScope() # pylint: disable=protected-access 848 context = context.outer_context 849 850 return retval 851 852 853@tf_export(v1=["tpu.PaddingSpec"]) 854class PaddingSpec(enum.IntEnum): 855 """Represents the type of padding policies for tpu.replicate.""" 856 # By default the policy is set to AUTO, the dynamic input shape dimension will 857 # be pad to maximum of all the replicas. 858 AUTO = 0 859 # Bucketize the dynamic input shape dimension into a power of 2. 860 POWER_OF_TWO = 1 861 862 863@tf_export(v1=["tpu.XLAOptions"]) 864class XLAOptions( 865 collections.namedtuple("XLAOptions", [ 866 "use_spmd_for_xla_partitioning", 867 ])): 868 """XLA compilation options. 869 870 Attributes: 871 use_spmd_for_xla_partitioning: Boolean. Whether to use XLA's SPMD 872 partitioner instead of MPMD partitioner when compiler partitioning is 873 requested. 874 """ 875 876 def __new__(cls, use_spmd_for_xla_partitioning=True): 877 return super(XLAOptions, cls).__new__(cls, use_spmd_for_xla_partitioning) 878 879 880@tf_export(v1=["tpu.replicate"]) 881def replicate( 882 computation: Callable[..., Any], 883 inputs: Optional[List[List[core_types.Tensor]]] = None, 884 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 885 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 886 name: Optional[Text] = None, 887 maximum_shapes: Any = None, 888 padding_spec: Optional[PaddingSpec] = None, 889 xla_options: Optional[XLAOptions] = None) -> List[Any]: 890 """Builds a graph operator that runs a replicated TPU computation. 891 892 Example for the basic usage that `inputs` has static shape: 893 894 ```python 895 896 def computation(x): 897 x = x + 1 898 return tf.math.reduce_mean(x) 899 900 x = tf.convert_to_tensor([1., 2., 3.]) 901 y = tf.convert_to_tensor([4., 5., 6.]) 902 tf.compat.v1.tpu.replicate(computation, inputs=[[x], [y]]) 903 ``` 904 905 If the `inputs` has dynamic shapes and you would like to automatically 906 bucketize the inputs to avoid XLA recompilation. See the advanced example 907 below: 908 909 ```python 910 911 def computation(x): 912 x = x + 1 913 return tf.math.reduce_mean(x) 914 915 # Assume input tensors in two replicas `x` and `y` both have dynamic shape 916 # ([None, 2]). 917 tf.compat.v1.tpu.replicate( 918 computation, 919 inputs=[x, y], 920 maximum_shapes=[tf.TensorShape([None, None])], 921 padding_spec=tf.compat.v1.tpu.PaddingSpec.POWER_OF_TWO) 922 ``` 923 924 Args: 925 computation: A Python function that builds the computation to replicate. 926 inputs: A list of lists of input tensors or `None` (equivalent to 927 `[[]]`), indexed by `[replica_num][input_num]`. All replicas must 928 have the same number of inputs. Each input can be a nested structure 929 containing values that are convertible to tensors. Note that passing an 930 N-dimension list of compatible values will result in a N-dimension list of 931 scalar tensors rather than a single Rank-N tensors. If you need different 932 behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. 933 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 934 of arguments as inputs to computation. 935 device_assignment: If not `None`, a `DeviceAssignment` describing the 936 mapping between logical cores in the computation with physical cores in 937 the TPU topology. Uses a default device assignment if `None`. The 938 `DeviceAssignment` may be omitted if each replica of the computation uses 939 only one core, and there is either only one replica, or the number of 940 replicas is equal to the number of cores in the TPU system. 941 name: (Deprecated) Does nothing. 942 maximum_shapes: A nested structure of tf.TensorShape representing the shape 943 to which the respective component of each input element in each replica 944 should be padded. Any unknown dimensions (e.g. 945 tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like 946 object) will be padded to the maximum size of that dimension over all 947 replicas. The structure of `maximum_shapes` needs to be the same as 948 `inputs[0]`. 949 padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the 950 padding policy when the `inputs` to `tpu.replicate` is dynamic. 951 One usage is to enable automatic bucketizing on the inputs by setting the 952 value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the 953 recompilation in the XLA side. 954 xla_options: An instance of `tpu.XLAOptions` which indicates the options 955 passed to XLA compiler. Use `None` for default options. 956 Returns: 957 A list of outputs, indexed by `[replica_num]` each output can be a nested 958 structure same as what computation() returns with a few exceptions. 959 960 Exceptions include: 961 1) None output: a NoOp would be returned which control-depends on 962 computation. 963 2) Single value output: A tuple containing the value would be returned. 964 3) Operation-only outputs: a NoOp would be returned which 965 control-depends on computation. 966 TODO(b/121383831): Investigate into removing these special cases. 967 968 Raises: 969 ValueError: If all replicas do not have equal numbers of input tensors. 970 ValueError: If the number of inputs per replica does not match 971 the number of formal parameters to `computation`. 972 ValueError: If the static `inputs` dimensions don't match with the values 973 given in `maximum_shapes`. 974 ValueError: If the structure of inputs per replica does not match 975 the structure of `maximum_shapes`. 976 """ 977 return split_compile_and_replicate( 978 computation, 979 inputs, 980 infeed_queue, 981 device_assignment, 982 name, 983 maximum_shapes=maximum_shapes, 984 padding_spec=padding_spec, 985 xla_options=xla_options)[1] 986 987 988def _ceil_to_pow_of_n(x, n): 989 """Ceil input `x` to power of `n`.""" 990 x = math_ops.cast(x, dtypes.float32) 991 lognx = math_ops.log(x) / math_ops.log(n * 1.0) 992 lognx = math_ops.ceil(lognx) 993 result = math_ops.pow(n * 1.0, lognx) 994 result = math_ops.cast(result, dtypes.int32) 995 return result 996 997 998def _pad_all_input( 999 inputs: Iterable[core_types.Tensor], 1000 padded_shapes: List[Optional[tensor_shape.TensorShape]], 1001 padding_spec: PaddingSpec 1002) -> Tuple[List[List[Any]], List[dynamic_padding.PaddingMap]]: 1003 """Pad all input tensors given padded_shapes. 1004 1005 The real shape tensors will be concatenated with the padded original inputs. 1006 1007 Args: 1008 inputs: The original inputs. 1009 padded_shapes: A list of padded shapes for each input. If an entry is None, 1010 no padding is performed. 1011 padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the 1012 padding policy when the `inputs` to `tf.tpu.replicate` is dynamic. 1013 One usage is to enable automatic bucketizing on the inputs by setting the 1014 value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the 1015 recompilation in the XLA side. 1016 1017 Returns: 1018 The padded inputs and a PaddingMap list which maps the padded input 1019 dimension to the real shape argument index. 1020 """ 1021 # maximum_static_shapes[idx][i] indicates the maximum static size of ith 1022 # dimension of the idx input among all the replicas. 1023 maximum_static_shapes = [] 1024 # need_padding[idx][i] indicates whether the ith dimension of the idx input 1025 # needs padding. 1026 need_padding = [] 1027 input_shape_tensors = [] 1028 for core_idx, inputs_per_core in enumerate(inputs): 1029 for idx, input_tensor in enumerate(inputs_per_core): 1030 input_shape = input_tensor.get_shape().as_list() 1031 if core_idx == 0: 1032 input_shape_tensors.append([]) 1033 maximum_static_shapes.append(input_shape) 1034 need_padding.append(np.full_like(input_shape, False, dtype=bool)) 1035 else: 1036 for i, s in enumerate(input_shape): 1037 if s is None or s != maximum_static_shapes[idx][i]: 1038 need_padding[idx][i] = True 1039 maximum_static_shapes[idx] = max(input_shape, 1040 maximum_static_shapes[idx]) 1041 1042 # Append _POST_DEVICE_REWRITE_ATTR attributes to the real shape ops. 1043 real_input_shape = array_ops.shape(input_tensor) 1044 real_input_shape.op._set_attr( # pylint: disable=protected-access 1045 _POST_DEVICE_REWRITE_ATTR, 1046 attr_value_pb2.AttrValue(b=True)) 1047 input_shape_tensors[idx].append(real_input_shape) 1048 1049 maximum_shapes = [] 1050 for shapes_per_input in input_shape_tensors: 1051 maximum_shapes.append( 1052 math_ops.reduce_max(array_ops.stack(shapes_per_input), axis=0)) 1053 1054 padded_inputs = [] 1055 real_shapes = [] 1056 padding_maps = [] 1057 for core_idx, inputs_per_core in enumerate(inputs): 1058 padded_inputs.append([]) 1059 real_shapes.append([]) 1060 real_shape_idx = len(inputs_per_core) - 1 1061 for idx, input_tensor in enumerate(inputs_per_core): 1062 input_shape_tensor = input_shape_tensors[idx][core_idx] 1063 input_shape = input_tensor.get_shape().as_list() 1064 padded_shape = padded_shapes[idx] 1065 1066 # If we have no padded_shape, then skip padding. 1067 if any(need_padding[idx]) and padded_shape is not None: 1068 for i, s in enumerate(input_shape): 1069 if need_padding[idx][i]: 1070 if core_idx == 0: 1071 real_shape_idx += 1 1072 padding_map = dynamic_padding.PaddingMap() 1073 padding_map.arg_index = idx 1074 padding_map.shape_index = i 1075 padding_map.padding_arg_index = real_shape_idx 1076 padding_maps.append(padding_map) 1077 real_shapes[core_idx].append( 1078 math_ops.cast(input_shape_tensor[i], dtypes.int32)) 1079 1080 paddings = [] 1081 for i, s in enumerate(padded_shape.dims): 1082 if need_padding[idx][i]: 1083 # The minimum padded dimension size is 2 as XLA doesn't support size 1084 # 1 dynamic size. 1085 minimum_dynamic_dim_size = 2 1086 if s.value is not None: 1087 # Pad to the given maximum value. 1088 max_dim_size = max(s.value, minimum_dynamic_dim_size) 1089 else: 1090 # If maximum value is not given, then pad to the maximum dimension 1091 # among all the cores. 1092 max_dim_size = math_ops.maximum(maximum_shapes[idx][i], 1093 minimum_dynamic_dim_size) 1094 if padding_spec == PaddingSpec.POWER_OF_TWO: 1095 max_dim_size = _ceil_to_pow_of_n(max_dim_size, 2) 1096 # Pad to the given maximum value. 1097 padding = [0, max_dim_size - input_shape_tensor[i]] 1098 else: 1099 padding = [0, 0] 1100 paddings.append(padding) 1101 1102 if input_tensor.get_shape().is_fully_defined(): 1103 # TODO(rxsang): This is a hack to make sure padded_input has dynamic 1104 # shapes, so any tf.size/tf.shape op performed on it won't be constant 1105 # folded. Do we have better ways to do it? 1106 padded_input = control_flow_ops.cond( 1107 array_ops.constant(True), 1108 lambda: array_ops.pad(input_tensor, paddings), # pylint: disable=cell-var-from-loop 1109 lambda: input_tensor) 1110 else: 1111 padded_input = array_ops.pad(input_tensor, paddings) 1112 1113 # Append _POST_DEVICE_REWRITE_ATTR attributes to all padded inputs. 1114 padded_input.op._set_attr( # pylint: disable=protected-access 1115 _POST_DEVICE_REWRITE_ATTR, 1116 attr_value_pb2.AttrValue(b=True)) 1117 1118 padded_inputs[core_idx].append(padded_input) 1119 else: 1120 padded_inputs[core_idx].append(input_tensor) 1121 1122 num_replicas = len(padded_inputs) 1123 for i in range(num_replicas): 1124 padded_inputs[i].extend(real_shapes[i]) 1125 1126 return padded_inputs, padding_maps 1127 1128 1129def _flatten_and_filter_composite(maybe_composite, non_composite_output, 1130 composite_output=None): 1131 """For an input, replaced the input by a tuple if the input is composite. 1132 1133 If `maybe_composite` is not composite, return the parameter 1134 `non_composite_output` otherwise return a tuple which consists of the value of 1135 the parameter `composite_output` the same number of times as there are 1136 components of the composite tensor. 1137 1138 This is useful for computing a mask when flattening nested data with 1139 `expand_composites=True`. For example 1140 1141 ```python 1142 nest.flatten(data, expand_composites=True) 1143 ``` 1144 1145 and 1146 1147 ```python 1148 nest.flatten(nest.map( 1149 data, lambda x: _flatten_and_filter_composite(x, False, True))) 1150 ``` 1151 1152 will have the same length and second will be True if the tensor in the first 1153 is derived from a expanding a composite tensor. 1154 1155 Args: 1156 maybe_composite: A value to test for being a composite tensor. 1157 non_composite_output: The value to return when `maybe_composite` is not a 1158 composite. 1159 composite_output: the value to fill the output tuple with if 1160 `maybe_composite` is a composite. 1161 1162 Returns: 1163 `non_composite_output` or a tuple with multiple copies of 1164 `composite_output`. 1165 """ 1166 1167 if isinstance(maybe_composite, composite_tensor.CompositeTensor): 1168 num_components = len( 1169 maybe_composite._type_spec._to_components(maybe_composite)) # pylint: disable=protected-access 1170 return (composite_output,) * num_components 1171 return non_composite_output 1172 1173 1174def split_compile_and_replicate( 1175 computation: Callable[..., Any], 1176 inputs: Optional[List[List[core_types.Tensor]]] = None, 1177 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 1178 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 1179 name: Optional[Text] = None, 1180 use_tpu: bool = True, 1181 maximum_shapes: Any = None, 1182 padding_spec: Optional[PaddingSpec] = None, 1183 xla_options: Optional[XLAOptions] = None, 1184) -> List[List[core_types.Tensor]]: 1185 """Builds graph operators that runs compilation and replicated computation. 1186 1187 This is a lower level interface than replicate that returns a separate compile 1188 and execute output tensor. In the generated graph the compile op feeds into 1189 the execute op and no additional compilation is incurred when running the 1190 compile op before the execute op. The compile op returns additional 1191 information about the compilation but does not return the compiled program. 1192 1193 Args: 1194 computation: A Python function that builds the computation to replicate. 1195 inputs: A list of lists of input tensors or `None` (equivalent to 1196 `[[]]`), indexed by `[replica_num][input_num]`. All replicas must 1197 have the same number of inputs. Each input can be a nested structure 1198 containing values that are convertible to tensors. Note that passing an 1199 N-dimension list of compatible values will result in a N-dimension list of 1200 scalar tensors rather than a single Rank-N tensors. If you need different 1201 behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. 1202 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 1203 of arguments as inputs to computation. 1204 device_assignment: If not `None`, a `DeviceAssignment` describing the 1205 mapping between logical cores in the computation with physical cores in 1206 the TPU topology. Uses a default device assignment if `None`. The 1207 `DeviceAssignment` may be omitted if each replica of the computation uses 1208 only one core, and there is either only one replica, or the number of 1209 replicas is equal to the number of cores in the TPU system. 1210 name: (Deprecated) Does nothing. 1211 use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU 1212 backends. Currently, only supports a default placement (computation is 1213 placed on GPU if one is available, and on CPU if not). 1214 maximum_shapes: A nested structure of tf.TensorShape representing the shape 1215 to which the respective component of each input element in each replica 1216 should be padded. Any unknown dimensions (e.g. 1217 tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like 1218 object) will be padded to the maximum size of that dimension over all 1219 replicas. The structure of `maximum_shapes` needs to be the same as 1220 `inputs[0]`. 1221 padding_spec: An enum specified by `tf.tpu.PaddingSpec`. This describes the 1222 padding policy when the `inputs` to `tf.tpu.replicate` is dynamic. 1223 One usage is to enable automatic bucketizing on the inputs by setting the 1224 value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the 1225 recompilation in the XLA side. 1226 xla_options: An instance of `tpu.XLAOptions` which indicates the options 1227 passed to XLA compiler. Use `None` for default options. 1228 1229 Returns: 1230 A list of lists with the first list corresponding to the compile op and the 1231 second a list of output tensors, indexed by `[replica_num][output_num]`. 1232 Raises: 1233 ValueError: If all replicas do not have equal numbers of input tensors. 1234 ValueError: If the number of inputs per replica does not match 1235 the number of formal parameters to `computation`. 1236 ValueError: If the static `inputs` dimensions don't match with the values 1237 given in `maximum_shapes`. 1238 ValueError: If the structure of inputs per replica does not match 1239 the structure of `maximum_shapes`. 1240 """ 1241 del name 1242 inputs = [[]] if inputs is None else inputs 1243 xla_options = xla_options or XLAOptions() 1244 1245 metadata_kwargs = {} 1246 if device_assignment is not None: 1247 # Turn the Numpy array into a flattened list so we can pass it as an 1248 # operator attribute. 1249 metadata_kwargs = { 1250 "topology": 1251 device_assignment.topology.serialized(), 1252 "device_assignment": 1253 device_assignment.core_assignment.flatten().tolist() 1254 } 1255 metadata_kwargs["num_cores_per_replica"] = ( 1256 device_assignment.num_cores_per_replica) 1257 1258 # This entry is used for enabling automatic outside compilation. 1259 metadata_kwargs["allow_soft_placement"] = config.get_soft_device_placement() 1260 if config.get_soft_device_placement(): 1261 logging.info("Automatic outside compilation is enabled. " 1262 "Ops without XLA kernels will be automatically " 1263 "placed on CPU.") 1264 1265 if ((not isinstance(inputs, list)) or 1266 any(not isinstance(inp, (list, tuple)) for inp in inputs)): 1267 raise TypeError("tpu.replicate() inputs must be a list of lists/tuples") 1268 1269 num_replicas = len(inputs) 1270 1271 # No replicas? Nothing to do. 1272 if num_replicas == 0: 1273 return [] 1274 1275 # Checks all replicas have the same structure. 1276 for i in xrange(1, num_replicas): 1277 nest.assert_same_structure(inputs[0], inputs[i]) 1278 1279 # Flatten inputs. This structure may contain None values, which will be 1280 # handled later. 1281 flat_inputs_with_nones = [ 1282 nest.flatten(per_replica_input, expand_composites=True) 1283 for per_replica_input in inputs 1284 ] 1285 # Mask parallel to one replica's inputs with True for tensors coming from 1286 # composites. 1287 is_composite = nest.flatten(nest.map_structure( 1288 lambda x: _flatten_and_filter_composite(x, False, True), inputs[0])) 1289 1290 # Converts inputs to Tensors, replacing Nones with a placeholder 0 since 1291 # tpu_ops.tpu_replicated_input() can't handle non-Tensor values. 1292 flat_inputs = [] 1293 for inp in flat_inputs_with_nones: 1294 flat_inputs.append([ 1295 constant_op.constant(0) if x is None else ops.convert_to_tensor(x) 1296 for x in inp 1297 ]) 1298 1299 # Verifies that all replicas have matching numbers and types of inputs 1300 flat_input_types = [x.dtype for x in flat_inputs[0]] 1301 input_arity = len(inputs[0]) 1302 flat_input_arity = len(flat_input_types) 1303 for i in range(num_replicas): 1304 if len(inputs[i]) != input_arity: 1305 raise ValueError("Replicas must have the same number of inputs. " 1306 "Replica 0 had {} inputs, replica {} had {} " 1307 "inputs.".format(input_arity, i, len(inputs[i]))) 1308 1309 types = [x.dtype for x in flat_inputs[i]] 1310 if types != flat_input_types: 1311 raise ValueError("Replicas must have matching input types. Replica 0 had " 1312 "input types {}, replica {} had input types {}".format( 1313 flat_input_types, i, types)) 1314 1315 arg_error = xla.check_function_argument_count( 1316 computation, input_arity, infeed_queue) 1317 if arg_error is not None: 1318 if infeed_queue is None: 1319 raise TypeError( 1320 "Supplied computation cannot be called with the specified inputs. " 1321 "You specified %d inputs: %s, but the computation needs %s" % ( 1322 input_arity, str([i.name for i in inputs[0]]), arg_error)) 1323 else: 1324 raise TypeError( 1325 "Supplied computation cannot be called with the specified inputs. " 1326 "You specified %d inputs: %s and %d additional inputs from infeed," 1327 " but the computation needs %s" % (input_arity, str( 1328 [i.name 1329 for i in inputs[0]]), infeed_queue.number_of_tuple_elements, 1330 arg_error)) 1331 1332 dynamic_shape_inputs = False 1333 if maximum_shapes: 1334 if infeed_queue: 1335 raise ValueError( 1336 "Dynamic input shapes are not supported with infeed queues") 1337 1338 # Make sure maximum_shapes has the same structure as inputs. 1339 nest.assert_same_structure(inputs[0], maximum_shapes, check_types=False) 1340 1341 # Flatten padded shapes: 1342 # For composite tensor components, we don't want to pad them. For each 1343 # entry of maximum_shapes that corresponds to a composite tensor, replace it 1344 # by a tuple of Nones of the same length as the number of components of the 1345 # composite tensor. When we flatten a second time, this makes 1346 # flat_maximum_shapes have the same length as flat_inputs[i]. We can then 1347 # avoid padding these tensors. The assumption is that they will be used by 1348 # outside compilation or that the components are statically shaped and will 1349 # be used by tpu compatible ops. 1350 flat_maximum_shapes = nest.flatten( 1351 [_flatten_and_filter_composite(x, y) 1352 for x, y in zip(nest.flatten(inputs[0]), 1353 nest.flatten(maximum_shapes))]) 1354 flat_maximum_shapes = [ 1355 tensor_shape.TensorShape(s) if s is not None else None 1356 for s in flat_maximum_shapes 1357 ] 1358 nest.assert_same_structure(flat_inputs[0], flat_maximum_shapes, 1359 check_types=False) 1360 1361 unpadded_inputs = flat_inputs 1362 flat_inputs, padding_maps = _pad_all_input(unpadded_inputs, 1363 flat_maximum_shapes, 1364 padding_spec) 1365 if padding_maps: 1366 dynamic_shape_inputs = True 1367 logging.info("TPU has inputs with dynamic shapes: %s", unpadded_inputs[0]) 1368 1369 metadata_kwargs["step_marker_location"] = getattr( 1370 computation, "step_marker_location", "STEP_MARK_AT_ENTRY") 1371 metadata_kwargs["use_spmd_for_xla_partitioning"] = \ 1372 xla_options.use_spmd_for_xla_partitioning 1373 1374 graph = ops.get_default_graph() 1375 1376 # Fan-in: Builds a TPUReplicatedInput node for each input. 1377 flat_replicated_inputs = [] 1378 for i in range(0, len(flat_inputs[0])): 1379 replicas = [flat_inputs[replica][i] for replica in xrange(num_replicas)] 1380 flat_replicated_inputs.append( 1381 tpu_ops.tpu_replicated_input( 1382 replicas, name="input{}".format(i), index=i)) 1383 if isinstance(graph, func_graph.FuncGraph): 1384 # When we are in Tensorflow 2.0 function, 'graph' will be a FuncGraph 1385 # object. If both outside graph and this function have a TPU cluster, 1386 # they will have the same cluster name and it will cause problems (because 1387 # we lower functional ops in Tensorflow 2.0). Append function name to 1388 # 'cluster_name' to avoid cluster name collision. 1389 cluster_name = graph.unique_name("cluster_" + graph.name) 1390 else: 1391 cluster_name = graph.unique_name("cluster") 1392 pivot = control_flow_ops.no_op(name=cluster_name + "/pivot") 1393 pivot._set_attr(_PIVOT_FOR_CLUSTER, # pylint: disable=protected-access 1394 attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name))) 1395 context = TPUReplicateContext( 1396 name=cluster_name, num_replicas=num_replicas, pivot=pivot) 1397 try: 1398 context.Enter() 1399 1400 metadata = tpu_ops.tpu_replicate_metadata( 1401 num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs) 1402 1403 with tpu_function.tpu_shard_context( 1404 num_replicas), ops.control_dependencies([metadata]): 1405 1406 if dynamic_shape_inputs: 1407 for padding_map in padding_maps: 1408 input_shape = flat_replicated_inputs[padding_map.arg_index].shape 1409 flat_replicated_inputs[ 1410 padding_map.arg_index] = tf2xla.set_dynamic_dimension_size( 1411 flat_replicated_inputs[padding_map.arg_index], 1412 padding_map.shape_index, 1413 flat_replicated_inputs[padding_map.padding_arg_index]) 1414 flat_replicated_inputs[padding_map.arg_index].set_shape(input_shape) 1415 1416 # Add identity ops so even unused inputs are "consumed" by the 1417 # computation. This is to avoid orphaned TPUReplicatedInput nodes. 1418 # TODO(phawkins): consider instead pruning unused TPUReplicatedInput 1419 # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs. 1420 flat_replicated_inputs = [ 1421 array_ops.identity(x, name="replicated_input_{}".format(i)) 1422 for i, x in enumerate(flat_replicated_inputs) 1423 ] 1424 for i, composite in zip(flat_replicated_inputs, is_composite): 1425 # pylint: disable=protected-access 1426 # Add an attribute to the identity node so that they could be removed in 1427 # encapsulate TPU computation pass if unused. However we don't remove 1428 # inputs when dynamic padding is enabled. 1429 # TODO(rxsang): Use other ways except argument index in padding_map so 1430 # outside compilation can work with dynamic padding correctly. 1431 if not dynamic_shape_inputs or composite: 1432 i.op._set_attr("_tpu_input_identity", 1433 attr_value_pb2.AttrValue(b=True)) 1434 # pylint: enable=protected-access 1435 1436 # Clobber replicated placeholders with Nones. 1437 computation_inputs = [ 1438 None if inp is None else replicated for replicated, inp in zip( 1439 flat_replicated_inputs, flat_inputs_with_nones[0]) 1440 ] 1441 1442 # Unflatten the computation inputs to match original input structure. 1443 computation_inputs = nest.pack_sequence_as( 1444 structure=inputs[0], 1445 flat_sequence=computation_inputs[:flat_input_arity], 1446 expand_composites=True) 1447 1448 # If there is an infeed queue, adds the dequeued values to the 1449 # computation's inputs. 1450 if infeed_queue is not None: 1451 infeed_queue.set_number_of_shards(num_replicas) 1452 for t in infeed_queue.generate_dequeue_op(): 1453 computation_inputs.append(t) 1454 1455 # Only resource variables work inside a TPU computation, so turn on 1456 # resource variables for the computation. 1457 # TODO(phawkins): consider removing this code. It will 1458 # be less confusing to clients if they knowingly choose to use resource 1459 # variables. 1460 # Partitioned variables is not supported (b/112311320). 1461 vscope = variable_scope.get_variable_scope() 1462 saved_use_resource = vscope.use_resource 1463 saved_custom_getter = vscope.custom_getter 1464 1465 def custom_getter(getter, name, *args, **kwargs): 1466 """Variables on TPU have a few restrictions.""" 1467 partitioner = kwargs.get("partitioner", None) 1468 if partitioner is not None: 1469 kwargs["partitioner"] = None 1470 logging.warning( 1471 "Partitioned variables are not supported on TPU. Got " 1472 "`partitioner` that is %s for variable %s. " 1473 "Setting `partitioner` to `None`.", partitioner, name) 1474 if saved_custom_getter is None: 1475 return getter(name, *args, **kwargs) 1476 else: 1477 return saved_custom_getter(getter, name, *args, **kwargs) 1478 1479 vscope.set_use_resource(True) 1480 vscope.set_custom_getter(custom_getter) 1481 1482 outputs = computation(*computation_inputs) 1483 1484 vscope.set_use_resource(saved_use_resource) 1485 vscope.set_custom_getter(saved_custom_getter) 1486 1487 outputs_is_flat = xla.is_flat(outputs) 1488 if outputs_is_flat: 1489 output_tensors, control_deps, pack_template = _postprocess_flat_outputs( 1490 outputs) 1491 else: 1492 output_tensors, control_deps, pack_template = ( 1493 _postprocess_non_flat_outputs(outputs)) 1494 1495 # tensor_tracer imports tpu.py. Local import to tensor_tracer to avoid 1496 # import-cycle 1497 if typing.TYPE_CHECKING: 1498 tensor_tracer = Any 1499 else: 1500 # pylint: disable=g-import-not-at-top 1501 from tensorflow.python.tpu import tensor_tracer 1502 # pylint: enable=g-import-not-at-top 1503 if tensor_tracer.TensorTracer.is_enabled(): 1504 tt = tensor_tracer.TensorTracer() 1505 output_tensors = tt.trace_tpu(ops.get_default_graph(), 1506 output_tensors, control_deps, 1507 num_replicas) 1508 1509 context.ExitResult(output_tensors) 1510 finally: 1511 context.report_unsupported_operations() 1512 context.Exit() 1513 host_compute_core = context.HostComputeCore() 1514 1515 if host_compute_core: 1516 attr_value = attr_value_pb2.AttrValue() 1517 attr_value.list.s.extend(compat.as_bytes(x) for x in host_compute_core) 1518 metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access 1519 1520 with ops.control_dependencies([metadata]): 1521 if use_tpu: 1522 compile_status = tpu_ops.tpu_compilation_result() 1523 op = compile_status.op 1524 attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)) 1525 op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access 1526 else: 1527 compile_status = control_flow_ops.no_op(name="compilation_status") 1528 1529 if not output_tensors: 1530 # Returns a list of NoOps dependent on the replication Op, indexed by 1531 # [replica_num]. 1532 return [ 1533 compile_status, 1534 [ 1535 control_flow_ops.group(control_deps, name="shard_%d" % i) 1536 for i in range(num_replicas) 1537 ] 1538 ] 1539 1540 # Fan-out: Builds a TPUReplicatedOutput node for each output. 1541 replicated_outputs = [[] for i in range(num_replicas)] 1542 for i, t in enumerate(output_tensors): 1543 1544 # None values returned by the computation can't be sent to 1545 # tpu_ops.tpu_replicated_output(), we handle them specially here. We can 1546 # avoid the placeholder 0 routine required on the inputs since outputs are 1547 # replicated per-tensor, not per-replica, so we can skip replication. 1548 if t is None: 1549 for replica in range(num_replicas): 1550 replicated_outputs[replica].append(None) 1551 continue 1552 1553 # Fan-out: Builds a TPUReplicatedOutput node for each output. 1554 ys = tpu_ops.tpu_replicated_output( 1555 t, num_replicas, name="output{}".format(i)) 1556 1557 # Wraps the outputs in identity operators so the names of any possible 1558 # `fetch` nodes are preserved by the replication rewrite. 1559 with ops.control_dependencies(control_deps): 1560 for replica in range(num_replicas): 1561 replicated_outputs[replica].append( 1562 array_ops.identity( 1563 ys[replica], name="output_%d_shard_%d" % (i, replica))) 1564 1565 replicated_outputs = [ 1566 nest.pack_sequence_as(pack_template, replica_outs, expand_composites=True) 1567 for replica_outs in replicated_outputs 1568 ] 1569 1570 return [compile_status, replicated_outputs] 1571 1572 1573def _postprocess_flat_outputs( 1574 outputs: Any 1575) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]: 1576 """Validates non-flat outputs, add backs device assignments and other attrs. 1577 1578 Args: 1579 outputs: Output from `computation` inside `tpu.rewrite`. 1580 1581 Returns: 1582 - Tensors extracted from outputs. 1583 - Operations extracted from outputs. 1584 - A pack template for use with nest.pack_sequence_as to pack the tensors. 1585 """ 1586 # Following code segment is to preserve legacy behavior. Previously we only 1587 # supported flat outputs and thus for consistency it was nice to convert even 1588 # single element into a tuple. But now that we support arbitrary output 1589 # structure, this is no longer necessary. 1590 # TODO(b/121383831): Migrate all legacy use cases and delete this special 1591 # case. 1592 # If the computation returns `None`, make it an empty tuple. 1593 if outputs is None: 1594 outputs = tuple() 1595 1596 # For legacy / backwards compatibility reasons we return a list for "flat" 1597 # output values (even if the user's flat return value was a different type or 1598 # even just a scalar value) so use nest.flatten to compute a flat list pack 1599 # template. 1600 pack_template = nest.flatten(outputs, expand_composites=False) 1601 1602 # Even though outputs is already "flat", we flatten any composites so their 1603 # component tensors can be tagged and replicated. The pack_template will be 1604 # used by the caller to repack the composite tensors. 1605 outputs = nest.flatten(outputs, expand_composites=True) 1606 1607 # Append `no_op` here so that fetching any return value of this function 1608 # will trigger TPUExecute node. 1609 outputs += (control_flow_ops.no_op(),) 1610 1611 maybe_convert = lambda x: None if x is None else ops.convert_to_tensor(x) 1612 try: 1613 with ops.device(core(0)): 1614 outputs = [ 1615 o if isinstance(o, ops.Operation) else maybe_convert(o) 1616 for o in outputs 1617 ] 1618 except Exception as e: 1619 raise ValueError( 1620 "TPU function return values must all either be Operations or " 1621 "convertible to Tensors. Got '%s'" % str(e)) 1622 1623 # Separates the returned Operations and Tensors. 1624 output_operations = [o for o in outputs if isinstance(o, ops.Operation)] 1625 output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] 1626 1627 if outputs != output_tensors + output_operations: 1628 raise ValueError( 1629 "TPU functions must return zero-or more Tensor values followed by " 1630 "zero or more Operations.") 1631 1632 # Trim operations off the end of the pack template. output_operations has 1 1633 # extra element due to the no-op that is added. 1634 if len(output_operations) > 1: 1635 pack_template = pack_template[:1 - len(output_operations)] 1636 1637 # Wraps outputs in Identity ops. Otherwise a replicated input copied 1638 # straight to an output would bypass the replicate(). This would be bad 1639 # because the TPUReplicatedInput/TPUReplicatedOutput operator would not 1640 # be rewritten away, leading to a runtime error. 1641 # TODO(phawkins): extend the rewrite to elide these nodes instead. 1642 new_output_tensors = [] 1643 for t in output_tensors: 1644 if t is None: 1645 new_output_tensors.append(None) 1646 with ops.device(t.device if t.device else core(0)): 1647 o = array_ops.identity(t) 1648 # pylint: disable=protected-access 1649 o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) 1650 # pylint: enable=protected-access 1651 new_output_tensors.append(o) 1652 return new_output_tensors, output_operations, pack_template 1653 1654 1655def _postprocess_non_flat_outputs( 1656 outputs: Any 1657) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]: 1658 """Validates non-flat outputs, add backs device assignments and other attrs. 1659 1660 Args: 1661 outputs: Output from `computation` inside `tpu.rewrite`. 1662 1663 Returns: 1664 - Tensors extracted from outputs. 1665 - An empty Operations list because Operations are not allowed in non-flat 1666 outputs. 1667 - A pack template for use with nest.pack_sequence_as to pack the tensors. 1668 """ 1669 1670 # Flatten output items. 1671 flat_outputs = nest.flatten(outputs, expand_composites=True) 1672 1673 # Convert all non-None non-Operation outputs to Tensors. 1674 for i, o in enumerate(flat_outputs): 1675 if o is None: 1676 flat_outputs[i] = None 1677 continue 1678 1679 if isinstance(o, ops.Operation): 1680 raise ValueError( 1681 "tpu.rewrite does not support Operation as return value in non-flat " 1682 "output structure. You can set returned Operations as control " 1683 "dependencies of returned Tensors so Operations are triggered when " 1684 'Tensors are evaluated. Operation found: "%s"' % o.name) 1685 1686 try: 1687 o = ops.convert_to_tensor(o) 1688 except Exception as e: 1689 raise ValueError( 1690 "TPU function return values must all either be Operations or " 1691 'convertible to Tensors. Got error: "%s"' % str(e)) 1692 1693 # Wraps outputs in Identity ops. Otherwise a replicated input copied 1694 # straight to an output would bypass the replicate(). This would be bad 1695 # because the TPUReplicatedInput/TPUReplicatedOutput operator would not 1696 # be rewritten away, leading to a runtime error. 1697 # TODO(phawkins): extend the rewrite to elide these nodes instead. 1698 with ops.device(o.device if o.device else core(0)): 1699 o = array_ops.identity(o) 1700 # pylint: disable=protected-access 1701 o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) 1702 # pylint: enable=protected-access 1703 flat_outputs[i] = array_ops.identity(o) 1704 1705 # All flat_outputs are Tensors, and no Operations. 1706 return flat_outputs, [], outputs 1707 1708 1709def split_compile_and_shard( 1710 computation: Callable[..., Any], 1711 inputs: List[List[Optional[core_types.Tensor]]] = None, 1712 num_shards: int = 1, 1713 input_shard_axes: Optional[List[int]] = None, 1714 outputs_from_all_shards: Union[bool, List[bool]] = True, 1715 output_shard_axes: Optional[List[int]] = None, 1716 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 1717 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 1718 name: Optional[Text] = None, 1719 xla_options: Optional[XLAOptions] = None, 1720 ) -> Tuple[ops.Operation, List[core_types.Tensor]]: 1721 """Shards `computation` for parallel execution. 1722 1723 `inputs` must be a list of Tensors or None (equivalent to an empty list), each 1724 of which has a corresponding split axis (from `input_shard_axes`). Each input 1725 is split into `num_shards` pieces along the corresponding axis, and 1726 computation is applied to each shard in parallel. 1727 1728 Tensors are broadcast to all shards if they are lexically captured by 1729 `computation`. e.g., 1730 1731 x = tf.constant(7) 1732 def computation(): 1733 return x + 3 1734 ... = shard(computation, ...) 1735 1736 If `outputs_from_all_shards` is true, the outputs from all shards of 1737 `computation` are concatenated back together along their `output_shard_axes`. 1738 Otherwise, each output is taken from an arbitrary shard. 1739 1740 Inputs and outputs of the computation must be at least rank-1 Tensors. 1741 1742 Args: 1743 computation: A Python function that builds a computation to apply to each 1744 shard of the input. 1745 inputs: A list of input tensors or None (equivalent to an empty list). Each 1746 input tensor has a corresponding shard axes, given by `input_shard_axes`, 1747 which must have size divisible by `num_shards`. 1748 num_shards: The number of shards. 1749 input_shard_axes: A list of dimensions along which to shard `inputs`, or 1750 `None`. `None` means "shard all inputs along dimension 0". If not `None`, 1751 there must be one dimension per input. 1752 outputs_from_all_shards: Boolean or list of boolean. For each output, if 1753 `True`, outputs from all shards are concatenated along the corresponding 1754 `output_shard_axes` entry. Otherwise, each output is taken 1755 from an arbitrary shard. If the argument is a boolean, the argument's 1756 value is used for each output. 1757 output_shard_axes: A list of dimensions along which to concatenate the 1758 outputs of `computation`, or `None`. `None` means "concatenate all outputs 1759 along dimension 0". If not `None`, there must be one dimension per output. 1760 Ignored if `outputs_from_all_shards` is False. 1761 infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs 1762 of `computation`. 1763 device_assignment: If not `None`, a `DeviceAssignment` describing the 1764 mapping between logical cores in the computation with physical cores in 1765 the TPU topology. Uses a default device assignment if `None`. The 1766 `DeviceAssignment` may be omitted if each shard of the computation uses 1767 only one core, and there is either only one shard, or the number of shards 1768 is equal to the number of cores in the TPU system. 1769 name: (Deprecated) Does nothing. 1770 xla_options: An instance of `tpu.XLAOptions` which indicates the options 1771 passed to XLA compiler. Use `None` for default options. 1772 Returns: 1773 A tuple of (compile op, [output tensors]). 1774 Raises: 1775 ValueError: If num_shards <= 0 1776 ValueError: If len(input_shard_axes) != len(inputs) 1777 ValueError: If len(output_shard_axes) != len(outputs from `computation`) 1778 """ 1779 # TODO(phawkins): consider adding support for broadcasting Tensors passed as 1780 # inputs. 1781 1782 if num_shards <= 0: 1783 raise ValueError("num_shards must be a positive integer.") 1784 1785 inputs = [] if inputs is None else inputs 1786 if not isinstance(inputs, list): 1787 raise TypeError("tpu.shard()'s inputs must be a list of Tensors or None.") 1788 1789 # Converts inputs to Tensors. 1790 inputs = [ops.convert_to_tensor(x) for x in inputs] 1791 1792 if input_shard_axes is None: 1793 input_shard_axes = [0] * len(inputs) 1794 if len(inputs) != len(input_shard_axes): 1795 raise ValueError("Length of input_shard_axes must be equal to the number " 1796 "of inputs.") 1797 1798 if inputs: 1799 # Splits the `inputs` along the corresponding `input_shard_axes`, giving 1800 # lists with layout [input][shard] 1801 split_inputs = [ 1802 array_ops.split(x, num_shards, axis=axis) 1803 for (axis, x) in zip(input_shard_axes, inputs)] 1804 1805 # Transposes the input lists to have layout [shard][input] 1806 transposed_inputs = [list(i) for i in zip(*split_inputs)] 1807 else: 1808 transposed_inputs = [[]] * num_shards 1809 1810 compile_op, outputs = split_compile_and_replicate( 1811 computation, 1812 transposed_inputs, 1813 infeed_queue=infeed_queue, 1814 device_assignment=device_assignment, 1815 name=name, 1816 xla_options=xla_options) 1817 1818 # There must be at least one shard since num_shards > 0. 1819 # TODO(b/36647078) remove disable when pylint bug is fixed. 1820 # pylint: disable=indexing-exception 1821 if isinstance(outputs[0], ops.Operation): 1822 # pylint: enable=indexing-exception 1823 # There were no outputs from the computation and replicate returned a list 1824 # of NoOps with control dependencies on the computation. Return the first 1825 # one so it can be used as a control dependency or fetch node. 1826 # TODO(b/36647078) remove disable when pylint bug is fixed. 1827 # pylint: disable=indexing-exception 1828 return compile_op, [outputs[0]] 1829 # pylint: enable=indexing-exception 1830 1831 # TODO(b/36647078) remove disable when pylint bug is fixed. 1832 # pylint: disable=indexing-exception 1833 num_outputs = len(outputs[0]) 1834 # pylint: enable=indexing-exception 1835 1836 if output_shard_axes is None: 1837 output_shard_axes = [0] * num_outputs 1838 if num_outputs != len(output_shard_axes): 1839 raise ValueError("Length of output_shard_axes must be equal to the number " 1840 "of outputs.") 1841 1842 if isinstance(outputs_from_all_shards, bool): 1843 outputs_from_all_shards = [outputs_from_all_shards] * num_outputs 1844 1845 if num_outputs != len(outputs_from_all_shards): 1846 raise ValueError("Length of outputs_from_all_shards must be equal to the " 1847 "number of outputs.") 1848 1849 results = [] 1850 for (axis, all_shards, x) in zip(output_shard_axes, outputs_from_all_shards, 1851 zip(*outputs)): 1852 if all_shards: 1853 # Concatenate all of the outputs together (use stack for scalars). 1854 shape = x[0].shape 1855 is_scalar = shape is not None and (shape.ndims == 0) 1856 results.append((array_ops.stack(list(x)) if is_scalar 1857 else array_ops.concat(list(x), axis=axis))) 1858 else: 1859 # TODO(phawkins): use a smarter policy, e.g., round-robin across shards. 1860 results.append(x[0]) 1861 1862 return compile_op, results 1863 1864 1865@tf_export(v1=["tpu.shard"]) 1866def shard( 1867 computation: Callable[..., Any], 1868 inputs: Optional[List[core_types.Tensor]] = None, 1869 num_shards: int = 1, 1870 input_shard_axes: Optional[List[int]] = None, 1871 outputs_from_all_shards: Union[bool, List[bool]] = True, 1872 output_shard_axes: Optional[List[int]] = None, 1873 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 1874 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 1875 name: Optional[Text] = None, 1876 xla_options: Optional[XLAOptions] = None) -> List[core_types.Tensor]: 1877 """Shards `computation` for parallel execution. 1878 1879 `inputs` must be a list of Tensors or None (equivalent to an empty list), each 1880 of which has a corresponding split axis (from `input_shard_axes`). Each input 1881 is split into `num_shards` pieces along the corresponding axis, and 1882 computation is applied to each shard in parallel. 1883 1884 Tensors are broadcast to all shards if they are lexically captured by 1885 `computation`. e.g., 1886 1887 x = tf.constant(7) 1888 def computation(): 1889 return x + 3 1890 ... = shard(computation, ...) 1891 1892 TODO(phawkins): consider adding support for broadcasting Tensors passed 1893 as inputs. 1894 1895 If `outputs_from_all_shards` is true, the outputs from all shards of 1896 `computation` are concatenated back together along their `output_shard_axes`. 1897 Otherwise, each output is taken from an arbitrary shard. 1898 1899 Inputs and outputs of the computation must be at least rank-1 Tensors. 1900 1901 Args: 1902 computation: A Python function that builds a computation to apply to each 1903 shard of the input. 1904 inputs: A list of input tensors or None (equivalent to an empty list). Each 1905 input tensor has a corresponding shard axes, given by `input_shard_axes`, 1906 which must have size divisible by `num_shards`. 1907 num_shards: The number of shards. 1908 input_shard_axes: A list of dimensions along which to shard `inputs`, or 1909 `None`. `None` means "shard all inputs along dimension 0". If not `None`, 1910 there must be one dimension per input. 1911 outputs_from_all_shards: Boolean or list of boolean. For each output, if 1912 `True`, outputs from all shards are concatenated along the corresponding 1913 `output_shard_axes` entry. Otherwise, each output is taken 1914 from an arbitrary shard. If the argument is a boolean, the argument's 1915 value is used for each output. 1916 output_shard_axes: A list of dimensions along which to concatenate the 1917 outputs of `computation`, or `None`. `None` means "concatenate all outputs 1918 along dimension 0". If not `None`, there must be one dimension per output. 1919 Ignored if `outputs_from_all_shards` is False. 1920 infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs 1921 of `computation`. 1922 device_assignment: If not `None`, a `DeviceAssignment` describing the 1923 mapping between logical cores in the computation with physical cores in 1924 the TPU topology. Uses a default device assignment if `None`. The 1925 `DeviceAssignment` may be omitted if each shard of the computation uses 1926 only one core, and there is either only one shard, or the number of shards 1927 is equal to the number of cores in the TPU system. 1928 name: (Deprecated) Does nothing. 1929 xla_options: An instance of `tpu.XLAOptions` which indicates the options 1930 passed to XLA compiler. Use `None` for default options. 1931 Returns: 1932 A list of output tensors. 1933 Raises: 1934 ValueError: If num_shards <= 0 1935 ValueError: If len(input_shard_axes) != len(inputs) 1936 ValueError: If len(output_shard_axes) != len(outputs from `computation`) 1937 """ 1938 return split_compile_and_shard( 1939 computation, 1940 inputs=inputs, 1941 num_shards=num_shards, 1942 input_shard_axes=input_shard_axes, 1943 outputs_from_all_shards=outputs_from_all_shards, 1944 output_shard_axes=output_shard_axes, 1945 infeed_queue=infeed_queue, 1946 device_assignment=device_assignment, 1947 name=name, 1948 xla_options=xla_options)[1] 1949 1950 1951@tf_export(v1=["tpu.batch_parallel"]) 1952def batch_parallel( 1953 computation: Callable[..., Any], 1954 inputs: List[List[Optional[core_types.Tensor]]] = None, 1955 num_shards: int = 1, 1956 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 1957 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 1958 name: Optional[Text] = None, 1959 xla_options: Optional[XLAOptions] = None): 1960 """Shards `computation` along the batch dimension for parallel execution. 1961 1962 Convenience wrapper around shard(). 1963 1964 `inputs` must be a list of Tensors or None (equivalent to an empty list). 1965 Each input is split into `num_shards` pieces along the 0-th dimension, and 1966 computation is applied to each shard in parallel. 1967 1968 Tensors are broadcast to all shards if they are lexically captured by 1969 `computation`. e.g., 1970 1971 x = tf.constant(7) 1972 def computation(): 1973 return x + 3 1974 ... = shard(computation, ...) 1975 1976 The outputs from all shards are concatenated back together along their 0-th 1977 dimension. 1978 1979 Inputs and outputs of the computation must be at least rank-1 Tensors. 1980 1981 Args: 1982 computation: A Python function that builds a computation to apply to each 1983 shard of the input. 1984 inputs: A list of input tensors or None (equivalent to an empty list). The 1985 0-th dimension of each Tensor must have size divisible by `num_shards`. 1986 num_shards: The number of shards. 1987 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 1988 of arguments as inputs to `computation`. 1989 device_assignment: If not `None`, a `DeviceAssignment` describing the 1990 mapping between logical cores in the computation with physical cores in 1991 the TPU topology. Uses a default device assignment if `None`. The 1992 `DeviceAssignment` may be omitted if each shard of the computation uses 1993 only one core, and there is either only one shard, or the number of shards 1994 is equal to the number of cores in the TPU system. 1995 name: (Deprecated) Does nothing. 1996 xla_options: An instance of `tpu.XLAOptions` which indicates the options 1997 passed to XLA compiler. Use `None` for default options. 1998 Returns: 1999 A list of output tensors. 2000 Raises: 2001 ValueError: If `num_shards <= 0` 2002 """ 2003 return shard( 2004 computation, 2005 inputs, 2006 num_shards=num_shards, 2007 infeed_queue=infeed_queue, 2008 device_assignment=device_assignment, 2009 name=name, 2010 xla_options=xla_options) 2011 2012 2013@tf_export(v1=["tpu.rewrite"]) 2014def rewrite( 2015 computation: Callable[..., Any], 2016 inputs: List[List[Optional[core_types.Tensor]]] = None, 2017 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 2018 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 2019 name: Optional[Text] = None, 2020 xla_options: Optional[XLAOptions] = None) -> Any: 2021 """Rewrites `computation` for execution on a TPU system. 2022 2023 Args: 2024 computation: A Python function that builds a computation to apply to the 2025 input. If the function takes n inputs, 'inputs' should be a list of n 2026 tensors. 2027 2028 `computation` may return a list of operations and tensors. Tensors must 2029 come before operations in the returned list. The return value of 2030 `rewrite` is a list of tensors corresponding to the tensors from the 2031 output of `computation`. 2032 2033 All `Operation`s constructed during `computation` will be executed when 2034 evaluating any of the returned output tensors, not just the ones returned. 2035 inputs: A list of input tensors or `None` (equivalent to an empty list). 2036 Each input can be a nested structure containing values that are 2037 convertible to tensors. Note that passing an N-dimension list of 2038 compatible values will result in a N-dimension list of scalar tensors 2039 rather than a single Rank-N tensors. If you need different behavior, 2040 convert part of inputs to tensors with `tf.convert_to_tensor`. 2041 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 2042 of arguments as inputs to `computation`. 2043 device_assignment: if not `None`, a `DeviceAssignment` describing the 2044 mapping between logical cores in the computation with physical cores in 2045 the TPU topology. May be omitted for a single-core computation, in which 2046 case the core attached to task 0, TPU device 0 is used. 2047 name: (Deprecated) Does nothing. 2048 xla_options: An instance of `tpu.XLAOptions` which indicates the options 2049 passed to XLA compiler. Use `None` for default options. 2050 Returns: 2051 Same data structure as if computation(*inputs) is called directly with some 2052 exceptions for correctness. Exceptions include: 2053 1) None output: a NoOp would be returned which control-depends on 2054 computation. 2055 2) Single value output: A tuple containing the value would be returned. 2056 3) Operation-only outputs: a NoOp would be returned which 2057 control-depends on computation. 2058 TODO(b/121383831): Investigate into removing these special cases. 2059 """ 2060 # TODO(b/36647078) remove disable when pylint bug is fixed. 2061 # pylint: disable=indexing-exception 2062 return replicate( 2063 computation, 2064 None if inputs is None else [inputs], 2065 infeed_queue=infeed_queue, 2066 device_assignment=device_assignment, 2067 name=name, 2068 xla_options=xla_options)[0] 2069 # pylint: enable=indexing-exception 2070 2071 # Operations that indicate some error in the user's inference graph. 2072 2073 2074_DENYLISTED_INFERENCE_OPS = set([ 2075 "ReadVariableOp", 2076 "AssignVariableOp", 2077 "AssignAddVariableOp", 2078 "AssignSubVariableOp", 2079 "VarHandleOp", 2080 "Variable", 2081 "VariableV2", 2082]) 2083 2084 2085def under_tpu_inference_context() -> bool: 2086 """Check if it is currently under `_TPUInferenceContext`.""" 2087 graph = ops.get_default_graph() 2088 while graph: 2089 context = graph._get_control_flow_context() # pylint: disable=protected-access 2090 while context: 2091 if isinstance(context, _TPUInferenceContext): 2092 return True 2093 context = context.outer_context 2094 if isinstance(graph, function._FuncGraph): # pylint: disable=protected-access 2095 graph = graph._outer_graph # pylint: disable=protected-access 2096 elif isinstance(graph, func_graph.FuncGraph): 2097 graph = graph.outer_graph 2098 else: 2099 return False 2100 2101 2102class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext): 2103 """A `ControlFlowContext` for nodes inside a TPU inference computation. 2104 2105 The primary role of `_TPUInferenceContext` is to indicate the mode of 2106 operation and possibly sanity check operators inside a 2107 tpu.rewrite_for_inference() computation. 2108 """ 2109 2110 def __init__(self, name: Text, check_ops: bool = True): 2111 super(_TPUInferenceContext, self).__init__() 2112 self._name = name 2113 self._check_ops = check_ops 2114 2115 def AddOp(self, op): 2116 self._AddOpInternal(op) 2117 2118 def _AddOpInternal(self, op): 2119 # pylint: disable=protected-access 2120 if self._check_ops and op.type in _DENYLISTED_INFERENCE_OPS: 2121 raise NotImplementedError( 2122 "Operation of type %s (%s) is not supported on the TPU for inference." 2123 " Execution will fail if this op is used in the graph. Make sure your" 2124 " variables are using variable_scope." % (op.type, op.name)) 2125 if self._outer_context: 2126 self._outer_context.AddInnerOp(op) 2127 2128 def AddValue(self, val): 2129 result = val 2130 if self._outer_context: 2131 result = self._outer_context.AddValue(val) 2132 return result 2133 2134 def AddInnerOp(self, op): 2135 self._AddOpInternal(op) 2136 2137 @property 2138 def grad_state(self): 2139 return None 2140 2141 2142def validate_inference_rewrite_for_variables(graph: ops.Graph): 2143 """Validates whether rewrite_for_inference() 'worked' for variables. 2144 2145 The rewrite_for_inference() method is supposed to append GuaranteeConstOps 2146 after ReadVariableOps, but this mechanism works only if you are using 2147 tf.compat.v1.get_variable() to create and access variables in your tpu 2148 computation. This validation method can be called immediately after calling 2149 tpu.rewrite_for_inference() to check whether GuaranteeConstOps where added 2150 to the graph. 2151 2152 Typical usages: 2153 tpu.validate_inference_rewrite_for_variables( 2154 tf.compat.v1.get_default_graph()) 2155 2156 tpu.validate_inference_rewrite_for_variables(sess.graph) 2157 2158 Args: 2159 graph: The graph which needs to be validated. 2160 Raises: 2161 RuntimeError: if validation failed. 2162 """ 2163 if not any(x.type == "GuaranteeConst" for x in graph.get_operations()): 2164 raise RuntimeError( 2165 "No GuaranteeConst ops found in the graph after running " 2166 "tpu.rewrite_for_inference(...). Please check that you are using " 2167 "tf.get_variable() to create and access variables in your tpu " 2168 "computation.") 2169 2170 2171def rewrite_for_inference( 2172 computation: Callable[..., Any], 2173 inputs: Optional[List[core_types.Tensor]] = None, 2174 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 2175 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None, 2176 name: Optional[Text] = None) -> List[core_types.Tensor]: 2177 """Rewrites `computation` for inference on a TPU system. 2178 2179 Other than 'rewriting' the computation to run on a TPU, if using variables 2180 in your computation, it moves the ReadVariableOps outside the TPU 2181 computation, and adds GuaranteeConst ops just after the ReadVariableOps. 2182 This mechanism works only if you are using tf.compat.v1.get_variable() to 2183 create and access variables in your tpu computation. You can validate 2184 whether this worked, by calling validate_inference_rewrite_for_variables() 2185 method immediately after this method to check whether GuaranteeConstOps 2186 where added to the graph. 2187 2188 Args: 2189 computation: A Python function that builds a computation to apply to the 2190 input. If the function takes n inputs, 'inputs' should be a list of n 2191 tensors. If the function returns m outputs, rewrite will return a list of 2192 m tensors. 2193 inputs: A list of input tensors or `None` (equivalent to an empty list). 2194 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 2195 of arguments as inputs to `computation`. 2196 device_assignment: if not `None`, a `DeviceAssignment` describing the 2197 mapping between logical cores in the computation with physical cores in 2198 the TPU topology. May be omitted for a single-core computation, in which 2199 case the core attached to task 0, TPU device 0 is used. 2200 name: The name of the operator. 2201 Returns: 2202 A list of output tensors. 2203 """ 2204 2205 def guarantee_const_getter(getter, name, *args, **kwargs): 2206 with ops.control_dependencies(None): 2207 return array_ops.guarantee_const( 2208 getter(name, *args, **kwargs), name=name + "/GuaranteeConst") 2209 2210 def wrapped_computation(*args, **kwargs): 2211 """Execute computation under `_TPUInferenceContext`.""" 2212 context = _TPUInferenceContext( 2213 name=ops.get_default_graph().unique_name("rewrite_for_inference")) 2214 try: 2215 context.Enter() 2216 2217 vscope = variable_scope.get_variable_scope() 2218 prev_custom_getter = vscope.custom_getter 2219 prev_caching_device = vscope.caching_device 2220 vscope.set_custom_getter(guarantee_const_getter) 2221 vscope.set_caching_device(lambda op: op.device) 2222 2223 result = computation(*args, **kwargs) 2224 2225 vscope.set_custom_getter(prev_custom_getter) 2226 vscope.set_caching_device(prev_caching_device) 2227 finally: 2228 context.Exit() 2229 return result 2230 2231 # pylint: disable=undefined-variable 2232 return rewrite( 2233 wrapped_computation, 2234 inputs=inputs, 2235 infeed_queue=infeed_queue, 2236 device_assignment=device_assignment, 2237 name=name) 2238 # pylint: enable=undefined-variable 2239 2240 2241def prune_unconnected_ops_from_xla(prune_graph: ops.Graph): 2242 """Prunes unconnected ops as listed in _UNCONNECTED_OPS_TO_PRUNE. 2243 2244 Args: 2245 prune_graph: A tensorflow graph from which we wish to prune unconnected ops 2246 as listed in _UNCONNECTED_OPS_TO_PRUNE. In general, these ops should have 2247 no inputs and no consumers. These can often be left behind due to graph 2248 construction rewiring (for instance TF-Hub). While they never execute, 2249 they will cause XLA compile to fail so we strip them from XLA compile by 2250 removing the tpu_replicate attribute. 2251 """ 2252 # Scan over the top level graph and all function graphs. 2253 for graph in [prune_graph] + [ 2254 f for f in prune_graph._functions.values() # pylint: disable=protected-access 2255 ]: 2256 if not isinstance(graph, ops.Graph): 2257 continue 2258 for op in graph.get_operations(): 2259 if op.type not in _UNCONNECTED_OPS_TO_PRUNE: 2260 continue 2261 outputs_consumed = False 2262 for output in op.outputs: 2263 if output.consumers(): 2264 outputs_consumed = True 2265 break 2266 if not outputs_consumed: 2267 logging.info( 2268 "Pruning OP %s of type %s from XLA Compile due to " 2269 "it being disconnected.", op.name, op.type) 2270 op._clear_attr(_TPU_REPLICATE_ATTR) # pylint: disable=protected-access 2271