1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ====================================== 15 16"""Library of TPU helper functions.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from absl import logging 23import numpy as np 24from six.moves import xrange # pylint: disable=redefined-builtin 25 26from tensorflow.core.framework import attr_value_pb2 27from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding 28from tensorflow.python.client import pywrap_tf_session 29from tensorflow.python.compiler.xla import xla 30from tensorflow.python.distribute import device_util 31from tensorflow.python.distribute import distribution_strategy_context 32from tensorflow.python.framework import auto_control_deps 33from tensorflow.python.framework import config 34from tensorflow.python.framework import device as pydev 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import errors 37from tensorflow.python.framework import func_graph 38from tensorflow.python.framework import function 39from tensorflow.python.framework import ops 40from tensorflow.python.framework import tensor_shape 41from tensorflow.python.ops import array_ops 42from tensorflow.python.ops import control_flow_ops 43from tensorflow.python.ops import math_ops 44from tensorflow.python.ops import variable_scope 45from tensorflow.python.tpu import tpu_function 46from tensorflow.python.tpu.ops import tpu_ops 47from tensorflow.python.util import compat 48from tensorflow.python.util import nest 49from tensorflow.python.util.compat import collections_abc 50from tensorflow.python.util.tf_export import tf_export 51 52ops.NotDifferentiable("TPUReplicatedInput") 53 54# Operations that indicate some error in the users graph, e.g. a placeholder 55# that's introduced outside of the infeed. 56_BLACKLISTED_OPS = set([ 57 "Placeholder", 58]) 59 60# XLA doesn't currently support reading of intermediate tensors, thus some ops 61# are not supported. 62_UNSUPPORTED_OPS = set([ 63 "AudioSummary", 64 "AudioSummaryV2", 65 "HistogramSummary", 66 "ImageSummary", 67 "MergeSummary", 68 "Print", 69 "ScalarSummary", 70 "TensorSummary", 71 "TensorSummaryV2", 72 ]) 73 74# Ops which can be safely pruned from XLA compile if they have no consumers. 75# These ops should also have no inputs. 76_UNCONNECTED_OPS_TO_PRUNE = set(["Placeholder", "VarHandleOp"]) 77 78_MAX_WARNING_LINES = 5 79 80_TPU_REPLICATE_ATTR = "_tpu_replicate" 81_POST_DEVICE_REWRITE_ATTR = "_post_device_rewrite" 82_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status" 83_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation" 84 85 86def _tpu_system_device_name(job): 87 """Returns the device name for the TPU_SYSTEM device of `job`.""" 88 if job is None: 89 return "/device:TPU_SYSTEM:0" 90 else: 91 return "/job:%s/device:TPU_SYSTEM:0" % job 92 93 94@tf_export(v1=["tpu.initialize_system"]) 95def initialize_system(embedding_config=None, 96 job=None, 97 compilation_failure_closes_chips=True): 98 """Initializes a distributed TPU system for use with TensorFlow. 99 100 Args: 101 embedding_config: If not None, a `TPUEmbeddingConfiguration` proto 102 describing the desired configuration of the hardware embedding lookup 103 tables. If embedding_config is None, no hardware embeddings can be used. 104 job: The job (the XXX in TensorFlow device specification /job:XXX) that 105 contains the TPU devices that will be initialized. If job=None it is 106 assumed there is only one job in the TensorFlow flock, and an error will 107 be returned if this assumption does not hold. 108 compilation_failure_closes_chips: Set the configuration whether 109 we want to close TPU chips when there is a compilation failure. 110 Returns: 111 A serialized `TopologyProto` that describes the TPU system. Note: 112 the topology must be evaluated using `Session.run` before it can be used. 113 """ 114 config_string = ("" if embedding_config is None else 115 embedding_config.SerializeToString()) 116 with ops.device(_tpu_system_device_name(job)): 117 return tpu_ops.configure_distributed_tpu( 118 embedding_config=config_string, 119 compilation_failure_closes_chips=compilation_failure_closes_chips) 120 121 122def initialize_system_for_tpu_embedding(embedding_config, job=None): 123 """Initializes a distributed TPU Embedding system for use with TensorFlow. 124 125 The following two are equivalent: 126 1. initialize_system() with embedding_config. 127 2. initialize_system() without embedding_config, then 128 initialize_system_for_tpu_embedding(). 129 initialize_system() should not be called with embedding_config if 130 initialize_system_for_tpu_embedding() is meant to be called later. 131 132 Args: 133 embedding_config: a `TPUEmbeddingConfiguration` proto describing the desired 134 configuration of the hardware embedding lookup tables. 135 job: The job (the XXX in TensorFlow device specification /job:XXX) that 136 contains the TPU devices that will be initialized. If job=None it is 137 assumed there is only one job in the TensorFlow flock, and an error will 138 be returned if this assumption does not hold. 139 140 Returns: 141 A no-op. 142 """ 143 config_string = embedding_config.SerializeToString() 144 with ops.device(_tpu_system_device_name(job)): 145 return tpu_ops.configure_tpu_embedding(config=config_string) 146 147 148@tf_export(v1=["tpu.shutdown_system"]) 149def shutdown_system(job=None): 150 """Shuts down a running a distributed TPU system. 151 152 Args: 153 job: The job (the XXX in TensorFlow device specification /job:XXX) that 154 contains the TPU devices that will be shutdown. If job=None it is 155 assumed there is only one job in the TensorFlow flock, and an error will 156 be returned if this assumption does not hold. 157 """ 158 with ops.device(_tpu_system_device_name(job)): 159 shutdown_distributed_tpu = tpu_ops.shutdown_distributed_tpu() 160 return shutdown_distributed_tpu 161 162 163@tf_export(v1=["tpu.core"]) 164def core(num): 165 """Returns the device name for a core in a replicated TPU computation. 166 167 Args: 168 num: the virtual core number within each replica to which operators should 169 be assigned. 170 Returns: 171 A device name, suitable for passing to `tf.device()`. 172 """ 173 return "device:TPU_REPLICATED_CORE:{}".format(num) 174 175 176def _enclosing_tpu_context_and_graph(): 177 """Returns the TPUReplicateContext and its associated graph.""" 178 graph = ops.get_default_graph() 179 while graph is not None: 180 # pylint: disable=protected-access 181 context_ = graph._get_control_flow_context() 182 # pylint: enable=protected-access 183 while context_ is not None: 184 if isinstance(context_, TPUReplicateContext): 185 return context_, graph 186 context_ = context_.outer_context 187 graph = getattr(graph, "outer_graph", None) 188 raise ValueError("get_replicated_var_handle() called without " 189 "TPUReplicateContext. This shouldn't happen. Please file " 190 "a bug.") 191 192 193def is_tpu_strategy(strategy): 194 is_tpu_strat = lambda k: k.__name__.startswith("TPUStrategy") 195 clz = strategy.__class__ 196 return is_tpu_strat(clz) or any(map(is_tpu_strat, clz.__bases__)) 197 198 199def _enclosing_tpu_device_assignment(): 200 if not distribution_strategy_context.has_strategy(): 201 return None 202 strategy = distribution_strategy_context.get_strategy() 203 if not is_tpu_strategy(strategy): 204 return None 205 return strategy.extended._device_assignment # pylint: disable=protected-access 206 207 208@auto_control_deps.register_acd_resource_resolver 209def tpu_replicated_input_resolver(op, resource_inputs): 210 """Replaces TPUReplicatedInput outputs with its inputs in resource_inputs.""" 211 # Ignore TPUReplicatedInput for ACD purposes since we will be directly adding 212 # control deps on the replicated inputs. 213 if op.type == "TPUReplicatedInput": 214 if resource_inputs: 215 resource_inputs.clear() 216 return True 217 else: 218 return False 219 # Replace tensors in `resource_inputs` which are outputs of TPUReplicatedInput 220 # with the actual replicated inputs. This allows ACD to correct add control 221 # deps when there are multiple calls to `experimental_run_v2` in a 222 # `tf.function`. 223 to_remove = [] 224 to_add = [] 225 for resource in resource_inputs: 226 if resource.op.type == "TPUReplicatedInput": 227 to_remove.append(resource) 228 to_add.extend(resource.op.inputs) 229 if not to_add and not to_remove: 230 return False 231 for t in to_remove: 232 resource_inputs.discard(t) 233 resource_inputs.update(to_add) 234 return True 235 236 237class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): 238 """A `ControlFlowContext` for nodes inside a TPU computation. 239 240 The primary role of `TPUReplicateContext` is to mark operators inside a 241 tpu.replicate() computation with the attribute "_tpu_replicate=XYZ", where XYZ 242 is a unique name. 243 244 We use a `ControlFlowContext` to perform the annotation since it integrates 245 with Tensorflow constructs like ResourceVariables. For example, if a 246 `ResourceVariable` is constructed inside a tpu.replicate() block, the 247 `ResourceVariable` implementation can use 248 `with ops.control_dependencies(None)` to build the variable's definition 249 outside the replicated computation. 250 """ 251 252 class _TFBufferWrapper(object): 253 """An internal class to help manage the TF_Buffer lifetime.""" 254 255 def __init__(self, buf_string): 256 self._buffer = pywrap_tf_session.TF_NewBufferFromString( 257 compat.as_bytes(buf_string)) 258 259 def __del__(self): 260 pywrap_tf_session.TF_DeleteBuffer(self._buffer) 261 262 def __init__(self, name, num_replicas, pivot): 263 """Builds a new TPUReplicateContext. 264 265 Args: 266 name: a unique name for the context, used to populate the `_tpu_replicate` 267 attribute. 268 num_replicas: an integer that gives the number of replicas for the 269 computation. 270 pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any 271 inputs will have a control dependency on the pivot node. This ensures 272 that nodes are correctly included in any enclosing control flow 273 contexts. 274 """ 275 super(TPUReplicateContext, self).__init__() 276 self._num_replicas = num_replicas 277 self._outer_device_function_stack = None 278 self._oc_dev_fn_stack = None 279 self._outside_compilation_cluster = None 280 self._outside_compilation_counter = 0 281 self._in_gradient_colocation = None 282 self._gradient_colocation_stack = [] 283 self._host_compute_core = [] 284 self._name = name 285 self._name_as_bytes = compat.as_bytes(name) 286 self._tpu_relicate_attr_buf = self._TFBufferWrapper( 287 attr_value_pb2.AttrValue(s=self._name_as_bytes).SerializeToString()) 288 self._unsupported_ops = [] 289 self._pivot = pivot 290 self._replicated_vars = {} 291 292 def get_replicated_var_handle(self, name, vars_, is_mirrored=False): 293 """Returns a variable handle for replicated TPU variable 'var'. 294 295 This is a method used by an experimental replicated variable implementation 296 and is not intended as a public API. 297 298 Args: 299 name: The common name of the variable. 300 vars_: The replicated TPU variables. 301 is_mirrored: Whether the variables are mirrored, which guarantees the 302 values in each replica are always the same. 303 304 Returns: 305 The handle of the TPU replicated input node. 306 """ 307 device_assignment = _enclosing_tpu_device_assignment() 308 # We don't need to put device assignment as part of the replicated_vars key 309 # because each TPUReplicateContext will only have one device assignment. 310 handle = self._replicated_vars.get(name) 311 if handle is not None: 312 return handle 313 314 if device_assignment is not None: 315 # Find a variable copy for each replica in the device assignment. 316 # Note that the order of devices for replicas for the variable and the 317 # device assignment might not match. 318 job_name = pydev.DeviceSpec.from_string(vars_[0].device).job 319 devices_to_vars = {v.device: v for v in vars_} 320 replicated_vars = [] 321 for replica_id in range(device_assignment.num_replicas): 322 for logical_core in range(device_assignment.num_cores_per_replica): 323 device = device_util.canonicalize( 324 device_assignment.tpu_device( 325 replica=replica_id, logical_core=logical_core, job=job_name)) 326 if device in devices_to_vars: 327 replicated_vars.append(devices_to_vars[device]) 328 break 329 else: 330 raise ValueError( 331 "Failed to find a variable on any device in replica {} for " 332 "current device assignment".format(replica_id)) 333 else: 334 replicated_vars = vars_ 335 336 # Builds a TPUReplicatedInput node for the variable, if one does not already 337 # exist. The TPUReplicatedInput node must belong to the enclosing 338 # control-flow scope of the TPUReplicateContext. 339 # TODO(phawkins): consider changing the contract of the TPU encapsulation 340 # so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope 341 # instead. 342 343 _, graph = _enclosing_tpu_context_and_graph() 344 with graph.as_default(): 345 # pylint: disable=protected-access 346 saved_context = graph._get_control_flow_context() 347 graph._set_control_flow_context(self.outer_context) 348 handle = tpu_ops.tpu_replicated_input([v.handle for v in replicated_vars], 349 name=name + "/handle", 350 is_mirrored_variable=is_mirrored) 351 graph._set_control_flow_context(saved_context) 352 # pylint: enable=protected-access 353 self._replicated_vars[name] = handle 354 return handle 355 356 def report_unsupported_operations(self): 357 if self._unsupported_ops: 358 op_str = "\n".join(" %s (%s)" % (op.type, op.name) 359 for op in self._unsupported_ops[:_MAX_WARNING_LINES]) 360 logging.warning("%d unsupported operations found: \n%s", 361 len(self._unsupported_ops), op_str) 362 if len(self._unsupported_ops) > _MAX_WARNING_LINES: 363 logging.warning("... and %d more" % 364 (len(self._unsupported_ops) - _MAX_WARNING_LINES)) 365 366 def EnterGradientColocation(self, op, gradient_uid): 367 if op is not None: 368 self._gradient_colocation_stack.append(op) 369 if not self._outside_compilation_cluster: 370 try: 371 outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii") 372 if self._in_gradient_colocation: 373 raise NotImplementedError( 374 "Cannot nest gradient colocation operations outside compilation" 375 ) 376 if gradient_uid == "__unsupported__": 377 raise NotImplementedError( 378 "No gradient_uid calling gradient within outside_compilation") 379 # When we take the gradient of an op X in an outside_compilation 380 # cluster C in a forward computation we would like to put the ops 381 # corresponding to the gradient of X into a new outside_compilation 382 # cluster C'. However, if we take the gradient of X twice, the second 383 # one should get yet another new outside_compilation cluster C''. 384 # 385 # The mechanism we adopt is to use a 'root_cluster' which is the 386 # cluster that X was in before we took gradients, and a 'gradient_uid' 387 # which is different for every invocation of gradients, and put the 388 # gradient of X in cluster 'root_cluster.gradient_uid'. 389 # 390 # When taking a gradient of a gradient, some ops will be colocated 391 # with Op in the forward pass (e.g., cluster root_cluster) and some in 392 # the backward pass (e.g., cluster root_cluster.initial_gradient_uid). 393 # We need all of the grad-of-grad ops to be in the same cluster to 394 # avoid cyclic dependencies between clusters. We adopt a heuristic 395 # that puts any op clustered with root_cluster.<xxx> in 396 # root_cluster.gradient_uid, even if xxx was initial_gradient_uid. 397 self._in_gradient_colocation = op 398 parts = outside_attr.split(".") 399 cluster = parts[0] + "." + gradient_uid 400 self._EnterOutsideCompilationScope(cluster=cluster) 401 except ValueError: 402 # The attr was not present: do nothing. 403 pass 404 405 def ExitGradientColocation(self, op, gradient_uid): 406 if op is not None: 407 if not self._gradient_colocation_stack: 408 raise errors.InternalError( 409 op.node_def, op, 410 "Badly nested gradient colocation: empty stack when popping Op " + 411 op.name) 412 last_op = self._gradient_colocation_stack.pop() 413 if op is last_op: 414 if op is self._in_gradient_colocation: 415 self._in_gradient_colocation = None 416 self._ExitOutsideCompilationScope() 417 else: 418 raise errors.InternalError( 419 op.node_def, op, "Badly nested gradient colocation, expected " + 420 last_op + ", got " + op.name) 421 422 def _EnterOutsideCompilationScope(self, cluster=None): 423 424 class FakeOp(object): 425 """A helper class to determine the current device. 426 427 Supports only the type and device set/get methods needed to run the 428 graph's _apply_device_function method. 429 """ 430 431 def __init__(self): 432 self._device = "" 433 434 @property 435 def type(self): 436 return "FakeOp" 437 438 @property 439 def device(self): 440 return self._device 441 442 def _set_device(self, device): 443 if isinstance(device, pydev.DeviceSpec): 444 self._device = device.to_string() 445 else: 446 self._device = device 447 448 def _set_device_from_string(self, device_str): 449 self._device = device_str 450 451 if self._outside_compilation_cluster: 452 raise NotImplementedError("Cannot nest outside_compilation clusters") 453 if cluster: 454 self._outside_compilation_cluster = cluster 455 else: 456 self._outside_compilation_cluster = str(self._outside_compilation_counter) 457 self._outside_compilation_counter += 1 458 graph = ops.get_default_graph() 459 fake_op = FakeOp() 460 graph._apply_device_functions(fake_op) # pylint: disable=protected-access 461 device = pydev.DeviceSpec.from_string(fake_op.device) 462 if (device.device_type == "TPU_REPLICATED_CORE" and 463 device.device_index is not None): 464 self._host_compute_core.append(self._outside_compilation_cluster + ":" + 465 str(device.device_index)) 466 self._oc_dev_fn_stack = graph._device_function_stack # pylint: disable=protected-access 467 graph._device_function_stack = self._outer_device_function_stack # pylint: disable=protected-access 468 469 def _ExitOutsideCompilationScope(self): 470 if not self._outside_compilation_cluster: 471 raise NotImplementedError( 472 "Attempted to exit outside_compilation scope when not in scope") 473 self._outside_compilation_cluster = None 474 graph = ops.get_default_graph() 475 graph._device_function_stack = self._oc_dev_fn_stack # pylint: disable=protected-access 476 477 def Enter(self): 478 if not self._outer_device_function_stack: 479 # Capture the device function stack at the time of first entry 480 # since that is the stack that will be used outside_compilation. 481 graph = ops.get_default_graph() 482 # pylint: disable=protected-access 483 self._outer_device_function_stack = graph._device_function_stack.copy() 484 # pylint: enable=protected-access 485 super(TPUReplicateContext, self).Enter() 486 487 def HostComputeCore(self): 488 return self._host_compute_core 489 490 def _RemoveExternalControlEdges(self, op): 491 """Remove any external control dependency on this op.""" 492 internal_control_inputs = [] 493 external_control_inputs = [] 494 for x in op.control_inputs: 495 # pylint: disable=protected-access 496 is_internal_op = False 497 ctxt = x._get_control_flow_context() 498 while ctxt is not None: 499 if ctxt == self: 500 is_internal_op = True 501 break 502 ctxt = ctxt._outer_context 503 if is_internal_op: 504 internal_control_inputs.append(x) 505 else: 506 external_control_inputs.append(x) 507 # pylint: enable=protected-access 508 # pylint: disable=protected-access 509 op._remove_all_control_inputs() 510 op._add_control_inputs(internal_control_inputs) 511 # pylint: enable=protected-access 512 return internal_control_inputs, external_control_inputs 513 514 def AddOp(self, op): 515 # pylint: disable=protected-access 516 if op.type in _BLACKLISTED_OPS: 517 logging.error("Operation of type %s (%s) is not supported on the TPU. " 518 "Execution will fail if this op is used in the graph. " % 519 (op.type, op.name)) 520 521 if op.type in _UNSUPPORTED_OPS: 522 self._unsupported_ops.append(op) 523 524 if any(x.dtype._is_ref_dtype for x in op.inputs): 525 raise NotImplementedError( 526 "Non-resource Variables are not supported inside TPU computations " 527 "(operator name: %s)" % op.name) 528 529 # TensorFlowOpLayer may clone nodes that are in tpu.rewrite()s. It'll add 530 # the "_cloned" attribute and we should continue in that case. 531 if (_TPU_REPLICATE_ATTR in op.node_def.attr and 532 "_cloned" not in op.node_def.attr): 533 raise ValueError("TPU computations cannot be nested on op (%s)" % 534 op) 535 op._set_attr_with_buf( 536 _TPU_REPLICATE_ATTR, self._tpu_relicate_attr_buf._buffer) 537 if self._outside_compilation_cluster: 538 op._set_attr( 539 _OUTSIDE_COMPILATION_ATTR, 540 attr_value_pb2.AttrValue( 541 s=compat.as_bytes(self._outside_compilation_cluster))) 542 if self._num_replicas > 1 or not self._outside_compilation_cluster: 543 # Prevent feeding or fetching anything that is being compiled, 544 # and any replicated outside_compilation Op. 545 op.graph.prevent_feeding(op) 546 op.graph.prevent_fetching(op) 547 548 # Remove any control edges from outer control flow contexts. These may cause 549 # mismatched frame errors. 550 (internal_control_inputs, 551 external_control_inputs) = self._RemoveExternalControlEdges(op) 552 553 if not op.inputs: 554 # Add a control edge from the control pivot to this op. 555 if not internal_control_inputs: 556 # pylint: disable=protected-access 557 op._add_control_input(self.GetControlPivot()) 558 # pylint: enable=protected-access 559 else: 560 for index in xrange(len(op.inputs)): 561 x = op.inputs[index] 562 real_x = self.AddValue(x) 563 if real_x is not x: 564 op._update_input(index, real_x) # pylint: disable=protected-access 565 566 if external_control_inputs: 567 # Use an identity to pull control inputs as data inputs. Note that we 568 # ignore ops which don't have outputs. TODO(phawkins): fix that. 569 with ops.control_dependencies(None): 570 self.Enter() 571 external_control_inputs = [ 572 array_ops.identity(x.outputs[0]).op 573 for x in external_control_inputs 574 if x.outputs 575 ] 576 self.Exit() 577 # pylint: disable=protected-access 578 op._add_control_inputs(external_control_inputs) 579 # pylint: enable=protected-access 580 581 # Mark op's outputs as seen by this context and any outer contexts. 582 output_names = [x.name for x in op.outputs] 583 context = self 584 while context is not None: 585 # pylint: disable=protected-access 586 context._values.update(output_names) 587 context = context._outer_context 588 # pylint: enable=protected-access 589 590 if self._outer_context: 591 self._outer_context.AddInnerOp(op) 592 593 def AddValue(self, val): 594 """Add `val` to the current context and its outer context recursively.""" 595 if val.name in self._values: 596 # Use the real value if it comes from outer context. 597 result = self._external_values.get(val.name) 598 return val if result is None else result 599 600 result = val 601 self._values.add(val.name) 602 if self._outer_context: 603 result = self._outer_context.AddValue(val) 604 self._values.add(result.name) 605 606 self._external_values[val.name] = result 607 608 return result 609 610 def AddInnerOp(self, op): 611 self.AddOp(op) 612 if self._outer_context: 613 self._outer_context.AddInnerOp(op) 614 615 @property 616 def grad_state(self): 617 # Define the gradient loop state associated with the TPUReplicateContext to 618 # be None as the TPUReplicateContext does not get nested nor does the 619 # grad_state outside the TPUReplicateContext affect the graph inside so the 620 # grad_state should be as if this is the top-level gradient state. 621 return None 622 623 @property 624 def back_prop(self): 625 """Forwards to the enclosing while context, if any.""" 626 if self.GetWhileContext(): 627 return self.GetWhileContext().back_prop 628 return False 629 630 def GetControlPivot(self): 631 return self._pivot 632 633 634class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext): 635 """The context for outside compilation in Tensorflow 2.0. 636 637 Every op added in this context will be assigned an _xla_outside_compilation 638 attribute. 639 """ 640 641 def __init__(self, name): 642 control_flow_ops.ControlFlowContext.__init__(self) 643 self._name = name 644 645 def AddOp(self, op): 646 if self._outer_context: 647 self._outer_context.AddOp(op) 648 # pylint: disable=protected-access 649 op._set_attr("_xla_outside_compilation", 650 attr_value_pb2.AttrValue(s=compat.as_bytes(self._name))) 651 # pylint: enable=protected-access 652 653 def AddInnerOp(self, op): 654 if self._outer_context: 655 self._outer_context.AddInnerOp(op) 656 # pylint: disable=protected-access 657 op._set_attr("_xla_outside_compilation", 658 attr_value_pb2.AttrValue(s=compat.as_bytes(self._name))) 659 # pylint: enable=protected-access 660 661 def to_control_flow_context_def(self, context_def, export_scope=None): 662 raise NotImplementedError("to_control_flow_context_def not implemented") 663 664 665@tf_export(v1=["tpu.outside_compilation"]) 666def outside_compilation(computation, *args, **kwargs): 667 """Builds part of a computation outside any current TPU replicate scope. 668 669 `tf.tpu.outside_compilation()` is used to run ops in `computation` on CPU 670 instead of running on TPU. For example, users can run ops that are not 671 supported on TPU's (e.g. tf.summary.write()) by explicitly placing those 672 ops on CPU's. Below usage of outside compilation will place ops in 673 `computation_with_string_ops` on CPU. 674 675 def computation_with_string_ops(x): 676 # strings types are not supported on TPU's and below ops must 677 # run on CPU instead. 678 output = tf.strings.format('1{}', x) 679 return tf.strings.to_number(output) 680 681 def tpu_computation(): 682 # Expected output is 11. 683 output = tf.tpu.outside_compilation(computation_with_string_ops, 1) 684 685 Outside compilation should be called inside TPUReplicateContext. That is, 686 `tf.tpu.outside_compilation()` should be called inside a function that is 687 passed to `tpu.split_compile_and_replicate()` -- this is implied when 688 outside compilation is invoked inside a function passed to TPUStrategy 689 `experimental_run_v2()`. If invoked outside of TPUReplicateContext, 690 then this simply returns the result of `computation`, and therefore, 691 would be a no-op. Note that outside compilation is different from 692 `tf.distribute.experimental.TPUStrategy.merge_call()` as logic in 693 outside compilation is replicated and executed separately for each 694 replica. On the other hand, `merge_call()` requires a `merge_fn` 695 to aggregate the inputs from different replicas and is executed only 696 once. 697 698 For variables placed in TPU device, which includes variables created inside 699 TPUStrategy scope, outside compilation logic must not include variable 700 read/write. For variables placed on host, which is the case when variables 701 created via TPUEstimator, variable read/write is only allowed if the variable 702 is not accessed by any other ops in the TPU computation. Variable read/write 703 from outside compilation cluster is not visible from TPU computation and 704 vice versa. Therefore, if outside compilation logic contains such host 705 variables read/write ops and if the variables are accessed by TPU 706 computation as well, then this may lead to deadlock. 707 708 Internally, `tf.tpu.outside_compilation()` adds outside compilation 709 attributes to all ops in `computation`. During later graph pass, these 710 ops with outside compilation attribute is extracted out and replicated 711 into a host-side graph. Inputs to this extract host-side graph is sent 712 from TPU computation graph to host graph via a pair of XlaSendToHost and 713 XlaRecvFromHost ops. Note that using `tf.tpu.outside_compilation()` 714 may result in tensor transfer between TPU and CPU, leading to non-trivial 715 performance impact. 716 717 Args: 718 computation: A Python function that builds the computation to 719 place on the host. 720 *args: the positional arguments for the computation. 721 **kwargs: the keyword arguments for the computation. 722 723 Returns: 724 The Tensors returned by computation. 725 """ 726 args = [] if args is None else args 727 graph = ops.get_default_graph() 728 729 # If we are in TF 2 functions (control flow V2 functions, or tf.function()), 730 # we need to attach _xla_outside_compilation attribute directly because we are 731 # not in TPUReplicateContext. 732 if isinstance(graph, func_graph.FuncGraph): 733 try: 734 tpu_context, _ = _enclosing_tpu_context_and_graph() 735 except ValueError: 736 logging.warning( 737 "Outside compilation attempted outside TPUReplicateContext " 738 "scope. As no enclosing TPUReplicateContext can be found, " 739 "returning the result of `computation` as is.") 740 return computation(*args, **kwargs) 741 742 # pylint: disable=protected-access 743 outside_compilation_name = str(tpu_context._outside_compilation_counter) 744 tpu_context._outside_compilation_counter = ( 745 tpu_context._outside_compilation_counter + 1) 746 # pylint: enable=protected-access 747 748 outside_compilation_context = OutsideCompilationV2Context( 749 outside_compilation_name) 750 outside_compilation_context.Enter() 751 args = [] if args is None else args 752 retval = computation(*args, **kwargs) 753 outside_compilation_context.Exit() 754 return retval 755 756 # If we are in a TPUReplicateContext, signal that we are now 757 # outside_compilation 758 initial_context = graph._get_control_flow_context() # pylint: disable=protected-access 759 context = initial_context 760 while context: 761 if isinstance(context, TPUReplicateContext): 762 context._EnterOutsideCompilationScope() # pylint: disable=protected-access 763 context = context.outer_context 764 765 retval = computation(*args, **kwargs) 766 767 # If we are in a TPUReplicateContext, signal that we are no longer 768 # outside_compilation 769 final_context = graph._get_control_flow_context() # pylint: disable=protected-access 770 if initial_context is not final_context: 771 raise NotImplementedError( 772 "Control-flow context cannot be different at start and end of an " 773 "outside_compilation scope") 774 context = initial_context 775 while context: 776 if isinstance(context, TPUReplicateContext): 777 context._ExitOutsideCompilationScope() # pylint: disable=protected-access 778 context = context.outer_context 779 780 return retval 781 782 783@tf_export(v1=["tpu.replicate"]) 784def replicate(computation, 785 inputs=None, 786 infeed_queue=None, 787 device_assignment=None, 788 name=None, 789 maximum_shapes=None): 790 """Builds a graph operator that runs a replicated TPU computation. 791 792 Args: 793 computation: A Python function that builds the computation to replicate. 794 inputs: A list of lists of input tensors or `None` (equivalent to 795 `[[]]`), indexed by `[replica_num][input_num]`. All replicas must 796 have the same number of inputs. Each input can be a nested structure 797 containing values that are convertible to tensors. Note that passing an 798 N-dimension list of compatible values will result in a N-dimension list of 799 scalar tensors rather than a single Rank-N tensors. If you need different 800 behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. 801 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 802 of arguments as inputs to computation. 803 device_assignment: If not `None`, a `DeviceAssignment` describing the 804 mapping between logical cores in the computation with physical cores in 805 the TPU topology. Uses a default device assignment if `None`. The 806 `DeviceAssignment` may be omitted if each replica of the computation uses 807 only one core, and there is either only one replica, or the number of 808 replicas is equal to the number of cores in the TPU system. 809 name: (Deprecated) Does nothing. 810 maximum_shapes: A nested structure of tf.TensorShape representing the shape 811 to which the respective component of each input element in each replica 812 should be padded. Any unknown dimensions (e.g. 813 tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like 814 object) will be padded to the maximum size of that dimension over all 815 replicas. The structure of `maximum_shapes` needs to be the same as 816 `inputs[0]`. 817 Returns: 818 A list of outputs, indexed by `[replica_num]` each output can be a nested 819 structure same as what computation() returns with a few exceptions. 820 821 Exceptions include: 822 1) None output: a NoOp would be returned which control-depends on 823 computation. 824 2) Single value output: A tuple containing the value would be returned. 825 3) Operation-only outputs: a NoOp would be returned which 826 control-depends on computation. 827 TODO(b/121383831): Investigate into removing these special cases. 828 829 Raises: 830 ValueError: If all replicas do not have equal numbers of input tensors. 831 ValueError: If the number of inputs per replica does not match 832 the number of formal parameters to `computation`. 833 ValueError: If the static `inputs` dimensions don't match with the values 834 given in `maximum_shapes`. 835 ValueError: If the structure of inputs per replica does not match 836 the structure of `maximum_shapes`. 837 """ 838 return split_compile_and_replicate( 839 computation, 840 inputs, 841 infeed_queue, 842 device_assignment, 843 name, 844 maximum_shapes=maximum_shapes)[1] 845 846 847def _pad_all_input(inputs, padded_shapes): 848 """Pad all input tensors given padded_shapes. 849 850 The real shape tensors will be concatenated with the padded original inputs. 851 852 Args: 853 inputs: The original inputs. 854 padded_shapes: A list of padded shapes for each input. 855 856 Returns: 857 The padded inputs and a PaddingMap list which maps the padded input 858 dimension to the real shape argument index. 859 """ 860 # maximum_static_shapes[idx][i] indicates the maximum static size of ith 861 # dimension of the idx input among all the replicas. 862 maximum_static_shapes = [] 863 # need_padding[idx][i] indicates whether the ith dimension of the idx input 864 # needs padding. 865 need_padding = [] 866 input_shape_tensors = [] 867 for core_idx, inputs_per_core in enumerate(inputs): 868 for idx, input_tensor in enumerate(inputs_per_core): 869 input_shape = input_tensor.get_shape().as_list() 870 if core_idx == 0: 871 input_shape_tensors.append([]) 872 maximum_static_shapes.append(input_shape) 873 need_padding.append(np.full_like(input_shape, False, dtype=bool)) 874 else: 875 for i, s in enumerate(input_shape): 876 if not s or s != maximum_static_shapes[idx][i]: 877 need_padding[idx][i] = True 878 maximum_static_shapes[idx] = max(input_shape, 879 maximum_static_shapes[idx]) 880 881 # Append _POST_DEVICE_REWRITE_ATTR attributes to the real shape ops. 882 real_input_shape = array_ops.shape(input_tensor) 883 real_input_shape.op._set_attr( # pylint: disable=protected-access 884 _POST_DEVICE_REWRITE_ATTR, 885 attr_value_pb2.AttrValue(b=True)) 886 input_shape_tensors[idx].append(real_input_shape) 887 888 maximum_shapes = [] 889 for shapes_per_input in input_shape_tensors: 890 maximum_shapes.append( 891 math_ops.reduce_max(array_ops.stack(shapes_per_input), axis=0)) 892 893 padded_inputs = [] 894 real_shapes = [] 895 padding_maps = [] 896 for core_idx, inputs_per_core in enumerate(inputs): 897 padded_inputs.append([]) 898 real_shapes.append([]) 899 real_shape_idx = len(inputs_per_core) - 1 900 for idx, input_tensor in enumerate(inputs_per_core): 901 input_shape_tensor = input_shape_tensors[idx][core_idx] 902 input_shape = input_tensor.get_shape().as_list() 903 padded_shape = padded_shapes[idx] 904 905 if any(need_padding[idx]): 906 for i, s in enumerate(input_shape): 907 if need_padding[idx][i]: 908 if core_idx == 0: 909 real_shape_idx += 1 910 padding_map = dynamic_padding.PaddingMap() 911 padding_map.arg_index = idx 912 padding_map.shape_index = i 913 padding_map.padding_arg_index = real_shape_idx 914 padding_maps.append(padding_map) 915 real_shapes[core_idx].append( 916 math_ops.cast(input_shape_tensor[i], dtypes.int32)) 917 918 paddings = [] 919 for i, s in enumerate(padded_shape.dims): 920 if need_padding[idx][i]: 921 # The minimum padded dimension size is 2 as XLA doesn't support size 922 # 1 dynamic size. 923 minimum_dynamic_dim_size = 2 924 if s.value: 925 # Pad to the given maximum value. 926 max_dim_size = max(s.value, minimum_dynamic_dim_size) 927 else: 928 # If maximum value is not given, then pad to the maximum dimension 929 # among all the cores. 930 max_dim_size = math_ops.maximum(maximum_shapes[idx][i], 931 minimum_dynamic_dim_size) 932 # Pad to the given maximum value. 933 padding = [0, max_dim_size - input_shape_tensor[i]] 934 else: 935 padding = [0, 0] 936 paddings.append(padding) 937 938 if input_tensor.get_shape().is_fully_defined(): 939 # TODO(rxsang): This is a hack to make sure padded_input has dynamic 940 # shapes, so any tf.size/tf.shape op performed on it won't be constant 941 # folded. Do we have better ways to do it? 942 padded_input = control_flow_ops.cond( 943 array_ops.constant(True), 944 lambda: array_ops.pad(input_tensor, paddings), # pylint: disable=cell-var-from-loop 945 lambda: input_tensor) 946 else: 947 padded_input = array_ops.pad(input_tensor, paddings) 948 949 # Append _POST_DEVICE_REWRITE_ATTR attributes to all padded inputs. 950 padded_input.op._set_attr( # pylint: disable=protected-access 951 _POST_DEVICE_REWRITE_ATTR, 952 attr_value_pb2.AttrValue(b=True)) 953 954 padded_inputs[core_idx].append(padded_input) 955 else: 956 padded_inputs[core_idx].append(input_tensor) 957 958 num_replicas = len(padded_inputs) 959 for i in range(num_replicas): 960 padded_inputs[i].extend(real_shapes[i]) 961 962 return padded_inputs, padding_maps 963 964 965def split_compile_and_replicate(computation, 966 inputs=None, 967 infeed_queue=None, 968 device_assignment=None, 969 name=None, 970 use_tpu=True, 971 maximum_shapes=None): 972 """Builds graph operators that runs compilation and replicated computation. 973 974 This is a lower level interface than replicate that returns a separate compile 975 and execute output tensor. In the generated graph the compile op feeds into 976 the execute op and no additional compilation is incurred when running the 977 compile op before the execute op. The compile op returns additional 978 information about the compilation but does not return the compiled program. 979 980 Args: 981 computation: A Python function that builds the computation to replicate. 982 inputs: A list of lists of input tensors or `None` (equivalent to 983 `[[]]`), indexed by `[replica_num][input_num]`. All replicas must 984 have the same number of inputs. Each input can be a nested structure 985 containing values that are convertible to tensors. Note that passing an 986 N-dimension list of compatible values will result in a N-dimension list of 987 scalar tensors rather than a single Rank-N tensors. If you need different 988 behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. 989 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 990 of arguments as inputs to computation. 991 device_assignment: If not `None`, a `DeviceAssignment` describing the 992 mapping between logical cores in the computation with physical cores in 993 the TPU topology. Uses a default device assignment if `None`. The 994 `DeviceAssignment` may be omitted if each replica of the computation uses 995 only one core, and there is either only one replica, or the number of 996 replicas is equal to the number of cores in the TPU system. 997 name: (Deprecated) Does nothing. 998 use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU 999 backends. Currently, only supports a default placement (computation is 1000 placed on GPU if one is available, and on CPU if not). 1001 maximum_shapes: A nested structure of tf.TensorShape representing the shape 1002 to which the respective component of each input element in each replica 1003 should be padded. Any unknown dimensions (e.g. 1004 tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like 1005 object) will be padded to the maximum size of that dimension over all 1006 replicas. The structure of `maximum_shapes` needs to be the same as 1007 `inputs[0]`. 1008 1009 Returns: 1010 A list of lists with the first list corresponding to the compile op and the 1011 second a list of output tensors, indexed by `[replica_num][output_num]`. 1012 Raises: 1013 ValueError: If all replicas do not have equal numbers of input tensors. 1014 ValueError: If the number of inputs per replica does not match 1015 the number of formal parameters to `computation`. 1016 ValueError: If the static `inputs` dimensions don't match with the values 1017 given in `maximum_shapes`. 1018 ValueError: If the structure of inputs per replica does not match 1019 the structure of `maximum_shapes`. 1020 """ 1021 del name 1022 inputs = [[]] if inputs is None else inputs 1023 1024 metadata_kwargs = {} 1025 if device_assignment is not None: 1026 # Turn the Numpy array into a flattened list so we can pass it as an 1027 # operator attribute. 1028 metadata_kwargs = { 1029 "topology": 1030 device_assignment.topology.serialized(), 1031 "device_assignment": 1032 device_assignment.core_assignment.flatten().tolist() 1033 } 1034 metadata_kwargs["num_cores_per_replica"] = ( 1035 device_assignment.num_cores_per_replica) 1036 # This entry is used for enabling automatic outside compilation. 1037 metadata_kwargs["allow_soft_placement"] = config.get_soft_device_placement() 1038 1039 if ((not isinstance(inputs, list)) or 1040 any(not isinstance(inp, (list, tuple)) for inp in inputs)): 1041 raise TypeError("tpu.replicate() inputs must be a list of lists/tuples") 1042 1043 num_replicas = len(inputs) 1044 1045 # No replicas? Nothing to do. 1046 if num_replicas == 0: 1047 return [] 1048 1049 # Checks all replicas have the same structure. 1050 for i in xrange(1, num_replicas): 1051 nest.assert_same_structure(inputs[0], inputs[i]) 1052 1053 # Flatten inputs. 1054 flat_inputs = [ 1055 nest.flatten(per_replica_input) for per_replica_input in inputs 1056 ] 1057 # Converts inputs to Tensors. 1058 flat_inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in flat_inputs] 1059 1060 # Verifies that all replicas have matching numbers and types of inputs 1061 flat_input_types = [x.dtype for x in flat_inputs[0]] 1062 input_arity = len(inputs[0]) 1063 flat_input_arity = len(flat_input_types) 1064 for i in range(num_replicas): 1065 if len(inputs[i]) != input_arity: 1066 raise ValueError("Replicas must have the same number of inputs. " 1067 "Replica 0 had {} inputs, replica {} had {} " 1068 "inputs.".format(input_arity, i, len(inputs[i]))) 1069 1070 types = [x.dtype for x in flat_inputs[i]] 1071 if types != flat_input_types: 1072 raise ValueError("Replicas must have matching input types. Replica 0 had " 1073 "input types {}, replica {} had input types {}".format( 1074 flat_input_types, i, types)) 1075 1076 arg_error = xla.check_function_argument_count( 1077 computation, input_arity, infeed_queue) 1078 if arg_error is not None: 1079 if infeed_queue is None: 1080 raise TypeError( 1081 "Supplied computation cannot be called with the specified inputs. " 1082 "You specified %d inputs: %s, but the computation needs %s" % ( 1083 input_arity, str([i.name for i in inputs[0]]), arg_error)) 1084 else: 1085 raise TypeError( 1086 "Supplied computation cannot be called with the specified inputs. " 1087 "You specified %d inputs: %s and %d additional inputs from infeed," 1088 " but the computation needs %s" % (input_arity, str( 1089 [i.name 1090 for i in inputs[0]]), infeed_queue.number_of_tuple_elements, 1091 arg_error)) 1092 1093 if maximum_shapes: 1094 if infeed_queue: 1095 raise ValueError( 1096 "Dynamic input shapes are not supported with infeed queues") 1097 1098 # Make sure maximum_shapes has the same structure as inputs. 1099 nest.assert_same_structure(inputs[0], maximum_shapes, check_types=False) 1100 1101 # Flatten padded shapes. 1102 flat_maximum_shapes = nest.flatten(maximum_shapes) 1103 flat_maximum_shapes = [ 1104 tensor_shape.TensorShape(s) for s in flat_maximum_shapes 1105 ] 1106 1107 flat_inputs, padding_maps = _pad_all_input(flat_inputs, flat_maximum_shapes) 1108 1109 serialized_padding_maps = [] 1110 for padding_map in padding_maps: 1111 serialized_padding_maps.append(padding_map.SerializeToString()) 1112 metadata_kwargs["padding_map"] = serialized_padding_maps 1113 1114 metadata_kwargs["step_marker_location"] = getattr( 1115 computation, "step_marker_location", "STEP_MARK_AT_ENTRY") 1116 1117 graph = ops.get_default_graph() 1118 1119 # Fan-in: Builds a TPUReplicatedInput node for each input. 1120 flat_replicated_inputs = [] 1121 for i in range(0, len(flat_inputs[0])): 1122 replicas = [flat_inputs[replica][i] for replica in xrange(num_replicas)] 1123 flat_replicated_inputs.append( 1124 tpu_ops.tpu_replicated_input( 1125 replicas, name="input{}".format(i), index=i)) 1126 if isinstance(graph, func_graph.FuncGraph): 1127 # When we are in Tensorflow 2.0 function, 'graph' will be a FuncGraph 1128 # object. If both outside graph and this function have a TPU cluster, 1129 # they will have the same cluster name and it will cause problems (because 1130 # we lower functional ops in Tensorflow 2.0). Append function name to 1131 # 'cluster_name' to avoid cluster name collision. 1132 cluster_name = graph.unique_name("cluster_" + graph.name) 1133 else: 1134 cluster_name = graph.unique_name("cluster") 1135 pivot = control_flow_ops.no_op(name=cluster_name + "/pivot") 1136 context = TPUReplicateContext( 1137 name=cluster_name, num_replicas=num_replicas, pivot=pivot) 1138 try: 1139 context.Enter() 1140 1141 metadata = tpu_ops.tpu_replicate_metadata( 1142 num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs) 1143 1144 with tpu_function.tpu_shard_context( 1145 num_replicas), ops.control_dependencies([metadata]): 1146 1147 # Add identity ops so even unused inputs are "consumed" by the 1148 # computation. This is to avoid orphaned TPUReplicatedInput nodes. 1149 # TODO(phawkins): consider instead pruning unused TPUReplicatedInput 1150 # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs. 1151 flat_replicated_inputs = [ 1152 array_ops.identity(x, name="replicated_input_{}".format(i)) 1153 for i, x in enumerate(flat_replicated_inputs) 1154 ] 1155 for i in flat_replicated_inputs: 1156 # pylint: disable=protected-access 1157 # Add an attribute to the identity node so that they could be removed in 1158 # encapsulate TPU computation pass if unused. However we don't remove 1159 # inputs when dynamic padding is enabled. 1160 # TODO(rxsang): Use other ways except argument index in padding_map so 1161 # outside compilation can work with dynamic padding correctly. 1162 if maximum_shapes is None: 1163 i.op._set_attr("_tpu_input_identity", 1164 attr_value_pb2.AttrValue(b=True)) 1165 # pylint: enable=protected-access 1166 1167 # Unflatten the computation inputs to match original input structure. 1168 computation_inputs = nest.pack_sequence_as( 1169 structure=inputs[0], 1170 flat_sequence=flat_replicated_inputs[:flat_input_arity]) 1171 1172 # If there is an infeed queue, adds the dequeued values to the 1173 # computation's inputs. 1174 if infeed_queue is not None: 1175 infeed_queue.set_number_of_shards(num_replicas) 1176 for t in infeed_queue.generate_dequeue_op(): 1177 computation_inputs.append(t) 1178 1179 # Only resource variables work inside a TPU computation, so turn on 1180 # resource variables for the computation. 1181 # TODO(phawkins): consider removing this code. It will 1182 # be less confusing to clients if they knowingly choose to use resource 1183 # variables. 1184 # Partitioned variables is not supported (b/112311320). 1185 vscope = variable_scope.get_variable_scope() 1186 saved_use_resource = vscope.use_resource 1187 saved_custom_getter = vscope.custom_getter 1188 1189 def custom_getter(getter, name, *args, **kwargs): 1190 """Variables on TPU have a few restrictions.""" 1191 partitioner = kwargs["partitioner"] 1192 if partitioner is not None: 1193 kwargs["partitioner"] = None 1194 logging.warning( 1195 "Partitioned variables are not supported on TPU. Got " 1196 "`partitioner` that is {} for variable {}. " 1197 "Setting `partitioner` to `None`." 1198 .format(partitioner, name)) 1199 if saved_custom_getter is None: 1200 return getter(name, *args, **kwargs) 1201 else: 1202 return saved_custom_getter(getter, name, *args, **kwargs) 1203 1204 vscope.set_use_resource(True) 1205 vscope.set_custom_getter(custom_getter) 1206 1207 outputs = computation(*computation_inputs) 1208 1209 vscope.set_use_resource(saved_use_resource) 1210 vscope.set_custom_getter(saved_custom_getter) 1211 1212 outputs_is_flat = xla.is_flat(outputs) 1213 if outputs_is_flat: 1214 output_tensors, control_deps = _postprocess_flat_outputs(outputs) 1215 else: 1216 output_tensors, control_deps = _postprocess_non_flat_outputs(outputs) 1217 1218 # tensor_tracer imports tpu.py. Local import to tensor_tracer to avoid 1219 # import-cycle 1220 # pylint: disable=g-import-not-at-top 1221 from tensorflow.python.tpu import tensor_tracer 1222 # pylint: enable=g-import-not-at-top 1223 if tensor_tracer.TensorTracer.is_enabled(): 1224 tt = tensor_tracer.TensorTracer() 1225 output_tensors = tt.trace_tpu(ops.get_default_graph(), 1226 output_tensors, control_deps, 1227 num_replicas) 1228 1229 context.ExitResult(output_tensors) 1230 finally: 1231 context.report_unsupported_operations() 1232 context.Exit() 1233 host_compute_core = context.HostComputeCore() 1234 1235 if host_compute_core: 1236 attr_value = attr_value_pb2.AttrValue() 1237 attr_value.list.s.extend(compat.as_bytes(x) for x in host_compute_core) 1238 metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access 1239 1240 with ops.control_dependencies([metadata]): 1241 if use_tpu: 1242 compile_status = tpu_ops.tpu_compilation_result() 1243 op = compile_status.op 1244 attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)) 1245 op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access 1246 else: 1247 compile_status = control_flow_ops.no_op(name="compilation_status") 1248 1249 if not output_tensors: 1250 # Returns a list of NoOps dependent on the replication Op, indexed by 1251 # [replica_num]. 1252 return [ 1253 compile_status, 1254 [ 1255 control_flow_ops.group(control_deps, name="shard_%d" % i) 1256 for i in range(num_replicas) 1257 ] 1258 ] 1259 1260 # Fan-out: Builds a TPUReplicatedOutput node for each output. 1261 replicated_outputs = [[] for i in xrange(num_replicas)] 1262 for i, t in enumerate(output_tensors): 1263 # Fan-out: Builds a TPUReplicatedOutput node for each output. 1264 ys = tpu_ops.tpu_replicated_output( 1265 t, num_replicas, name="output{}".format(i)) 1266 1267 # Wraps the outputs in identity operators so the names of any possible 1268 # `fetch` nodes are preserved by the replication rewrite. 1269 with ops.control_dependencies(control_deps): 1270 for replica in xrange(num_replicas): 1271 replicated_outputs[replica].append( 1272 array_ops.identity( 1273 ys[replica], name="output_%d_shard_%d" % (i, replica))) 1274 1275 if not outputs_is_flat: 1276 replicated_outputs = [ 1277 nest.pack_sequence_as(outputs, replica_outs) 1278 for replica_outs in replicated_outputs 1279 ] 1280 1281 return [compile_status, replicated_outputs] 1282 1283 1284def _postprocess_flat_outputs(outputs): 1285 """Validates non-flat outputs, add backs device assignments and other attrs. 1286 1287 Args: 1288 outputs: Output from `computation` inside `tpu.rewrite`. 1289 1290 Returns: 1291 Tensors and Operations extracted from outputs. 1292 """ 1293 # Following code segment is to preserve legacy behavior. Previously we only 1294 # supported flat outputs and thus for consistency it was nice to convert even 1295 # single element into a tuple. But now that we support arbitrary output 1296 # structure, this is no longer necessary. 1297 # TODO(b/121383831): Migrate all legacy use cases and delete this special 1298 # case. 1299 # If the computation returns `None`, make it an empty tuple. 1300 if outputs is None: 1301 outputs = tuple() 1302 # If the computation only returned one value, makes it a tuple. 1303 if not isinstance(outputs, collections_abc.Sequence): 1304 outputs = (outputs,) 1305 1306 # Append `no_op` here so that fetching any return value of this function 1307 # will trigger TPUExecute node. 1308 outputs += (control_flow_ops.no_op(),) 1309 try: 1310 with ops.device(core(0)): 1311 outputs = [ 1312 o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) 1313 for o in outputs 1314 ] 1315 except Exception as e: 1316 raise ValueError( 1317 "TPU function return values must all either be Operations or " 1318 "convertible to Tensors. Got '%s'" % str(e)) 1319 1320 # Separates the returned Operations and Tensors. 1321 output_operations = [o for o in outputs if isinstance(o, ops.Operation)] 1322 output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] 1323 1324 if outputs != output_tensors + output_operations: 1325 raise ValueError( 1326 "TPU functions must return zero-or more Tensor values followed by " 1327 "zero or more Operations.") 1328 1329 # Wraps outputs in Identity ops. Otherwise a replicated input copied 1330 # straight to an output would bypass the replicate(). This would be bad 1331 # because the TPUReplicatedInput/TPUReplicatedOutput operator would not 1332 # be rewritten away, leading to a runtime error. 1333 # TODO(phawkins): extend the rewrite to elide these nodes instead. 1334 new_output_tensors = [] 1335 for t in output_tensors: 1336 with ops.device(t.device if t.device else core(0)): 1337 o = array_ops.identity(t) 1338 # pylint: disable=protected-access 1339 o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) 1340 # pylint: enable=protected-access 1341 new_output_tensors.append(o) 1342 return new_output_tensors, output_operations 1343 1344 1345def _postprocess_non_flat_outputs(outputs): 1346 """Validates non-flat outputs, add backs device assignments and other attrs. 1347 1348 Args: 1349 outputs: Output from `computation` inside `tpu.rewrite`. 1350 1351 Returns: 1352 Tensors extracted from outputs and an empty list because Operations are not 1353 allowed in non-flat outputs.. 1354 """ 1355 1356 # Flatten output items. 1357 flat_outputs = nest.flatten(outputs) 1358 1359 # Convert all non-Operation outputs to Tensors. 1360 for i, o in enumerate(flat_outputs): 1361 if isinstance(o, ops.Operation): 1362 raise ValueError( 1363 "tpu.rewrite does not support Operation as return value in non-flat " 1364 "output structure. You can set returned Operations as control " 1365 "dependencies of returned Tensors so Operations are triggered when " 1366 'Tensors are evaluated. Operation found: "%s"' % o.name) 1367 1368 try: 1369 o = ops.convert_to_tensor(o) 1370 except Exception as e: 1371 raise ValueError( 1372 "TPU function return values must all either be Operations or " 1373 'convertible to Tensors. Got error: "%s"' % str(e)) 1374 1375 # Wraps outputs in Identity ops. Otherwise a replicated input copied 1376 # straight to an output would bypass the replicate(). This would be bad 1377 # because the TPUReplicatedInput/TPUReplicatedOutput operator would not 1378 # be rewritten away, leading to a runtime error. 1379 # TODO(phawkins): extend the rewrite to elide these nodes instead. 1380 with ops.device(core(0)): 1381 o = array_ops.identity(o) 1382 # pylint: disable=protected-access 1383 o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) 1384 # pylint: enable=protected-access 1385 flat_outputs[i] = array_ops.identity(o) 1386 1387 # All flat_outputs are Tensors, and no Operations. 1388 return flat_outputs, [] 1389 1390 1391def split_compile_and_shard(computation, 1392 inputs=None, 1393 num_shards=1, 1394 input_shard_axes=None, 1395 outputs_from_all_shards=True, 1396 output_shard_axes=None, 1397 infeed_queue=None, 1398 device_assignment=None, 1399 name=None): 1400 """Shards `computation` for parallel execution. 1401 1402 `inputs` must be a list of Tensors or None (equivalent to an empty list), each 1403 of which has a corresponding split axis (from `input_shard_axes`). Each input 1404 is split into `num_shards` pieces along the corresponding axis, and 1405 computation is applied to each shard in parallel. 1406 1407 Tensors are broadcast to all shards if they are lexically captured by 1408 `computation`. e.g., 1409 1410 x = tf.constant(7) 1411 def computation(): 1412 return x + 3 1413 ... = shard(computation, ...) 1414 1415 If `outputs_from_all_shards` is true, the outputs from all shards of 1416 `computation` are concatenated back together along their `output_shard_axes`. 1417 Otherwise, each output is taken from an arbitrary shard. 1418 1419 Inputs and outputs of the computation must be at least rank-1 Tensors. 1420 1421 Args: 1422 computation: A Python function that builds a computation to apply to each 1423 shard of the input. 1424 inputs: A list of input tensors or None (equivalent to an empty list). Each 1425 input tensor has a corresponding shard axes, given by `input_shard_axes`, 1426 which must have size divisible by `num_shards`. 1427 num_shards: The number of shards. 1428 input_shard_axes: A list of dimensions along which to shard `inputs`, or 1429 `None`. `None` means "shard all inputs along dimension 0". If not `None`, 1430 there must be one dimension per input. 1431 outputs_from_all_shards: Boolean or list of boolean. For each output, if 1432 `True`, outputs from all shards are concatenated along the corresponding 1433 `output_shard_axes` entry. Otherwise, each output is taken 1434 from an arbitrary shard. If the argument is a boolean, the argument's 1435 value is used for each output. 1436 output_shard_axes: A list of dimensions along which to concatenate the 1437 outputs of `computation`, or `None`. `None` means "concatenate all outputs 1438 along dimension 0". If not `None`, there must be one dimension per output. 1439 Ignored if `outputs_from_all_shards` is False. 1440 infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs 1441 of `computation`. 1442 device_assignment: If not `None`, a `DeviceAssignment` describing the 1443 mapping between logical cores in the computation with physical cores in 1444 the TPU topology. Uses a default device assignment if `None`. The 1445 `DeviceAssignment` may be omitted if each shard of the computation uses 1446 only one core, and there is either only one shard, or the number of shards 1447 is equal to the number of cores in the TPU system. 1448 name: (Deprecated) Does nothing. 1449 Returns: 1450 A tuple of (compile op, [output tensors]). 1451 Raises: 1452 ValueError: If num_shards <= 0 1453 ValueError: If len(input_shard_axes) != len(inputs) 1454 ValueError: If len(output_shard_axes) != len(outputs from `computation`) 1455 """ 1456 # TODO(phawkins): consider adding support for broadcasting Tensors passed as 1457 # inputs. 1458 1459 if num_shards <= 0: 1460 raise ValueError("num_shards must be a positive integer.") 1461 1462 inputs = [] if inputs is None else inputs 1463 if not isinstance(inputs, list): 1464 raise TypeError("tpu.shard()'s inputs must be a list of Tensors or None.") 1465 1466 # Converts inputs to Tensors. 1467 inputs = [ops.convert_to_tensor(x) for x in inputs] 1468 1469 if input_shard_axes is None: 1470 input_shard_axes = [0] * len(inputs) 1471 if len(inputs) != len(input_shard_axes): 1472 raise ValueError("Length of input_shard_axes must be equal to the number " 1473 "of inputs.") 1474 1475 if inputs: 1476 # Splits the `inputs` along the corresponding `input_shard_axes`, giving 1477 # lists with layout [input][shard] 1478 split_inputs = [ 1479 array_ops.split(x, num_shards, axis=axis) 1480 for (axis, x) in zip(input_shard_axes, inputs)] 1481 1482 # Transposes the input lists to have layout [shard][input] 1483 transposed_inputs = [list(i) for i in zip(*split_inputs)] 1484 else: 1485 transposed_inputs = [[]] * num_shards 1486 1487 compile_op, outputs = split_compile_and_replicate( 1488 computation, 1489 transposed_inputs, 1490 infeed_queue=infeed_queue, 1491 device_assignment=device_assignment, 1492 name=name) 1493 1494 # There must be at least one shard since num_shards > 0. 1495 # TODO(b/36647078) remove disable when pylint bug is fixed. 1496 # pylint: disable=indexing-exception 1497 if isinstance(outputs[0], ops.Operation): 1498 # pylint: enable=indexing-exception 1499 # There were no outputs from the computation and replicate returned a list 1500 # of NoOps with control dependencies on the computation. Return the first 1501 # one so it can be used as a control dependency or fetch node. 1502 # TODO(b/36647078) remove disable when pylint bug is fixed. 1503 # pylint: disable=indexing-exception 1504 return compile_op, [outputs[0]] 1505 # pylint: enable=indexing-exception 1506 1507 # TODO(b/36647078) remove disable when pylint bug is fixed. 1508 # pylint: disable=indexing-exception 1509 num_outputs = len(outputs[0]) 1510 # pylint: enable=indexing-exception 1511 1512 if output_shard_axes is None: 1513 output_shard_axes = [0] * num_outputs 1514 if num_outputs != len(output_shard_axes): 1515 raise ValueError("Length of output_shard_axes must be equal to the number " 1516 "of outputs.") 1517 1518 if isinstance(outputs_from_all_shards, bool): 1519 outputs_from_all_shards = [outputs_from_all_shards] * num_outputs 1520 1521 if num_outputs != len(outputs_from_all_shards): 1522 raise ValueError("Length of outputs_from_all_shards must be equal to the " 1523 "number of outputs.") 1524 1525 results = [] 1526 for (axis, all_shards, x) in zip(output_shard_axes, outputs_from_all_shards, 1527 zip(*outputs)): 1528 if all_shards: 1529 # Concatenate all of the outputs together (use stack for scalars). 1530 shape = x[0].shape 1531 is_scalar = shape is not None and (shape.ndims == 0) 1532 results.append((array_ops.stack(list(x)) if is_scalar 1533 else array_ops.concat(list(x), axis=axis))) 1534 else: 1535 # TODO(phawkins): use a smarter policy, e.g., round-robin across shards. 1536 results.append(x[0]) 1537 1538 return compile_op, results 1539 1540 1541@tf_export(v1=["tpu.shard"]) 1542def shard(computation, 1543 inputs=None, 1544 num_shards=1, 1545 input_shard_axes=None, 1546 outputs_from_all_shards=True, 1547 output_shard_axes=None, 1548 infeed_queue=None, 1549 device_assignment=None, 1550 name=None): 1551 """Shards `computation` for parallel execution. 1552 1553 `inputs` must be a list of Tensors or None (equivalent to an empty list), each 1554 of which has a corresponding split axis (from `input_shard_axes`). Each input 1555 is split into `num_shards` pieces along the corresponding axis, and 1556 computation is applied to each shard in parallel. 1557 1558 Tensors are broadcast to all shards if they are lexically captured by 1559 `computation`. e.g., 1560 1561 x = tf.constant(7) 1562 def computation(): 1563 return x + 3 1564 ... = shard(computation, ...) 1565 1566 TODO(phawkins): consider adding support for broadcasting Tensors passed 1567 as inputs. 1568 1569 If `outputs_from_all_shards` is true, the outputs from all shards of 1570 `computation` are concatenated back together along their `output_shard_axes`. 1571 Otherwise, each output is taken from an arbitrary shard. 1572 1573 Inputs and outputs of the computation must be at least rank-1 Tensors. 1574 1575 Args: 1576 computation: A Python function that builds a computation to apply to each 1577 shard of the input. 1578 inputs: A list of input tensors or None (equivalent to an empty list). Each 1579 input tensor has a corresponding shard axes, given by `input_shard_axes`, 1580 which must have size divisible by `num_shards`. 1581 num_shards: The number of shards. 1582 input_shard_axes: A list of dimensions along which to shard `inputs`, or 1583 `None`. `None` means "shard all inputs along dimension 0". If not `None`, 1584 there must be one dimension per input. 1585 outputs_from_all_shards: Boolean or list of boolean. For each output, if 1586 `True`, outputs from all shards are concatenated along the corresponding 1587 `output_shard_axes` entry. Otherwise, each output is taken 1588 from an arbitrary shard. If the argument is a boolean, the argument's 1589 value is used for each output. 1590 output_shard_axes: A list of dimensions along which to concatenate the 1591 outputs of `computation`, or `None`. `None` means "concatenate all outputs 1592 along dimension 0". If not `None`, there must be one dimension per output. 1593 Ignored if `outputs_from_all_shards` is False. 1594 infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs 1595 of `computation`. 1596 device_assignment: If not `None`, a `DeviceAssignment` describing the 1597 mapping between logical cores in the computation with physical cores in 1598 the TPU topology. Uses a default device assignment if `None`. The 1599 `DeviceAssignment` may be omitted if each shard of the computation uses 1600 only one core, and there is either only one shard, or the number of shards 1601 is equal to the number of cores in the TPU system. 1602 name: (Deprecated) Does nothing. 1603 Returns: 1604 A list of output tensors. 1605 Raises: 1606 ValueError: If num_shards <= 0 1607 ValueError: If len(input_shard_axes) != len(inputs) 1608 ValueError: If len(output_shard_axes) != len(outputs from `computation`) 1609 """ 1610 return split_compile_and_shard( 1611 computation, 1612 inputs=inputs, 1613 num_shards=num_shards, 1614 input_shard_axes=input_shard_axes, 1615 outputs_from_all_shards=outputs_from_all_shards, 1616 output_shard_axes=output_shard_axes, 1617 infeed_queue=infeed_queue, 1618 device_assignment=device_assignment, 1619 name=name)[1] 1620 1621 1622@tf_export(v1=["tpu.batch_parallel"]) 1623def batch_parallel(computation, 1624 inputs=None, 1625 num_shards=1, 1626 infeed_queue=None, 1627 device_assignment=None, 1628 name=None): 1629 """Shards `computation` along the batch dimension for parallel execution. 1630 1631 Convenience wrapper around shard(). 1632 1633 `inputs` must be a list of Tensors or None (equivalent to an empty list). 1634 Each input is split into `num_shards` pieces along the 0-th dimension, and 1635 computation is applied to each shard in parallel. 1636 1637 Tensors are broadcast to all shards if they are lexically captured by 1638 `computation`. e.g., 1639 1640 x = tf.constant(7) 1641 def computation(): 1642 return x + 3 1643 ... = shard(computation, ...) 1644 1645 The outputs from all shards are concatenated back together along their 0-th 1646 dimension. 1647 1648 Inputs and outputs of the computation must be at least rank-1 Tensors. 1649 1650 Args: 1651 computation: A Python function that builds a computation to apply to each 1652 shard of the input. 1653 inputs: A list of input tensors or None (equivalent to an empty list). The 1654 0-th dimension of each Tensor must have size divisible by `num_shards`. 1655 num_shards: The number of shards. 1656 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 1657 of arguments as inputs to `computation`. 1658 device_assignment: If not `None`, a `DeviceAssignment` describing the 1659 mapping between logical cores in the computation with physical cores in 1660 the TPU topology. Uses a default device assignment if `None`. The 1661 `DeviceAssignment` may be omitted if each shard of the computation uses 1662 only one core, and there is either only one shard, or the number of shards 1663 is equal to the number of cores in the TPU system. 1664 name: (Deprecated) Does nothing. 1665 Returns: 1666 A list of output tensors. 1667 Raises: 1668 ValueError: If `num_shards <= 0` 1669 """ 1670 return shard( 1671 computation, 1672 inputs, 1673 num_shards=num_shards, 1674 infeed_queue=infeed_queue, 1675 device_assignment=device_assignment, 1676 name=name) 1677 1678 1679@tf_export(v1=["tpu.rewrite"]) 1680def rewrite(computation, 1681 inputs=None, 1682 infeed_queue=None, 1683 device_assignment=None, 1684 name=None): 1685 """Rewrites `computation` for execution on a TPU system. 1686 1687 Args: 1688 computation: A Python function that builds a computation to apply to the 1689 input. If the function takes n inputs, 'inputs' should be a list of n 1690 tensors. 1691 1692 `computation` may return a list of operations and tensors. Tensors must 1693 come before operations in the returned list. The return value of 1694 `rewrite` is a list of tensors corresponding to the tensors from the 1695 output of `computation`. 1696 1697 All `Operation`s constructed during `computation` will be executed when 1698 evaluating any of the returned output tensors, not just the ones returned. 1699 inputs: A list of input tensors or `None` (equivalent to an empty list). 1700 Each input can be a nested structure containing values that are 1701 convertible to tensors. Note that passing an N-dimension list of 1702 compatible values will result in a N-dimention list of scalar tensors 1703 rather than a single Rank-N tensors. If you need different behavior, 1704 convert part of inputs to tensors with `tf.convert_to_tensor`. 1705 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 1706 of arguments as inputs to `computation`. 1707 device_assignment: if not `None`, a `DeviceAssignment` describing the 1708 mapping between logical cores in the computation with physical cores in 1709 the TPU topology. May be omitted for a single-core computation, in which 1710 case the core attached to task 0, TPU device 0 is used. 1711 name: (Deprecated) Does nothing. 1712 Returns: 1713 Same data structure as if computation(*inputs) is called directly with some 1714 exceptions for correctness. Exceptions include: 1715 1) None output: a NoOp would be returned which control-depends on 1716 computation. 1717 2) Single value output: A tuple containing the value would be returned. 1718 3) Operation-only outputs: a NoOp would be returned which 1719 control-depends on computation. 1720 TODO(b/121383831): Investigate into removing these special cases. 1721 """ 1722 # TODO(b/36647078) remove disable when pylint bug is fixed. 1723 # pylint: disable=indexing-exception 1724 return replicate( 1725 computation, 1726 None if inputs is None else [inputs], 1727 infeed_queue=infeed_queue, 1728 device_assignment=device_assignment, 1729 name=name)[0] 1730 # pylint: enable=indexing-exception 1731 1732 # Operations that indicate some error in the user's inference graph. 1733_BLACKLISTED_INFERENCE_OPS = set([ 1734 "ReadVariableOp", 1735 "AssignVariableOp", 1736 "AssignAddVariableOp", 1737 "AssignSubVariableOp", 1738 "VarHandleOp", 1739 "Variable", 1740 "VariableV2", 1741]) 1742 1743 1744def under_tpu_inference_context(): 1745 """Check if it is currently under `_TPUInferenceContext`.""" 1746 graph = ops.get_default_graph() 1747 while graph: 1748 context = graph._get_control_flow_context() # pylint: disable=protected-access 1749 while context: 1750 if isinstance(context, _TPUInferenceContext): 1751 return True 1752 context = context.outer_context 1753 if isinstance(graph, function._FuncGraph): # pylint: disable=protected-access 1754 graph = graph._outer_graph # pylint: disable=protected-access 1755 elif isinstance(graph, func_graph.FuncGraph): 1756 graph = graph.outer_graph 1757 else: 1758 return False 1759 1760 1761class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext): 1762 """A `ControlFlowContext` for nodes inside a TPU inference computation. 1763 1764 The primary role of `_TPUInferenceContext` is to indicate the mode of 1765 operation and possibly sanity check operators inside a 1766 tpu.rewrite_for_inference() computation. 1767 """ 1768 1769 def __init__(self, name, check_ops=True): 1770 super(_TPUInferenceContext, self).__init__() 1771 self._name = name 1772 self._check_ops = check_ops 1773 1774 def AddOp(self, op): 1775 self._AddOpInternal(op) 1776 1777 def _AddOpInternal(self, op): 1778 # pylint: disable=protected-access 1779 if self._check_ops and op.type in _BLACKLISTED_INFERENCE_OPS: 1780 raise NotImplementedError( 1781 "Operation of type %s (%s) is not supported on the TPU for inference." 1782 " Execution will fail if this op is used in the graph. Make sure your" 1783 " variables are using variable_scope." % (op.type, op.name)) 1784 if self._outer_context: 1785 self._outer_context.AddInnerOp(op) 1786 1787 def AddValue(self, val): 1788 result = val 1789 if self._outer_context: 1790 result = self._outer_context.AddValue(val) 1791 return result 1792 1793 def AddInnerOp(self, op): 1794 self._AddOpInternal(op) 1795 1796 @property 1797 def grad_state(self): 1798 return None 1799 1800 1801def validate_inference_rewrite_for_variables(graph): 1802 """Validates whether rewrite_for_inference() 'worked' for variables. 1803 1804 The rewrite_for_inference() method is supposed to append GuaranteeConstOps 1805 after ReadVariableOps, but this mechanism works only if you are using 1806 tf.compat.v1.get_variable() to create and access variables in your tpu 1807 computation. This validation method can be called immediately after calling 1808 tpu.rewrite_for_inference() to check whether GuaranteeConstOps where added 1809 to the graph. 1810 1811 Typical usages: 1812 tpu.validate_inference_rewrite_for_variables( 1813 tf.compat.v1.get_default_graph()) 1814 1815 tpu.validate_inference_rewrite_for_variables(sess.graph) 1816 1817 Args: 1818 graph: The graph which needs to be validated. 1819 Raises: 1820 RuntimeError: if validation failed. 1821 """ 1822 if not any(x.type == "GuaranteeConst" for x in graph.get_operations()): 1823 raise RuntimeError( 1824 "No GuaranteeConst ops found in the graph after running " 1825 "tpu.rewrite_for_inference(...). Please check that you are using " 1826 "tf.get_variable() to create and access variables in your tpu " 1827 "computation.") 1828 1829 1830def rewrite_for_inference(computation, 1831 inputs=None, 1832 infeed_queue=None, 1833 device_assignment=None, 1834 name=None): 1835 """Rewrites `computation` for inference on a TPU system. 1836 1837 Other than 'rewriting' the computation to run on a TPU, if using variables 1838 in your computation, it moves the ReadVariableOps outside the TPU 1839 computation, and adds GuaranteeConst ops just after the ReadVariableOps. 1840 This mechanism works only if you are using tf.compat.v1.get_variable() to 1841 create and access variables in your tpu computation. You can validate 1842 whether this worked, by calling validate_inference_rewrite_for_variables() 1843 method immediately after this method to check whether GuaranteeConstOps 1844 where added to the graph. 1845 1846 Args: 1847 computation: A Python function that builds a computation to apply to the 1848 input. If the function takes n inputs, 'inputs' should be a list of n 1849 tensors. If the function returns m outputs, rewrite will return a list of 1850 m tensors. 1851 inputs: A list of input tensors or `None` (equivalent to an empty list). 1852 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 1853 of arguments as inputs to `computation`. 1854 device_assignment: if not `None`, a `DeviceAssignment` describing the 1855 mapping between logical cores in the computation with physical cores in 1856 the TPU topology. May be omitted for a single-core computation, in which 1857 case the core attached to task 0, TPU device 0 is used. 1858 name: The name of the operator. 1859 Returns: 1860 A list of output tensors. 1861 """ 1862 1863 def guarantee_const_getter(getter, name, *args, **kwargs): 1864 with ops.control_dependencies(None): 1865 return array_ops.guarantee_const( 1866 getter(name, *args, **kwargs), name=name + "/GuaranteeConst") 1867 1868 def wrapped_computation(*args, **kwargs): 1869 """Execute computation under `_TPUInferenceContext`.""" 1870 context = _TPUInferenceContext( 1871 name=ops.get_default_graph().unique_name("rewrite_for_inference")) 1872 try: 1873 context.Enter() 1874 1875 vscope = variable_scope.get_variable_scope() 1876 prev_custom_getter = vscope.custom_getter 1877 prev_caching_device = vscope.caching_device 1878 vscope.set_custom_getter(guarantee_const_getter) 1879 vscope.set_caching_device(lambda op: op.device) 1880 1881 result = computation(*args, **kwargs) 1882 1883 vscope.set_custom_getter(prev_custom_getter) 1884 vscope.set_caching_device(prev_caching_device) 1885 finally: 1886 context.Exit() 1887 return result 1888 1889 # pylint: disable=undefined-variable 1890 return rewrite( 1891 wrapped_computation, 1892 inputs=inputs, 1893 infeed_queue=infeed_queue, 1894 device_assignment=device_assignment, 1895 name=name) 1896 # pylint: enable=undefined-variable 1897 1898 1899def prune_unconnected_ops_from_xla(prune_graph): 1900 """Prunes unconnected ops as listed in _UNCONNECTED_OPS_TO_PRUNE. 1901 1902 Args: 1903 prune_graph: A tensorflow graph from which we wish to prune unconnected ops 1904 as listed in _UNCONNECTED_OPS_TO_PRUNE. In general, these ops should have 1905 no inputs and no consumers. These can often be left behind due to graph 1906 construction rewiring (for instance TF-Hub). While they never execute, 1907 they will cause XLA compile to fail so we strip them from XLA compile by 1908 removing the tpu_replicate attribute. 1909 """ 1910 # Scan over the top level graph and all function graphs. 1911 for graph in [prune_graph] + [ 1912 f for f in prune_graph._functions.values() # pylint: disable=protected-access 1913 ]: 1914 if not isinstance(graph, ops.Graph): 1915 continue 1916 for op in graph.get_operations(): 1917 if op.type not in _UNCONNECTED_OPS_TO_PRUNE: 1918 continue 1919 outputs_consumed = False 1920 for output in op.outputs: 1921 if output.consumers(): 1922 outputs_consumed = True 1923 break 1924 if not outputs_consumed: 1925 logging.info( 1926 "Pruning OP %s of type %s from XLA Compile due to " 1927 "it being disconnected.", op.name, op.type) 1928 op._clear_attr(_TPU_REPLICATE_ATTR) # pylint: disable=protected-access 1929