1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ====================================== 15 16"""Library of TPU helper functions.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23from six.moves import xrange # pylint: disable=redefined-builtin 24 25from tensorflow.core.framework import attr_value_pb2 26from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding 27from tensorflow.python.compat import compat as api_compat 28from tensorflow.python.framework import device as pydev 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import errors 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import tensor_shape 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import variable_scope 37from tensorflow.python.platform import tf_logging as logging 38from tensorflow.python.tpu import tpu_function 39from tensorflow.python.tpu import xla 40from tensorflow.python.tpu.ops import tpu_ops 41from tensorflow.python.util import compat 42from tensorflow.python.util import nest 43 44 45# Operations that indicate some error in the users graph, e.g. a placeholder 46# that's introduced outside of the infeed. 47_BLACKLISTED_OPS = set([ 48 "Placeholder", 49]) 50 51# XLA doesn't currently support reading of intermediate tensors, thus some ops 52# are not supported. 53_UNSUPPORTED_OPS = set([ 54 "AudioSummary", 55 "AudioSummaryV2", 56 "HistogramSummary", 57 "ImageSummary", 58 "MergeSummary", 59 "Print", 60 "ScalarSummary", 61 "TensorSummary", 62 "TensorSummaryV2", 63 ]) 64 65# Ops which can be safely pruned from XLA compile if they have no consumers. 66# These ops should also have no inputs. 67_UNCONNECTED_OPS_TO_PRUNE = set(["Placeholder", "VarHandleOp"]) 68 69_MAX_WARNING_LINES = 5 70 71_TPU_REPLICATE_ATTR = "_tpu_replicate" 72_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status" 73_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation" 74 75 76def _tpu_system_device_name(job): 77 """Returns the device name for the TPU_SYSTEM device of `job`.""" 78 if job is None: 79 return "/device:TPU_SYSTEM:0" 80 else: 81 return "/job:%s/device:TPU_SYSTEM:0" % job 82 83 84def initialize_system(embedding_config=None, job=None): 85 """Initializes a distributed TPU system for use with TensorFlow. 86 87 Args: 88 embedding_config: If not None, a `TPUEmbeddingConfiguration` proto 89 describing the desired configuration of the hardware embedding lookup 90 tables. If embedding_config is None, no hardware embeddings can be used. 91 job: The job (the XXX in TensorFlow device specification /job:XXX) that 92 contains the TPU devices that will be initialized. If job=None it is 93 assumed there is only one job in the TensorFlow flock, and an error will 94 be returned if this assumption does not hold. 95 Returns: 96 A serialized `TopologyProto` that describes the TPU system. Note: 97 the topology must be evaluated using `Session.run` before it can be used. 98 """ 99 config_string = ("" if embedding_config is None else 100 embedding_config.SerializeToString()) 101 with ops.device(_tpu_system_device_name(job)): 102 return tpu_ops.configure_distributed_tpu(embedding_config=config_string) 103 104 105def shutdown_system(job=None): 106 """Shuts down a running a distributed TPU system.""" 107 with ops.device(_tpu_system_device_name(job)): 108 shutdown_distributed_tpu = tpu_ops.shutdown_distributed_tpu() 109 return shutdown_distributed_tpu 110 111 112def core(num): 113 """Returns the device name for a core in a replicated TPU computation. 114 115 Args: 116 num: the virtual core number within each replica to which operators should 117 be assigned. 118 Returns: 119 A device name, suitable for passing to `tf.device()`. 120 """ 121 return "device:TPU_REPLICATED_CORE:{}".format(num) 122 123 124class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): 125 """A `ControlFlowContext` for nodes inside a TPU computation. 126 127 The primary role of `TPUReplicateContext` is to mark operators inside a 128 tpu.replicate() computation with the attribute "_tpu_replicate=XYZ", where XYZ 129 is a unique name. 130 131 We use a `ControlFlowContext` to perform the annotation since it integrates 132 with Tensorflow constructs like ResourceVariables. For example, if a 133 `ResourceVariable` is constructed inside a tpu.replicate() block, the 134 `ResourceVariable` implementation can use 135 `with ops.control_dependencies(None)` to build the variable's definition 136 outside the replicated computation. 137 """ 138 139 def __init__(self, name, num_replicas, pivot): 140 """Builds a new TPUReplicateContext. 141 142 Args: 143 name: a unique name for the context, used to populate the `_tpu_replicate` 144 attribute. 145 num_replicas: an integer that gives the number of replicas for the 146 computation. 147 pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any 148 inputs will have a control dependency on the pivot node. This ensures 149 that nodes are correctly included in any enclosing control flow 150 contexts. 151 """ 152 super(TPUReplicateContext, self).__init__() 153 self._num_replicas = num_replicas 154 self._outer_device_function_stack = None 155 self._oc_dev_fn_stack = None 156 self._outside_compilation_cluster = None 157 self._outside_compilation_counter = 0 158 self._in_gradient_colocation = None 159 self._gradient_colocation_stack = [] 160 self._host_compute_core = [] 161 self._name = name 162 self._name_as_bytes = compat.as_bytes(name) 163 self._unsupported_ops = [] 164 self._pivot = pivot 165 self._replicated_vars = {} 166 167 def get_replicated_var_handle(self, name, vars_): 168 """Returns a variable handle for replicated TPU variable 'var'. 169 170 This is a method used by an experimental replicated variable implementation 171 and is not intended as a public API. 172 173 Args: 174 name: The common name of the variable. 175 vars_: The replicated TPU variables. 176 177 Returns: 178 The handle of the TPU replicated input node. 179 """ 180 handle = self._replicated_vars.get(name) 181 if handle is not None: 182 return handle 183 184 # Builds a TPUReplicatedInput node for the variable, if one does not already 185 # exist. The TPUReplicatedInput node must belong to the enclosing 186 # control-flow scope of the TPUReplicateContext. 187 # TODO(phawkins): consider changing the contract of the TPU encapsulation 188 # so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope 189 # instead. 190 191 # pylint: disable=protected-access 192 graph = ops.get_default_graph() 193 saved_context = graph._get_control_flow_context() 194 graph._set_control_flow_context(self.outer_context) 195 handle = tpu_ops.tpu_replicated_input( 196 [v.handle for v in vars_], name=name + "/handle") 197 graph._set_control_flow_context(saved_context) 198 # pylint: enable=protected-access 199 self._replicated_vars[name] = handle 200 return handle 201 202 def report_unsupported_operations(self): 203 if self._unsupported_ops: 204 op_str = "\n".join([" %s (%s)" % (op.type, op.name) 205 for op in self._unsupported_ops[:_MAX_WARNING_LINES]]) 206 logging.warning("%d unsupported operations found: \n%s", 207 len(self._unsupported_ops), op_str) 208 if len(self._unsupported_ops) > _MAX_WARNING_LINES: 209 logging.warning("... and %d more" % 210 (len(self._unsupported_ops) - _MAX_WARNING_LINES)) 211 212 def EnterGradientColocation(self, op, gradient_uid): 213 if op is not None: 214 self._gradient_colocation_stack.append(op) 215 if not self._outside_compilation_cluster: 216 try: 217 outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR) 218 if self._in_gradient_colocation: 219 raise NotImplementedError( 220 "Cannot nest gradient colocation operations outside compilation" 221 ) 222 if gradient_uid == "__unsupported__": 223 raise NotImplementedError( 224 "No gradient_uid calling gradient within outside_compilation") 225 # When we take the gradient of an op X in an outside_compilation 226 # cluster C in a forward computation we would like to put the ops 227 # corresponding to the gradient of X into a new outside_compilation 228 # cluster C'. However, if we take the gradient of X twice, the second 229 # one should get yet another new outside_compilation cluster C''. 230 # 231 # The mechanism we adopt is to use a 'root_cluster' which is the 232 # cluster that X was in before we took gradients, and a 'gradient_uid' 233 # which is different for every invocation of gradients, and put the 234 # gradient of X in cluster 'root_cluster.gradient_uid'. 235 # 236 # When taking a gradient of a gradient, some ops will be colocated 237 # with Op in the forward pass (e.g., cluster root_cluster) and some in 238 # the backward pass (e.g., cluster root_cluster.initial_gradient_uid). 239 # We need all of the grad-of-grad ops to be in the same cluster to 240 # avoid cyclic dependencies between clusters. We adopt a heuristic 241 # that puts any op clustered with root_cluster.<xxx> in 242 # root_cluster.gradient_uid, even if xxx was initial_gradient_uid. 243 self._in_gradient_colocation = op 244 parts = outside_attr.split(".") 245 cluster = parts[0] + "." + gradient_uid 246 self._EnterOutsideCompilationScope(cluster=cluster) 247 except ValueError: 248 # The attr was not present: do nothing. 249 pass 250 251 def ExitGradientColocation(self, op, gradient_uid): 252 if op is not None: 253 if not self._gradient_colocation_stack: 254 raise errors.InternalError( 255 op.node_def, op, 256 "Badly nested gradient colocation: empty stack when popping Op " + 257 op.name) 258 last_op = self._gradient_colocation_stack.pop() 259 if op is last_op: 260 if op is self._in_gradient_colocation: 261 self._in_gradient_colocation = None 262 self._ExitOutsideCompilationScope() 263 else: 264 raise errors.InternalError( 265 op.node_def, op, "Badly nested gradient colocation, expected " + 266 last_op + ", got " + op.name) 267 268 def _EnterOutsideCompilationScope(self, cluster=None): 269 270 class FakeOp(object): 271 """A helper class to determine the current device. 272 273 Supports only the type and device set/get methods needed to run the 274 graph's _apply_device_function method. 275 """ 276 277 def __init__(self): 278 self._device = "" 279 280 @property 281 def type(self): 282 return "FakeOp" 283 284 @property 285 def device(self): 286 return self._device 287 288 def _set_device(self, device): 289 if isinstance(device, pydev.DeviceSpec): 290 self._device = device.to_string() 291 else: 292 self._device = device 293 294 if self._outside_compilation_cluster: 295 raise NotImplementedError("Cannot nest outside_compilation clusters") 296 if cluster: 297 self._outside_compilation_cluster = cluster 298 else: 299 self._outside_compilation_cluster = str(self._outside_compilation_counter) 300 self._outside_compilation_counter += 1 301 graph = ops.get_default_graph() 302 fake_op = FakeOp() 303 graph._apply_device_functions(fake_op) # pylint: disable=protected-access 304 device = pydev.DeviceSpec.from_string(fake_op.device) 305 if (device.device_type == "TPU_REPLICATED_CORE" and 306 device.device_index is not None): 307 self._host_compute_core.append(self._outside_compilation_cluster + ":" + 308 str(device.device_index)) 309 self._oc_dev_fn_stack = graph._device_function_stack # pylint: disable=protected-access 310 graph._device_function_stack = self._outer_device_function_stack # pylint: disable=protected-access 311 312 def _ExitOutsideCompilationScope(self): 313 if not self._outside_compilation_cluster: 314 raise NotImplementedError( 315 "Attempted to exit outside_compilation scope when not in scope") 316 self._outside_compilation_cluster = None 317 graph = ops.get_default_graph() 318 graph._device_function_stack = self._oc_dev_fn_stack # pylint: disable=protected-access 319 320 def Enter(self): 321 if not self._outer_device_function_stack: 322 # Capture the device function stack at the time of first entry 323 # since that is the stack that will be used outside_compilation. 324 graph = ops.get_default_graph() 325 # pylint: disable=protected-access 326 self._outer_device_function_stack = graph._device_function_stack.copy() 327 # pylint: enable=protected-access 328 super(TPUReplicateContext, self).Enter() 329 330 def HostComputeCore(self): 331 return self._host_compute_core 332 333 def _RemoveExternalControlEdges(self, op): 334 """Remove any external control dependency on this op.""" 335 internal_control_inputs = [] 336 external_control_inputs = [] 337 for x in op.control_inputs: 338 # pylint: disable=protected-access 339 is_internal_op = False 340 ctxt = x._get_control_flow_context() 341 while ctxt is not None: 342 if ctxt == self: 343 is_internal_op = True 344 break 345 ctxt = ctxt._outer_context 346 if is_internal_op: 347 internal_control_inputs.append(x) 348 else: 349 external_control_inputs.append(x) 350 # pylint: enable=protected-access 351 # pylint: disable=protected-access 352 op._remove_all_control_inputs() 353 op._add_control_inputs(internal_control_inputs) 354 # pylint: enable=protected-access 355 return internal_control_inputs, external_control_inputs 356 357 def AddOp(self, op): 358 # pylint: disable=protected-access 359 if op.type in _BLACKLISTED_OPS: 360 logging.error("Operation of type %s (%s) is not supported on the TPU. " 361 "Execution will fail if this op is used in the graph. " % 362 (op.type, op.name)) 363 364 if op.type in _UNSUPPORTED_OPS: 365 self._unsupported_ops.append(op) 366 367 if any(x.dtype._is_ref_dtype for x in op.inputs): 368 raise NotImplementedError( 369 "Non-resource Variables are not supported inside TPU computations " 370 "(operator name: %s)" % op.name) 371 if _TPU_REPLICATE_ATTR in op.node_def.attr: 372 raise ValueError("TPU computations cannot be nested") 373 op._set_attr(_TPU_REPLICATE_ATTR, 374 attr_value_pb2.AttrValue(s=self._name_as_bytes)) 375 if self._outside_compilation_cluster: 376 op._set_attr( 377 _OUTSIDE_COMPILATION_ATTR, 378 attr_value_pb2.AttrValue( 379 s=compat.as_bytes(self._outside_compilation_cluster))) 380 if self._num_replicas > 1 or not self._outside_compilation_cluster: 381 # Prevent feeding or fetching anything that is being compiled, 382 # and any replicated outside_compilation Op. 383 op.graph.prevent_feeding(op) 384 op.graph.prevent_fetching(op) 385 386 # Remove any control edges from outer control flow contexts. These may cause 387 # mismatched frame errors. 388 (internal_control_inputs, 389 external_control_inputs) = self._RemoveExternalControlEdges(op) 390 391 if not op.inputs: 392 # Add a control edge from the control pivot to this op. 393 if not internal_control_inputs: 394 # pylint: disable=protected-access 395 op._add_control_input(self.GetControlPivot()) 396 # pylint: enable=protected-access 397 else: 398 for index in xrange(len(op.inputs)): 399 x = op.inputs[index] 400 real_x = self.AddValue(x) 401 if real_x != x: 402 op._update_input(index, real_x) # pylint: disable=protected-access 403 404 if external_control_inputs: 405 # Use an identity to pull control inputs as data inputs. Note that we 406 # ignore ops which don't have outputs. TODO(phawkins): fix that. 407 with ops.control_dependencies(None): 408 self.Enter() 409 external_control_inputs = [ 410 array_ops.identity(x.outputs[0]).op 411 for x in external_control_inputs 412 if x.outputs 413 ] 414 self.Exit() 415 # pylint: disable=protected-access 416 op._add_control_inputs(external_control_inputs) 417 # pylint: enable=protected-access 418 419 # Mark op's outputs as seen by this context and any outer contexts. 420 output_names = [x.name for x in op.outputs] 421 context = self 422 while context is not None: 423 # pylint: disable=protected-access 424 context._values.update(output_names) 425 context = context._outer_context 426 # pylint: enable=protected-access 427 428 if self._outer_context: 429 self._outer_context.AddInnerOp(op) 430 431 def AddValue(self, val): 432 """Add `val` to the current context and its outer context recursively.""" 433 if val.name in self._values: 434 # Use the real value if it comes from outer context. 435 result = self._external_values.get(val.name) 436 return val if result is None else result 437 438 result = val 439 self._values.add(val.name) 440 if self._outer_context: 441 result = self._outer_context.AddValue(val) 442 self._values.add(result.name) 443 444 self._external_values[val.name] = result 445 446 return result 447 448 def AddInnerOp(self, op): 449 self.AddOp(op) 450 if self._outer_context: 451 self._outer_context.AddInnerOp(op) 452 453 @property 454 def grad_state(self): 455 # Define the gradient loop state associated with the TPUReplicateContext to 456 # be None as the TPUReplicateContext does not get nested nor does the 457 # grad_state outside the TPUReplicateContext affect the graph inside so the 458 # grad_state should be as if this is the top-level gradient state. 459 return None 460 461 @property 462 def back_prop(self): 463 """Forwards to the enclosing while context, if any.""" 464 if self.GetWhileContext(): 465 return self.GetWhileContext().back_prop 466 return False 467 468 def GetControlPivot(self): 469 return self._pivot 470 471 472def outside_compilation(computation, *args, **kwargs): 473 """Builds part of a computation outside any current TPU replicate scope. 474 475 Args: 476 computation: A Python function that builds the computation to 477 place on the host. 478 *args: the positional arguments for the computation. 479 **kwargs: the keyword arguments for the computation. 480 481 Returns: 482 The Tensors returned by computation. 483 """ 484 args = [] if args is None else args 485 graph = ops.get_default_graph() 486 487 # If we are in a TPUReplicateContext, signal that we are now 488 # outside_compilation 489 initial_context = graph._get_control_flow_context() # pylint: disable=protected-access 490 context = initial_context 491 while context: 492 if isinstance(context, TPUReplicateContext): 493 context._EnterOutsideCompilationScope() # pylint: disable=protected-access 494 context = context.outer_context 495 496 retval = computation(*args, **kwargs) 497 498 # If we are in a TPUReplicateContext, signal that we are no longer 499 # outside_compilation 500 final_context = graph._get_control_flow_context() # pylint: disable=protected-access 501 if initial_context is not final_context: 502 raise NotImplementedError( 503 "Control-flow context cannot be different at start and end of an " 504 "outside_compilation scope") 505 context = initial_context 506 while context: 507 if isinstance(context, TPUReplicateContext): 508 context._ExitOutsideCompilationScope() # pylint: disable=protected-access 509 context = context.outer_context 510 511 return retval 512 513 514def replicate(computation, 515 inputs=None, 516 infeed_queue=None, 517 device_assignment=None, 518 name=None, 519 maximum_shapes=None): 520 """Builds a graph operator that runs a replicated TPU computation. 521 522 Args: 523 computation: A Python function that builds the computation to replicate. 524 inputs: A list of lists of input tensors or `None` (equivalent to 525 `[[]]`), indexed by `[replica_num][input_num]`. All replicas must 526 have the same number of inputs. Each input can be a nested structure 527 containing values that are convertible to tensors. Note that passing an 528 N-dimension list of compatible values will result in a N-dimention list of 529 scalar tensors rather than a single Rank-N tensors. If you need different 530 behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. 531 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 532 of arguments as inputs to computation. 533 device_assignment: If not `None`, a `DeviceAssignment` describing the 534 mapping between logical cores in the computation with physical cores in 535 the TPU topology. Uses a default device assignment if `None`. The 536 `DeviceAssignment` may be omitted if each replica of the computation uses 537 only one core, and there is either only one replica, or the number of 538 replicas is equal to the number of cores in the TPU system. 539 name: (Deprecated) Does nothing. 540 maximum_shapes: A nested structure of tf.TensorShape representing the shape 541 to which the respective component of each input element in each replica 542 should be padded. Any unknown dimensions (e.g. tf.Dimension(None) in a 543 tf.TensorShape or -1 in a tensor-like object) will be padded to the 544 maximum size of that dimension over all replicas. Note that if the input 545 dimension is already static, we won't do padding on it and we require the 546 maximum_shapes to have the same value or None on that dimension. The 547 structure of `maximum_shapes` needs to be the same as `inputs[0]`. 548 Returns: 549 A list of outputs, indexed by `[replica_num]` each output can be a nested 550 structure same as what computation() returns with a few exceptions. 551 552 Exceptions include: 553 1) None output: a NoOp would be returned which control-depends on 554 computation. 555 2) Single value output: A tuple containing the value would be returned. 556 3) Operation-only outputs: a NoOp would be returned which 557 control-depends on computation. 558 TODO(b/121383831): Investigate into removing these special cases. 559 560 Raises: 561 ValueError: If all replicas do not have equal numbers of input tensors. 562 ValueError: If the number of inputs per replica does not match 563 the number of formal parameters to `computation`. 564 ValueError: If the static `inputs` dimensions don't match with the values 565 given in `maximum_shapes`. 566 ValueError: If the structure of inputs per replica does not match 567 the structure of `maximum_shapes`. 568 """ 569 return split_compile_and_replicate( 570 computation, 571 inputs, 572 infeed_queue, 573 device_assignment, 574 name, 575 maximum_shapes=maximum_shapes)[1] 576 577 578def _pad_all_input(inputs, padded_shapes): 579 """Pad all input tensors given padded_shapes. 580 581 The real shape tensors will be concatenated with the padded original inputs. 582 583 Args: 584 inputs: The original inputs. 585 padded_shapes: A list of padded shapes for each input. 586 587 Returns: 588 The padded inputs and a PaddingMap list which maps the padded input 589 dimension to the real shape argument index. 590 """ 591 input_shape_tensors = [] 592 for core_idx, inputs_per_core in enumerate(inputs): 593 for idx, input_tensor in enumerate(inputs_per_core): 594 if core_idx == 0: 595 input_shape_tensors.append([]) 596 input_shape_tensors[idx].append(array_ops.shape(input_tensor)) 597 598 maximum_shapes = [] 599 for shapes_per_input in input_shape_tensors: 600 maximum_shapes.append( 601 math_ops.reduce_max(array_ops.stack(shapes_per_input), axis=0)) 602 603 padded_inputs = [] 604 real_shapes = [] 605 padding_maps = [] 606 for core_idx, inputs_per_core in enumerate(inputs): 607 padded_inputs.append([]) 608 real_shapes.append([]) 609 real_shape_idx = len(inputs_per_core) - 1 610 for idx, input_tensor in enumerate(inputs_per_core): 611 input_shape_tensor = input_shape_tensors[idx][core_idx] 612 input_shape = input_tensor.get_shape() 613 padded_shape = padded_shapes[idx] 614 615 # The static shape of inputs should be compatible with the given padded 616 # shapes. 617 input_shape.assert_is_compatible_with(padded_shape) 618 619 if input_shape.is_fully_defined(): 620 # Do nothing if the shape of the whole tensor is already static. 621 padded_inputs[core_idx].append(input_tensor) 622 else: 623 # Only pad the non static shape dimension. 624 for i, s in enumerate(input_shape): 625 if s.value is None: 626 if core_idx == 0: 627 real_shape_idx += 1 628 padding_map = dynamic_padding.PaddingMap() 629 padding_map.arg_index = idx 630 padding_map.shape_index = i 631 padding_map.padding_arg_index = real_shape_idx 632 padding_maps.append(padding_map) 633 real_shapes[core_idx].append( 634 math_ops.cast(input_shape_tensor[i], dtypes.uint32)) 635 636 paddings = [] 637 for i, s in enumerate(padded_shape): 638 if input_shape[i].value: 639 # Don't pad if input shape is already static. 640 padding = [0, 0] 641 else: 642 if s.value: 643 # Pad to the given maximum value. 644 padding = [0, s.value - input_shape_tensor[i]] 645 else: 646 # If maximum value is not given, then pad to the maximum dimension 647 # among all the cores. 648 padding = [0, maximum_shapes[idx][i] - input_shape_tensor[i]] 649 paddings.append(padding) 650 651 padded_input = array_ops.pad(input_tensor, paddings) 652 padded_inputs[core_idx].append(padded_input) 653 654 num_replicas = len(padded_inputs) 655 for i in range(num_replicas): 656 padded_inputs[i].extend(real_shapes[i]) 657 658 return padded_inputs, padding_maps 659 660 661def split_compile_and_replicate(computation, 662 inputs=None, 663 infeed_queue=None, 664 device_assignment=None, 665 name=None, 666 use_tpu=True, 667 maximum_shapes=None): 668 """Builds graph operators that runs compilation and replicated computation. 669 670 This is a lower level interface than replicate that returns a separate compile 671 and execute output tensor. In the generated graph the compile op feeds into 672 the execute op and no additional compilation is incurred when running the 673 compile op before the execute op. The compile op returns additional 674 information about the compilation but does not return the compiled program. 675 676 Args: 677 computation: A Python function that builds the computation to replicate. 678 inputs: A list of lists of input tensors or `None` (equivalent to 679 `[[]]`), indexed by `[replica_num][input_num]`. All replicas must 680 have the same number of inputs. Each input can be a nested structure 681 containing values that are convertible to tensors. Note that passing an 682 N-dimension list of compatible values will result in a N-dimention list of 683 scalar tensors rather than a single Rank-N tensors. If you need different 684 behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. 685 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 686 of arguments as inputs to computation. 687 device_assignment: If not `None`, a `DeviceAssignment` describing the 688 mapping between logical cores in the computation with physical cores in 689 the TPU topology. Uses a default device assignment if `None`. The 690 `DeviceAssignment` may be omitted if each replica of the computation uses 691 only one core, and there is either only one replica, or the number of 692 replicas is equal to the number of cores in the TPU system. 693 name: (Deprecated) Does nothing. 694 use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU 695 backends. Currently, only supports a default placement (computation is 696 placed on GPU if one is available, and on CPU if not). 697 maximum_shapes: A nested structure of tf.TensorShape representing the shape 698 to which the respective component of each input element in each replica 699 should be padded. Any unknown dimensions (e.g. tf.Dimension(None) in a 700 tf.TensorShape or -1 in a tensor-like object) will be padded to the 701 maximum size of that dimension over all replicas. Note that if the input 702 dimension is already static, we won't do padding on it and we require the 703 maximum_shapes to have the same value or None on that dimension. The 704 structure of `maximum_shapes` needs to be the same as `inputs[0]`. 705 706 Returns: 707 A list of lists with the first list corresponding to the compile op and the 708 second a list of output tensors, indexed by `[replica_num][output_num]`. 709 Raises: 710 ValueError: If all replicas do not have equal numbers of input tensors. 711 ValueError: If the number of inputs per replica does not match 712 the number of formal parameters to `computation`. 713 ValueError: If the static `inputs` dimensions don't match with the values 714 given in `maximum_shapes`. 715 ValueError: If the structure of inputs per replica does not match 716 the structure of `maximum_shapes`. 717 """ 718 del name 719 inputs = [[]] if inputs is None else inputs 720 721 metadata_kwargs = {} 722 if device_assignment is not None: 723 # Turn the Numpy array into a flattened list so we can pass it as an 724 # operator attribute. 725 metadata_kwargs = { 726 "topology": 727 device_assignment.topology.serialized(), 728 "device_assignment": 729 device_assignment.core_assignment.flatten().tolist() 730 } 731 # TODO(phawkins): remove this case after the forward compatibility window 732 # expires on 2018-10-5. 733 if api_compat.forward_compatible(2018, 10, 5): 734 metadata_kwargs["num_cores_per_replica"] = ( 735 device_assignment.num_cores_per_replica) 736 else: 737 metadata_kwargs["computation_shape"] = [ 738 device_assignment.num_cores_per_replica 739 ] 740 741 if ((not isinstance(inputs, list)) or 742 any(not isinstance(inp, (list, tuple)) for inp in inputs)): 743 raise TypeError("tpu.replicate() inputs must be a list of lists/tuples") 744 745 num_replicas = len(inputs) 746 747 # No replicas? Nothing to do. 748 if num_replicas == 0: 749 return [] 750 751 # Checks all replicas have the same structure. 752 for i in xrange(1, num_replicas): 753 nest.assert_same_structure(inputs[0], inputs[i]) 754 755 # Flatten inputs. 756 flat_inputs = [ 757 nest.flatten(per_replica_input) for per_replica_input in inputs 758 ] 759 # Converts inputs to Tensors. 760 flat_inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in flat_inputs] 761 762 # Verifies that all replicas have matching numbers and types of inputs 763 flat_input_types = [x.dtype for x in flat_inputs[0]] 764 input_arity = len(inputs[0]) 765 flat_input_arity = len(flat_input_types) 766 for i in range(num_replicas): 767 if len(inputs[i]) != input_arity: 768 raise ValueError("Replicas must have the same number of inputs. " 769 "Replica 0 had {} inputs, replica {} had {} " 770 "inputs.".format(input_arity, i, len(inputs[i]))) 771 772 types = [x.dtype for x in flat_inputs[i]] 773 if types != flat_input_types: 774 raise ValueError("Replicas must have matching input types. Replica 0 had " 775 "input types {}, replica {} had input types {}".format( 776 flat_input_types, i, types)) 777 778 arg_error = xla.check_function_argument_count( 779 computation, input_arity, infeed_queue) 780 if arg_error is not None: 781 if infeed_queue is None: 782 raise TypeError( 783 "Supplied computation cannot be called with the specified inputs. " 784 "You specified %d inputs: %s, but the computation needs %s" % ( 785 input_arity, str([i.name for i in inputs[0]]), arg_error)) 786 else: 787 raise TypeError( 788 "Supplied computation cannot be called with the specified inputs. " 789 "You specified %d inputs: %s and %d additional inputs from infeed," 790 " but the computation needs %s" % (input_arity, str( 791 [i.name 792 for i in inputs[0]]), infeed_queue.number_of_tuple_elements, 793 arg_error)) 794 795 if maximum_shapes: 796 if infeed_queue: 797 raise ValueError( 798 "Dynamic input shapes are not supported with infeed queues") 799 800 # Make sure maximum_shapes has the same structure as inputs. 801 nest.assert_same_structure(inputs[0], maximum_shapes, check_types=False) 802 803 # Flatten padded shapes. 804 flat_maximum_shapes = nest.flatten(maximum_shapes) 805 flat_maximum_shapes = [ 806 tensor_shape.TensorShape(s) for s in flat_maximum_shapes 807 ] 808 809 flat_inputs, padding_maps = _pad_all_input(flat_inputs, flat_maximum_shapes) 810 811 serialized_padding_maps = [] 812 for padding_map in padding_maps: 813 serialized_padding_maps.append(padding_map.SerializeToString()) 814 metadata_kwargs["padding_map"] = serialized_padding_maps 815 816 metadata_kwargs["step_marker_location"] = getattr( 817 computation, "step_marker_location", "STEP_MARK_AT_ENTRY") 818 819 graph = ops.get_default_graph() 820 821 # Fan-in: Builds a TPUReplicatedInput node for each input. 822 flat_replicated_inputs = [] 823 for i in range(0, len(flat_inputs[0])): 824 replicas = [flat_inputs[replica][i] for replica in xrange(num_replicas)] 825 flat_replicated_inputs.append( 826 tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) 827 828 cluster_name = graph.unique_name("cluster") 829 pivot = control_flow_ops.no_op(name=cluster_name + "/pivot") 830 context = TPUReplicateContext( 831 name=cluster_name, num_replicas=num_replicas, pivot=pivot) 832 try: 833 context.Enter() 834 835 metadata = tpu_ops.tpu_replicate_metadata( 836 num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs) 837 838 with tpu_function.tpu_shard_context( 839 num_replicas), ops.control_dependencies([metadata]): 840 841 # Add identity ops so even unused inputs are "consumed" by the 842 # computation. This is to avoid orphaned TPUReplicatedInput nodes. 843 # TODO(phawkins): consider instead pruning unused TPUReplicatedInput 844 # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs. 845 flat_replicated_inputs = [ 846 array_ops.identity(x, name="replicated_input_{}".format(i)) 847 for i, x in enumerate(flat_replicated_inputs) 848 ] 849 for i in flat_replicated_inputs: 850 # pylint: disable=protected-access 851 # Add an attribute to the identity node so that they could be removed in 852 # encapsulate TPU computation pass if unused. However we don't remove 853 # inputs when dynamic padding is enabled. 854 # TODO(rxsang): Use other ways except argument index in padding_map so 855 # outside compilation can work with dynamic padding correctly. 856 if maximum_shapes is None: 857 i.op._set_attr("_tpu_input_identity", 858 attr_value_pb2.AttrValue(b=True)) 859 # pylint: enable=protected-access 860 861 # Unflatten the computation inputs to match original input structure. 862 computation_inputs = nest.pack_sequence_as( 863 structure=inputs[0], 864 flat_sequence=flat_replicated_inputs[:flat_input_arity]) 865 866 # If there is an infeed queue, adds the dequeued values to the 867 # computation's inputs. 868 if infeed_queue is not None: 869 infeed_queue.set_number_of_shards(num_replicas) 870 for t in infeed_queue.generate_dequeue_op(): 871 computation_inputs.append(t) 872 873 # Only resource variables work inside a TPU computation, so turn on 874 # resource variables for the computation. 875 # TODO(phawkins): consider removing this code. It will 876 # be less confusing to clients if they knowingly choose to use resource 877 # variables. 878 # Partitioned variables is not supported (b/112311320). 879 vscope = variable_scope.get_variable_scope() 880 saved_use_resource = vscope.use_resource 881 saved_custom_getter = vscope.custom_getter 882 883 def custom_getter(getter, name, *args, **kwargs): 884 """Variables on TPU have a few restrictions.""" 885 partitioner = kwargs["partitioner"] 886 if partitioner is not None: 887 kwargs["partitioner"] = None 888 logging.warning( 889 "Partitioned variables are not supported on TPU. Got " 890 "`partitioner` that is {} for variable {}. " 891 "Setting `partitioner` to `None`." 892 .format(partitioner, name)) 893 if saved_custom_getter is None: 894 return getter(name, *args, **kwargs) 895 else: 896 return saved_custom_getter(getter, name, *args, **kwargs) 897 898 vscope.set_use_resource(True) 899 vscope.set_custom_getter(custom_getter) 900 901 outputs = computation(*computation_inputs) 902 903 vscope.set_use_resource(saved_use_resource) 904 vscope.set_custom_getter(saved_custom_getter) 905 906 outputs_is_flat = xla.is_flat(outputs) 907 if outputs_is_flat: 908 output_tensors, control_deps = _postprocess_flat_outputs(outputs) 909 else: 910 output_tensors, control_deps = _postprocess_non_flat_outputs(outputs) 911 912 # tensor_tracer imports tpu.py. Local import to tensor_tracer to avoid 913 # import-cycle 914 # pylint: disable=g-import-not-at-top 915 from tensorflow.python.tpu import tensor_tracer 916 # pylint: enable=g-import-not-at-top 917 if tensor_tracer.TensorTracer.is_enabled(): 918 tt = tensor_tracer.TensorTracer() 919 output_tensors = tt.trace_tpu(ops.get_default_graph(), 920 output_tensors, control_deps, 921 num_replicas) 922 923 context.ExitResult(output_tensors) 924 finally: 925 context.report_unsupported_operations() 926 context.Exit() 927 host_compute_core = context.HostComputeCore() 928 929 if host_compute_core: 930 attr_value = attr_value_pb2.AttrValue() 931 attr_value.list.s.extend([compat.as_bytes(x) for x in host_compute_core]) 932 metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access 933 934 with ops.control_dependencies([metadata]): 935 if use_tpu: 936 compile_status = tpu_ops.tpu_compilation_result() 937 op = compile_status.op 938 attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)) 939 op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access 940 else: 941 compile_status = control_flow_ops.no_op(name="compilation_status") 942 943 if not output_tensors: 944 # Returns a list of NoOps dependent on the replication Op, indexed by 945 # [replica_num]. 946 return [ 947 compile_status, 948 [ 949 control_flow_ops.group(control_deps, name="shard_%d" % i) 950 for i in range(num_replicas) 951 ] 952 ] 953 954 # Fan-out: Builds a TPUReplicatedOutput node for each output. 955 replicated_outputs = [[] for i in xrange(num_replicas)] 956 for i, t in enumerate(output_tensors): 957 # Fan-out: Builds a TPUReplicatedOutput node for each output. 958 ys = tpu_ops.tpu_replicated_output( 959 t, num_replicas, name="output{}".format(i)) 960 961 # Wraps the outputs in identity operators so the names of any possible 962 # `fetch` nodes are preserved by the replication rewrite. 963 with ops.control_dependencies(control_deps): 964 for replica in xrange(num_replicas): 965 replicated_outputs[replica].append( 966 array_ops.identity( 967 ys[replica], name="output_%d_shard_%d" % (i, replica))) 968 969 if not outputs_is_flat: 970 replicated_outputs = [ 971 nest.pack_sequence_as(outputs, replica_outs) 972 for replica_outs in replicated_outputs 973 ] 974 975 return [compile_status, replicated_outputs] 976 977 978def _postprocess_flat_outputs(outputs): 979 """Validates non-flat outputs, add backs device assignments and other attrs. 980 981 Args: 982 outputs: Output from `computation` inside `tpu.rewrite`. 983 984 Returns: 985 Tensors and Operations extracted from outputs. 986 """ 987 # Following code segment is to preserve legacy behavior. Previously we only 988 # supported flat outputs and thus for consistency it was nice to convert even 989 # single element into a tuple. But now that we support arbitrary output 990 # structure, this is no longer necessary. 991 # TODO(b/121383831): Migrate all legacy use cases and delete this special 992 # case. 993 # If the computation returns `None`, make it an empty tuple. 994 if outputs is None: 995 outputs = tuple() 996 # If the computation only returned one value, makes it a tuple. 997 if not isinstance(outputs, collections.Sequence): 998 outputs = (outputs,) 999 1000 # Append `no_op` here so that fetching any return value of this function 1001 # will trigger TPUExecute node. 1002 outputs += (control_flow_ops.no_op(),) 1003 try: 1004 with ops.device(core(0)): 1005 outputs = [ 1006 o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) 1007 for o in outputs 1008 ] 1009 except Exception as e: 1010 raise ValueError( 1011 "TPU function return values must all either be Operations or " 1012 "convertible to Tensors. Got '%s'" % str(e)) 1013 1014 # Separates the returned Operations and Tensors. 1015 output_operations = [o for o in outputs if isinstance(o, ops.Operation)] 1016 output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] 1017 1018 if outputs != output_tensors + output_operations: 1019 raise ValueError( 1020 "TPU functions must return zero-or more Tensor values followed by " 1021 "zero or more Operations.") 1022 1023 # Wraps outputs in Identity ops. Otherwise a replicated input copied 1024 # straight to an output would bypass the replicate(). This would be bad 1025 # because the TPUReplicatedInput/TPUReplicatedOutput operator would not 1026 # be rewritten away, leading to a runtime error. 1027 # TODO(phawkins): extend the rewrite to elide these nodes instead. 1028 new_output_tensors = [] 1029 for t in output_tensors: 1030 with ops.device(t.device if t.device else core(0)): 1031 o = array_ops.identity(t) 1032 # pylint: disable=protected-access 1033 o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) 1034 # pylint: enable=protected-access 1035 new_output_tensors.append(o) 1036 return new_output_tensors, output_operations 1037 1038 1039def _postprocess_non_flat_outputs(outputs): 1040 """Validates non-flat outputs, add backs device assignments and other attrs. 1041 1042 Args: 1043 outputs: Output from `computation` inside `tpu.rewrite`. 1044 1045 Returns: 1046 Tensors extracted from outputs and an empty list because Operations are not 1047 allowed in non-flat outputs.. 1048 """ 1049 1050 # Flatten output items. 1051 flat_outputs = nest.flatten(outputs) 1052 1053 # Convert all non-Operation outputs to Tensors. 1054 for i, o in enumerate(flat_outputs): 1055 if isinstance(o, ops.Operation): 1056 raise ValueError( 1057 "tpu.rewrite does not support Operation as return value in non-flat " 1058 "output structure. You can set returned Operations as control " 1059 "dependencies of returned Tensors so Operations are triggered when " 1060 'Tensors are evaluated. Operation found: "%s"' % o.name) 1061 1062 try: 1063 o = ops.convert_to_tensor(o) 1064 except Exception as e: 1065 raise ValueError( 1066 "TPU function return values must all either be Operations or " 1067 'convertible to Tensors. Got error: "%s"' % str(e)) 1068 1069 # Wraps outputs in Identity ops. Otherwise a replicated input copied 1070 # straight to an output would bypass the replicate(). This would be bad 1071 # because the TPUReplicatedInput/TPUReplicatedOutput operator would not 1072 # be rewritten away, leading to a runtime error. 1073 # TODO(phawkins): extend the rewrite to elide these nodes instead. 1074 with ops.device(core(0)): 1075 o = array_ops.identity(o) 1076 # pylint: disable=protected-access 1077 o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) 1078 # pylint: enable=protected-access 1079 flat_outputs[i] = array_ops.identity(o) 1080 1081 # All flat_outputs are Tensors, and no Operations. 1082 return flat_outputs, [] 1083 1084 1085def split_compile_and_shard(computation, 1086 inputs=None, 1087 num_shards=1, 1088 input_shard_axes=None, 1089 outputs_from_all_shards=True, 1090 output_shard_axes=None, 1091 infeed_queue=None, 1092 device_assignment=None, 1093 name=None): 1094 """Shards `computation` for parallel execution. 1095 1096 `inputs` must be a list of Tensors or None (equivalent to an empty list), each 1097 of which has a corresponding split axis (from `input_shard_axes`). Each input 1098 is split into `num_shards` pieces along the corresponding axis, and 1099 computation is applied to each shard in parallel. 1100 1101 Tensors are broadcast to all shards if they are lexically captured by 1102 `computation`. e.g., 1103 1104 x = tf.constant(7) 1105 def computation(): 1106 return x + 3 1107 ... = shard(computation, ...) 1108 1109 If `outputs_from_all_shards` is true, the outputs from all shards of 1110 `computation` are concatenated back together along their `output_shards_axes`. 1111 Otherwise, each output is taken from an arbitrary shard. 1112 1113 Inputs and outputs of the computation must be at least rank-1 Tensors. 1114 1115 Args: 1116 computation: A Python function that builds a computation to apply to each 1117 shard of the input. 1118 inputs: A list of input tensors or None (equivalent to an empty list). Each 1119 input tensor has a corresponding shard axes, given by `input_shard_axes`, 1120 which must have size divisible by `num_shards`. 1121 num_shards: The number of shards. 1122 input_shard_axes: A list of dimensions along which to shard `inputs`, or 1123 `None`. `None` means "shard all inputs along dimension 0". If not `None`, 1124 there must be one dimension per input. 1125 outputs_from_all_shards: Boolean or list of boolean. For each output, if 1126 `True`, outputs from all shards are concatenated along the corresponding 1127 `output_shard_axes` entry. Otherwise, each output is taken 1128 from an arbitrary shard. If the argument is a boolean, the argument's 1129 value is used for each output. 1130 output_shard_axes: A list of dimensions along which to concatenate the 1131 outputs of `computation`, or `None`. `None` means "concatenate all outputs 1132 along dimension 0". If not `None`, there must be one dimension per output. 1133 Ignored if `outputs_from_all_shards` is False. 1134 infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs 1135 of `computation`. 1136 device_assignment: If not `None`, a `DeviceAssignment` describing the 1137 mapping between logical cores in the computation with physical cores in 1138 the TPU topology. Uses a default device assignment if `None`. The 1139 `DeviceAssignment` may be omitted if each shard of the computation uses 1140 only one core, and there is either only one shard, or the number of shards 1141 is equal to the number of cores in the TPU system. 1142 name: (Deprecated) Does nothing. 1143 Returns: 1144 A tuple of (compile op, [output tensors]). 1145 Raises: 1146 ValueError: If num_shards <= 0 1147 ValueError: If len(input_shard_axes) != len(inputs) 1148 ValueError: If len(output_shard_axes) != len(outputs from `computation`) 1149 """ 1150 # TODO(phawkins): consider adding support for broadcasting Tensors passed as 1151 # inputs. 1152 1153 if num_shards <= 0: 1154 raise ValueError("num_shards must be a positive integer.") 1155 1156 inputs = [] if inputs is None else inputs 1157 if not isinstance(inputs, list): 1158 raise TypeError("tpu.shard()'s inputs must be a list of Tensors or None.") 1159 1160 # Converts inputs to Tensors. 1161 inputs = [ops.convert_to_tensor(x) for x in inputs] 1162 1163 if input_shard_axes is None: 1164 input_shard_axes = [0] * len(inputs) 1165 if len(inputs) != len(input_shard_axes): 1166 raise ValueError("Length of input_shard_axes must be equal to the number " 1167 "of inputs.") 1168 1169 if inputs: 1170 # Splits the `inputs` along the corresponding `input_shard_axes`, giving 1171 # lists with layout [input][shard] 1172 split_inputs = [ 1173 array_ops.split(x, num_shards, axis=axis) 1174 for (axis, x) in zip(input_shard_axes, inputs)] 1175 1176 # Transposes the input lists to have layout [shard][input] 1177 transposed_inputs = [list(i) for i in zip(*split_inputs)] 1178 else: 1179 transposed_inputs = [[]] * num_shards 1180 1181 compile_op, outputs = split_compile_and_replicate( 1182 computation, 1183 transposed_inputs, 1184 infeed_queue=infeed_queue, 1185 device_assignment=device_assignment, 1186 name=name) 1187 1188 # There must be at least one shard since num_shards > 0. 1189 # TODO(b/36647078) remove disable when pylint bug is fixed. 1190 # pylint: disable=indexing-exception 1191 if isinstance(outputs[0], ops.Operation): 1192 # pylint: enable=indexing-exception 1193 # There were no outputs from the computation and replicate returned a list 1194 # of NoOps with control dependencies on the computation. Return the first 1195 # one so it can be used as a control dependency or fetch node. 1196 # TODO(b/36647078) remove disable when pylint bug is fixed. 1197 # pylint: disable=indexing-exception 1198 return compile_op, [outputs[0]] 1199 # pylint: enable=indexing-exception 1200 1201 # TODO(b/36647078) remove disable when pylint bug is fixed. 1202 # pylint: disable=indexing-exception 1203 num_outputs = len(outputs[0]) 1204 # pylint: enable=indexing-exception 1205 1206 if output_shard_axes is None: 1207 output_shard_axes = [0] * num_outputs 1208 if num_outputs != len(output_shard_axes): 1209 raise ValueError("Length of output_shard_axes must be equal to the number " 1210 "of outputs.") 1211 1212 if isinstance(outputs_from_all_shards, bool): 1213 outputs_from_all_shards = [outputs_from_all_shards] * num_outputs 1214 1215 if num_outputs != len(outputs_from_all_shards): 1216 raise ValueError("Length of outputs_from_all_shards must be equal to the " 1217 "number of outputs.") 1218 1219 results = [] 1220 for (axis, all_shards, x) in zip(output_shard_axes, outputs_from_all_shards, 1221 zip(*outputs)): 1222 if all_shards: 1223 # Concatenate all of the outputs together (use stack for scalars). 1224 shape = x[0].shape 1225 is_scalar = shape is not None and (shape.ndims == 0) 1226 results.append((array_ops.stack(list(x)) if is_scalar 1227 else array_ops.concat(list(x), axis=axis))) 1228 else: 1229 # TODO(phawkins): use a smarter policy, e.g., round-robin across shards. 1230 results.append(x[0]) 1231 1232 return compile_op, results 1233 1234 1235def shard(computation, 1236 inputs=None, 1237 num_shards=1, 1238 input_shard_axes=None, 1239 outputs_from_all_shards=True, 1240 output_shard_axes=None, 1241 infeed_queue=None, 1242 device_assignment=None, 1243 name=None): 1244 """Shards `computation` for parallel execution. 1245 1246 `inputs` must be a list of Tensors or None (equivalent to an empty list), each 1247 of which has a corresponding split axis (from `input_shard_axes`). Each input 1248 is split into `num_shards` pieces along the corresponding axis, and 1249 computation is applied to each shard in parallel. 1250 1251 Tensors are broadcast to all shards if they are lexically captured by 1252 `computation`. e.g., 1253 1254 x = tf.constant(7) 1255 def computation(): 1256 return x + 3 1257 ... = shard(computation, ...) 1258 1259 TODO(phawkins): consider adding support for broadcasting Tensors passed 1260 as inputs. 1261 1262 If `outputs_from_all_shards` is true, the outputs from all shards of 1263 `computation` are concatenated back together along their `output_shards_axes`. 1264 Otherwise, each output is taken from an arbitrary shard. 1265 1266 Inputs and outputs of the computation must be at least rank-1 Tensors. 1267 1268 Args: 1269 computation: A Python function that builds a computation to apply to each 1270 shard of the input. 1271 inputs: A list of input tensors or None (equivalent to an empty list). Each 1272 input tensor has a corresponding shard axes, given by `input_shard_axes`, 1273 which must have size divisible by `num_shards`. 1274 num_shards: The number of shards. 1275 input_shard_axes: A list of dimensions along which to shard `inputs`, or 1276 `None`. `None` means "shard all inputs along dimension 0". If not `None`, 1277 there must be one dimension per input. 1278 outputs_from_all_shards: Boolean or list of boolean. For each output, if 1279 `True`, outputs from all shards are concatenated along the corresponding 1280 `output_shard_axes` entry. Otherwise, each output is taken 1281 from an arbitrary shard. If the argument is a boolean, the argument's 1282 value is used for each output. 1283 output_shard_axes: A list of dimensions along which to concatenate the 1284 outputs of `computation`, or `None`. `None` means "concatenate all outputs 1285 along dimension 0". If not `None`, there must be one dimension per output. 1286 Ignored if `outputs_from_all_shards` is False. 1287 infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs 1288 of `computation`. 1289 device_assignment: If not `None`, a `DeviceAssignment` describing the 1290 mapping between logical cores in the computation with physical cores in 1291 the TPU topology. Uses a default device assignment if `None`. The 1292 `DeviceAssignment` may be omitted if each shard of the computation uses 1293 only one core, and there is either only one shard, or the number of shards 1294 is equal to the number of cores in the TPU system. 1295 name: (Deprecated) Does nothing. 1296 Returns: 1297 A list of output tensors. 1298 Raises: 1299 ValueError: If num_shards <= 0 1300 ValueError: If len(input_shard_axes) != len(inputs) 1301 ValueError: If len(output_shard_axes) != len(outputs from `computation`) 1302 """ 1303 return split_compile_and_shard( 1304 computation, 1305 inputs=inputs, 1306 num_shards=num_shards, 1307 input_shard_axes=input_shard_axes, 1308 outputs_from_all_shards=outputs_from_all_shards, 1309 output_shard_axes=output_shard_axes, 1310 infeed_queue=infeed_queue, 1311 device_assignment=device_assignment, 1312 name=name)[1] 1313 1314 1315def batch_parallel(computation, 1316 inputs=None, 1317 num_shards=1, 1318 infeed_queue=None, 1319 device_assignment=None, 1320 name=None): 1321 """Shards `computation` along the batch dimension for parallel execution. 1322 1323 Convenience wrapper around shard(). 1324 1325 `inputs` must be a list of Tensors or None (equivalent to an empty list). 1326 Each input is split into `num_shards` pieces along the 0-th dimension, and 1327 computation is applied to each shard in parallel. 1328 1329 Tensors are broadcast to all shards if they are lexically captured by 1330 `computation`. e.g., 1331 1332 x = tf.constant(7) 1333 def computation(): 1334 return x + 3 1335 ... = shard(computation, ...) 1336 1337 The outputs from all shards are concatenated back together along their 0-th 1338 dimension. 1339 1340 Inputs and outputs of the computation must be at least rank-1 Tensors. 1341 1342 Args: 1343 computation: A Python function that builds a computation to apply to each 1344 shard of the input. 1345 inputs: A list of input tensors or None (equivalent to an empty list). The 1346 0-th dimension of each Tensor must have size divisible by `num_shards`. 1347 num_shards: The number of shards. 1348 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 1349 of arguments as inputs to `computation`. 1350 device_assignment: If not `None`, a `DeviceAssignment` describing the 1351 mapping between logical cores in the computation with physical cores in 1352 the TPU topology. Uses a default device assignment if `None`. The 1353 `DeviceAssignment` may be omitted if each shard of the computation uses 1354 only one core, and there is either only one shard, or the number of shards 1355 is equal to the number of cores in the TPU system. 1356 name: (Deprecated) Does nothing. 1357 Returns: 1358 A list of output tensors. 1359 Raises: 1360 ValueError: If `num_shards <= 0` 1361 """ 1362 return shard( 1363 computation, 1364 inputs, 1365 num_shards=num_shards, 1366 infeed_queue=infeed_queue, 1367 device_assignment=device_assignment, 1368 name=name) 1369 1370 1371def rewrite(computation, 1372 inputs=None, 1373 infeed_queue=None, 1374 device_assignment=None, 1375 name=None): 1376 """Rewrites `computation` for execution on a TPU system. 1377 1378 Args: 1379 computation: A Python function that builds a computation to apply to the 1380 input. If the function takes n inputs, 'inputs' should be a list of n 1381 tensors. 1382 1383 `computation` may return a list of operations and tensors. Tensors must 1384 come before operations in the returned list. The return value of 1385 `rewrite` is a list of tensors corresponding to the tensors from the 1386 output of `computation`. 1387 1388 All `Operation`s constructed during `computation` will be executed when 1389 evaluating any of the returned output tensors, not just the ones returned. 1390 inputs: A list of input tensors or `None` (equivalent to an empty list). 1391 Each input can be a nested structure containing values that are 1392 convertible to tensors. Note that passing an N-dimension list of 1393 compatible values will result in a N-dimention list of scalar tensors 1394 rather than a single Rank-N tensors. If you need different behavior, 1395 convert part of inputs to tensors with `tf.convert_to_tensor`. 1396 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 1397 of arguments as inputs to `computation`. 1398 device_assignment: if not `None`, a `DeviceAssignment` describing the 1399 mapping between logical cores in the computation with physical cores in 1400 the TPU topology. May be omitted for a single-core computation, in which 1401 case the core attached to task 0, TPU device 0 is used. 1402 name: (Deprecated) Does nothing. 1403 Returns: 1404 Same data structure as if computation(*inputs) is called directly with some 1405 exceptions for correctness. Exceptions include: 1406 1) None output: a NoOp would be returned which control-depends on 1407 computation. 1408 2) Single value output: A tuple containing the value would be returned. 1409 3) Operation-only outputs: a NoOp would be returned which 1410 control-depends on computation. 1411 TODO(b/121383831): Investigate into removing these special cases. 1412 """ 1413 # TODO(b/36647078) remove disable when pylint bug is fixed. 1414 # pylint: disable=indexing-exception 1415 return replicate( 1416 computation, 1417 None if inputs is None else [inputs], 1418 infeed_queue=infeed_queue, 1419 device_assignment=device_assignment, 1420 name=name)[0] 1421 # pylint: enable=indexing-exception 1422 1423 # Operations that indicate some error in the user's inference graph. 1424_BLACKLISTED_INFERENCE_OPS = set([ 1425 "ReadVariableOp", 1426 "AssignVariableOp", 1427 "AssignAddVariableOp", 1428 "AssignSubVariableOp", 1429 "VarHandleOp", 1430 "Variable", 1431 "VariableV2", 1432]) 1433 1434 1435def under_tpu_inference_context(): 1436 """Check if it is currently under `tpu.rewrite_for_inference()`.""" 1437 graph = ops.get_default_graph() 1438 1439 context = graph._get_control_flow_context() # pylint: disable=protected-access 1440 while context: 1441 if isinstance(context, _TPUInferenceContext): 1442 return True 1443 context = context.outer_context 1444 1445 return False 1446 1447 1448class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext): 1449 """A `ControlFlowContext` for nodes inside a TPU inference computation. 1450 1451 The primary role of `TPUReplicateContext` is to sanity check operators inside 1452 a tpu.rewrite_for_inference() computation. 1453 """ 1454 1455 def __init__(self, name): 1456 super(_TPUInferenceContext, self).__init__() 1457 self._name = name 1458 1459 def AddOp(self, op): 1460 self._AddOpInternal(op) 1461 1462 def _AddOpInternal(self, op): 1463 # pylint: disable=protected-access 1464 if op.type in _BLACKLISTED_INFERENCE_OPS: 1465 raise NotImplementedError( 1466 "Operation of type %s (%s) is not supported on the TPU for inference." 1467 " Execution will fail if this op is used in the graph. Make sure your" 1468 " variables are using variable_scope." % (op.type, op.name)) 1469 if self._outer_context: 1470 self._outer_context.AddInnerOp(op) 1471 1472 def AddValue(self, val): 1473 result = val 1474 if self._outer_context: 1475 result = self._outer_context.AddValue(val) 1476 return result 1477 1478 def AddInnerOp(self, op): 1479 self._AddOpInternal(op) 1480 1481 @property 1482 def grad_state(self): 1483 return None 1484 1485 1486def validate_inference_rewrite_for_variables(graph): 1487 """Validates whether rewrite_for_inference() 'worked' for variables. 1488 1489 The rewrite_for_inference() method is supposed to append GuaranteeConstOps 1490 after ReadVariableOps, but this mechanism works only if you are using 1491 tf.get_variable() to create and access variables in your tpu computation. 1492 This validation method can be called immediately after calling 1493 tpu.rewrite_for_inference() to check whether GuaranteeConstOps where added 1494 to the graph. 1495 1496 Typical usages: 1497 tpu.validate_inference_rewrite_for_variables(tf.get_default_graph()) 1498 1499 tpu.validate_inference_rewrite_for_variables(sess.graph) 1500 1501 Args: 1502 graph: The graph which needs to be validated. 1503 Raises: 1504 RuntimeError: if validation failed. 1505 """ 1506 if not any(x.type == "GuaranteeConst" for x in graph.get_operations()): 1507 raise RuntimeError( 1508 "No GuaranteeConst ops found in the graph after running " 1509 "tpu.rewrite_for_inference(...). Please check that you are using " 1510 "tf.get_variable() to create and access variables in your tpu " 1511 "computation.") 1512 1513 1514def rewrite_for_inference(computation, 1515 inputs=None, 1516 infeed_queue=None, 1517 device_assignment=None, 1518 name=None): 1519 """Rewrites `computation` for inference on a TPU system. 1520 1521 Other than 'rewriting' the computation to run on a TPU, if using variables 1522 in your computation, it moves the ReadVariableOps outside the TPU 1523 computation, and adds GuaranteeConst ops just after the ReadVariableOps. 1524 This mechanism works only if you are using tf.get_variable() to create and 1525 access variables in your tpu computation. You can validate whether this 1526 worked, by calling validate_inference_rewrite_for_variables() method 1527 immediately after this method to check whether GuaranteeConstOps where 1528 added to the graph. 1529 1530 Args: 1531 computation: A Python function that builds a computation to apply to the 1532 input. If the function takes n inputs, 'inputs' should be a list of n 1533 tensors. If the function returns m outputs, rewrite will return a list of 1534 m tensors. 1535 inputs: A list of input tensors or `None` (equivalent to an empty list). 1536 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple 1537 of arguments as inputs to `computation`. 1538 device_assignment: if not `None`, a `DeviceAssignment` describing the 1539 mapping between logical cores in the computation with physical cores in 1540 the TPU topology. May be omitted for a single-core computation, in which 1541 case the core attached to task 0, TPU device 0 is used. 1542 name: The name of the operator. 1543 Returns: 1544 A list of output tensors. 1545 """ 1546 1547 def guarantee_const_getter(getter, name, *args, **kwargs): 1548 with ops.control_dependencies(None): 1549 return array_ops.guarantee_const( 1550 getter(name, *args, **kwargs), name=name + "/GuaranteeConst") 1551 1552 def wrapped_computation(*args, **kwargs): 1553 """Execute computation under `_TPUInferenceContext`.""" 1554 context = _TPUInferenceContext( 1555 name=ops.get_default_graph().unique_name("rewrite_for_inference")) 1556 try: 1557 context.Enter() 1558 1559 vscope = variable_scope.get_variable_scope() 1560 prev_custom_getter = vscope.custom_getter 1561 prev_caching_device = vscope.caching_device 1562 vscope.set_custom_getter(guarantee_const_getter) 1563 vscope.set_caching_device(lambda op: op.device) 1564 1565 result = computation(*args, **kwargs) 1566 1567 vscope.set_custom_getter(prev_custom_getter) 1568 vscope.set_caching_device(prev_caching_device) 1569 finally: 1570 context.Exit() 1571 return result 1572 1573 # pylint: disable=undefined-variable 1574 return rewrite( 1575 wrapped_computation, 1576 inputs=inputs, 1577 infeed_queue=infeed_queue, 1578 device_assignment=device_assignment, 1579 name=name) 1580 # pylint: enable=undefined-variable 1581 1582 1583def prune_unconnected_ops_from_xla(prune_graph): 1584 """Prunes unconnected ops as listed in _UNCONNECTED_OPS_TO_PRUNE. 1585 1586 Args: 1587 prune_graph: A tensorflow graph from which we wish to prune unconnected ops 1588 as listed in _UNCONNECTED_OPS_TO_PRUNE. In general, these ops should have 1589 no inputs and no consumers. These can often be left behind due to graph 1590 construction rewiring (for instance TF-Hub). While they never execute, 1591 they will cause XLA compile to fail so we strip them from XLA compile by 1592 removing the tpu_replicate attribute. 1593 """ 1594 # Scan over the top level graph and all function graphs. 1595 for graph in [prune_graph] + list(prune_graph._functions.values()): # pylint: disable=protected-access 1596 for op in graph.get_operations(): 1597 if op.type not in _UNCONNECTED_OPS_TO_PRUNE: 1598 continue 1599 outputs_consumed = False 1600 for output in op.outputs: 1601 if output.consumers(): 1602 outputs_consumed = True 1603 break 1604 if not outputs_consumed: 1605 logging.info( 1606 "Pruning OP %s of type %s from XLA Compile due to " 1607 "it being disconnected.", op.name, op.type) 1608 op._clear_attr(_TPU_REPLICATE_ATTR) # pylint: disable=protected-access 1609