• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ======================================
15
16"""Library of TPU helper functions."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23from six.moves import xrange  # pylint: disable=redefined-builtin
24
25from tensorflow.core.framework import attr_value_pb2
26from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
27from tensorflow.python.compat import compat as api_compat
28from tensorflow.python.framework import device as pydev
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import tensor_shape
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import variable_scope
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.tpu import tpu_function
39from tensorflow.python.tpu import xla
40from tensorflow.python.tpu.ops import tpu_ops
41from tensorflow.python.util import compat
42from tensorflow.python.util import nest
43
44
45# Operations that indicate some error in the users graph, e.g. a placeholder
46# that's introduced outside of the infeed.
47_BLACKLISTED_OPS = set([
48    "Placeholder",
49])
50
51# XLA doesn't currently support reading of intermediate tensors, thus some ops
52# are not supported.
53_UNSUPPORTED_OPS = set([
54    "AudioSummary",
55    "AudioSummaryV2",
56    "HistogramSummary",
57    "ImageSummary",
58    "MergeSummary",
59    "Print",
60    "ScalarSummary",
61    "TensorSummary",
62    "TensorSummaryV2",
63    ])
64
65# Ops which can be safely pruned from XLA compile if they have no consumers.
66#  These ops should also have no inputs.
67_UNCONNECTED_OPS_TO_PRUNE = set(["Placeholder", "VarHandleOp"])
68
69_MAX_WARNING_LINES = 5
70
71_TPU_REPLICATE_ATTR = "_tpu_replicate"
72_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status"
73_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation"
74
75
76def _tpu_system_device_name(job):
77  """Returns the device name for the TPU_SYSTEM device of `job`."""
78  if job is None:
79    return "/device:TPU_SYSTEM:0"
80  else:
81    return "/job:%s/device:TPU_SYSTEM:0" % job
82
83
84def initialize_system(embedding_config=None, job=None):
85  """Initializes a distributed TPU system for use with TensorFlow.
86
87  Args:
88    embedding_config: If not None, a `TPUEmbeddingConfiguration` proto
89      describing the desired configuration of the hardware embedding lookup
90      tables. If embedding_config is None, no hardware embeddings can be used.
91    job: The job (the XXX in TensorFlow device specification /job:XXX) that
92      contains the TPU devices that will be initialized. If job=None it is
93      assumed there is only one job in the TensorFlow flock, and an error will
94      be returned if this assumption does not hold.
95  Returns:
96    A serialized `TopologyProto` that describes the TPU system. Note:
97      the topology must be evaluated using `Session.run` before it can be used.
98  """
99  config_string = ("" if embedding_config is None else
100                   embedding_config.SerializeToString())
101  with ops.device(_tpu_system_device_name(job)):
102    return tpu_ops.configure_distributed_tpu(embedding_config=config_string)
103
104
105def shutdown_system(job=None):
106  """Shuts down a running a distributed TPU system."""
107  with ops.device(_tpu_system_device_name(job)):
108    shutdown_distributed_tpu = tpu_ops.shutdown_distributed_tpu()
109  return shutdown_distributed_tpu
110
111
112def core(num):
113  """Returns the device name for a core in a replicated TPU computation.
114
115  Args:
116    num: the virtual core number within each replica to which operators should
117    be assigned.
118  Returns:
119    A device name, suitable for passing to `tf.device()`.
120  """
121  return "device:TPU_REPLICATED_CORE:{}".format(num)
122
123
124class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
125  """A `ControlFlowContext` for nodes inside a TPU computation.
126
127  The primary role of `TPUReplicateContext` is to mark operators inside a
128  tpu.replicate() computation with the attribute "_tpu_replicate=XYZ", where XYZ
129  is a unique name.
130
131  We use a `ControlFlowContext` to perform the annotation since it integrates
132  with Tensorflow constructs like ResourceVariables. For example, if a
133  `ResourceVariable` is constructed inside a tpu.replicate() block, the
134  `ResourceVariable` implementation can use
135  `with ops.control_dependencies(None)` to build the variable's definition
136  outside the replicated computation.
137  """
138
139  def __init__(self, name, num_replicas, pivot):
140    """Builds a new TPUReplicateContext.
141
142    Args:
143      name: a unique name for the context, used to populate the `_tpu_replicate`
144        attribute.
145      num_replicas: an integer that gives the number of replicas for the
146        computation.
147      pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any
148        inputs will have a control dependency on the pivot node. This ensures
149        that nodes are correctly included in any enclosing control flow
150        contexts.
151    """
152    super(TPUReplicateContext, self).__init__()
153    self._num_replicas = num_replicas
154    self._outer_device_function_stack = None
155    self._oc_dev_fn_stack = None
156    self._outside_compilation_cluster = None
157    self._outside_compilation_counter = 0
158    self._in_gradient_colocation = None
159    self._gradient_colocation_stack = []
160    self._host_compute_core = []
161    self._name = name
162    self._name_as_bytes = compat.as_bytes(name)
163    self._unsupported_ops = []
164    self._pivot = pivot
165    self._replicated_vars = {}
166
167  def get_replicated_var_handle(self, name, vars_):
168    """Returns a variable handle for replicated TPU variable 'var'.
169
170    This is a method used by an experimental replicated variable implementation
171    and is not intended as a public API.
172
173    Args:
174      name: The common name of the variable.
175      vars_: The replicated TPU variables.
176
177    Returns:
178      The handle of the TPU replicated input node.
179    """
180    handle = self._replicated_vars.get(name)
181    if handle is not None:
182      return handle
183
184    # Builds a TPUReplicatedInput node for the variable, if one does not already
185    # exist. The TPUReplicatedInput node must belong to the enclosing
186    # control-flow scope of the TPUReplicateContext.
187    # TODO(phawkins): consider changing the contract of the TPU encapsulation
188    # so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope
189    # instead.
190
191    # pylint: disable=protected-access
192    graph = ops.get_default_graph()
193    saved_context = graph._get_control_flow_context()
194    graph._set_control_flow_context(self.outer_context)
195    handle = tpu_ops.tpu_replicated_input(
196        [v.handle for v in vars_], name=name + "/handle")
197    graph._set_control_flow_context(saved_context)
198    # pylint: enable=protected-access
199    self._replicated_vars[name] = handle
200    return handle
201
202  def report_unsupported_operations(self):
203    if self._unsupported_ops:
204      op_str = "\n".join(["  %s (%s)" % (op.type, op.name)
205                          for op in self._unsupported_ops[:_MAX_WARNING_LINES]])
206      logging.warning("%d unsupported operations found: \n%s",
207                      len(self._unsupported_ops), op_str)
208      if len(self._unsupported_ops) > _MAX_WARNING_LINES:
209        logging.warning("... and %d more" %
210                        (len(self._unsupported_ops) - _MAX_WARNING_LINES))
211
212  def EnterGradientColocation(self, op, gradient_uid):
213    if op is not None:
214      self._gradient_colocation_stack.append(op)
215      if not self._outside_compilation_cluster:
216        try:
217          outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR)
218          if self._in_gradient_colocation:
219            raise NotImplementedError(
220                "Cannot nest gradient colocation operations outside compilation"
221            )
222          if gradient_uid == "__unsupported__":
223            raise NotImplementedError(
224                "No gradient_uid calling gradient within outside_compilation")
225          # When we take the gradient of an op X in an outside_compilation
226          # cluster C in a forward computation we would like to put the ops
227          # corresponding to the gradient of X into a new outside_compilation
228          # cluster C'. However, if we take the gradient of X twice, the second
229          # one should get yet another new outside_compilation cluster C''.
230          #
231          # The mechanism we adopt is to use a 'root_cluster' which is the
232          # cluster that X was in before we took gradients, and a 'gradient_uid'
233          # which is different for every invocation of gradients, and put the
234          # gradient of X in cluster 'root_cluster.gradient_uid'.
235          #
236          # When taking a gradient of a gradient, some ops will be colocated
237          # with Op in the forward pass (e.g., cluster root_cluster) and some in
238          # the backward pass (e.g., cluster root_cluster.initial_gradient_uid).
239          # We need all of the grad-of-grad ops to be in the same cluster to
240          # avoid cyclic dependencies between clusters. We adopt a heuristic
241          # that puts any op clustered with root_cluster.<xxx> in
242          # root_cluster.gradient_uid, even if xxx was initial_gradient_uid.
243          self._in_gradient_colocation = op
244          parts = outside_attr.split(".")
245          cluster = parts[0] + "." + gradient_uid
246          self._EnterOutsideCompilationScope(cluster=cluster)
247        except ValueError:
248          # The attr was not present: do nothing.
249          pass
250
251  def ExitGradientColocation(self, op, gradient_uid):
252    if op is not None:
253      if not self._gradient_colocation_stack:
254        raise errors.InternalError(
255            op.node_def, op,
256            "Badly nested gradient colocation: empty stack when popping Op " +
257            op.name)
258      last_op = self._gradient_colocation_stack.pop()
259      if op is last_op:
260        if op is self._in_gradient_colocation:
261          self._in_gradient_colocation = None
262          self._ExitOutsideCompilationScope()
263      else:
264        raise errors.InternalError(
265            op.node_def, op, "Badly nested gradient colocation, expected " +
266            last_op + ", got " + op.name)
267
268  def _EnterOutsideCompilationScope(self, cluster=None):
269
270    class FakeOp(object):
271      """A helper class to determine the current device.
272
273      Supports only the type and device set/get methods needed to run the
274      graph's _apply_device_function method.
275      """
276
277      def __init__(self):
278        self._device = ""
279
280      @property
281      def type(self):
282        return "FakeOp"
283
284      @property
285      def device(self):
286        return self._device
287
288      def _set_device(self, device):
289        if isinstance(device, pydev.DeviceSpec):
290          self._device = device.to_string()
291        else:
292          self._device = device
293
294    if self._outside_compilation_cluster:
295      raise NotImplementedError("Cannot nest outside_compilation clusters")
296    if cluster:
297      self._outside_compilation_cluster = cluster
298    else:
299      self._outside_compilation_cluster = str(self._outside_compilation_counter)
300      self._outside_compilation_counter += 1
301    graph = ops.get_default_graph()
302    fake_op = FakeOp()
303    graph._apply_device_functions(fake_op)  # pylint: disable=protected-access
304    device = pydev.DeviceSpec.from_string(fake_op.device)
305    if (device.device_type == "TPU_REPLICATED_CORE" and
306        device.device_index is not None):
307      self._host_compute_core.append(self._outside_compilation_cluster + ":" +
308                                     str(device.device_index))
309    self._oc_dev_fn_stack = graph._device_function_stack  # pylint: disable=protected-access
310    graph._device_function_stack = self._outer_device_function_stack  # pylint: disable=protected-access
311
312  def _ExitOutsideCompilationScope(self):
313    if not self._outside_compilation_cluster:
314      raise NotImplementedError(
315          "Attempted to exit outside_compilation scope when not in scope")
316    self._outside_compilation_cluster = None
317    graph = ops.get_default_graph()
318    graph._device_function_stack = self._oc_dev_fn_stack  # pylint: disable=protected-access
319
320  def Enter(self):
321    if not self._outer_device_function_stack:
322      # Capture the device function stack at the time of first entry
323      # since that is the stack that will be used outside_compilation.
324      graph = ops.get_default_graph()
325      # pylint: disable=protected-access
326      self._outer_device_function_stack = graph._device_function_stack.copy()
327      # pylint: enable=protected-access
328    super(TPUReplicateContext, self).Enter()
329
330  def HostComputeCore(self):
331    return self._host_compute_core
332
333  def _RemoveExternalControlEdges(self, op):
334    """Remove any external control dependency on this op."""
335    internal_control_inputs = []
336    external_control_inputs = []
337    for x in op.control_inputs:
338      # pylint: disable=protected-access
339      is_internal_op = False
340      ctxt = x._get_control_flow_context()
341      while ctxt is not None:
342        if ctxt == self:
343          is_internal_op = True
344          break
345        ctxt = ctxt._outer_context
346      if is_internal_op:
347        internal_control_inputs.append(x)
348      else:
349        external_control_inputs.append(x)
350      # pylint: enable=protected-access
351    # pylint: disable=protected-access
352    op._remove_all_control_inputs()
353    op._add_control_inputs(internal_control_inputs)
354    # pylint: enable=protected-access
355    return internal_control_inputs, external_control_inputs
356
357  def AddOp(self, op):
358    # pylint: disable=protected-access
359    if op.type in _BLACKLISTED_OPS:
360      logging.error("Operation of type %s (%s) is not supported on the TPU. "
361                    "Execution will fail if this op is used in the graph. " %
362                    (op.type, op.name))
363
364    if op.type in _UNSUPPORTED_OPS:
365      self._unsupported_ops.append(op)
366
367    if any(x.dtype._is_ref_dtype for x in op.inputs):
368      raise NotImplementedError(
369          "Non-resource Variables are not supported inside TPU computations "
370          "(operator name: %s)" % op.name)
371    if _TPU_REPLICATE_ATTR in op.node_def.attr:
372      raise ValueError("TPU computations cannot be nested")
373    op._set_attr(_TPU_REPLICATE_ATTR,
374                 attr_value_pb2.AttrValue(s=self._name_as_bytes))
375    if self._outside_compilation_cluster:
376      op._set_attr(
377          _OUTSIDE_COMPILATION_ATTR,
378          attr_value_pb2.AttrValue(
379              s=compat.as_bytes(self._outside_compilation_cluster)))
380    if self._num_replicas > 1 or not self._outside_compilation_cluster:
381      # Prevent feeding or fetching anything that is being compiled,
382      # and any replicated outside_compilation Op.
383      op.graph.prevent_feeding(op)
384      op.graph.prevent_fetching(op)
385
386    # Remove any control edges from outer control flow contexts. These may cause
387    # mismatched frame errors.
388    (internal_control_inputs,
389     external_control_inputs) = self._RemoveExternalControlEdges(op)
390
391    if not op.inputs:
392      # Add a control edge from the control pivot to this op.
393      if not internal_control_inputs:
394        # pylint: disable=protected-access
395        op._add_control_input(self.GetControlPivot())
396        # pylint: enable=protected-access
397    else:
398      for index in xrange(len(op.inputs)):
399        x = op.inputs[index]
400        real_x = self.AddValue(x)
401        if real_x != x:
402          op._update_input(index, real_x)  # pylint: disable=protected-access
403
404    if external_control_inputs:
405      # Use an identity to pull control inputs as data inputs. Note that we
406      # ignore ops which don't have outputs. TODO(phawkins): fix that.
407      with ops.control_dependencies(None):
408        self.Enter()
409        external_control_inputs = [
410            array_ops.identity(x.outputs[0]).op
411            for x in external_control_inputs
412            if x.outputs
413        ]
414        self.Exit()
415      # pylint: disable=protected-access
416      op._add_control_inputs(external_control_inputs)
417      # pylint: enable=protected-access
418
419    # Mark op's outputs as seen by this context and any outer contexts.
420    output_names = [x.name for x in op.outputs]
421    context = self
422    while context is not None:
423      # pylint: disable=protected-access
424      context._values.update(output_names)
425      context = context._outer_context
426      # pylint: enable=protected-access
427
428    if self._outer_context:
429      self._outer_context.AddInnerOp(op)
430
431  def AddValue(self, val):
432    """Add `val` to the current context and its outer context recursively."""
433    if val.name in self._values:
434      # Use the real value if it comes from outer context.
435      result = self._external_values.get(val.name)
436      return val if result is None else result
437
438    result = val
439    self._values.add(val.name)
440    if self._outer_context:
441      result = self._outer_context.AddValue(val)
442      self._values.add(result.name)
443
444    self._external_values[val.name] = result
445
446    return result
447
448  def AddInnerOp(self, op):
449    self.AddOp(op)
450    if self._outer_context:
451      self._outer_context.AddInnerOp(op)
452
453  @property
454  def grad_state(self):
455    # Define the gradient loop state associated with the TPUReplicateContext to
456    # be None as the TPUReplicateContext does not get nested nor does the
457    # grad_state outside the TPUReplicateContext affect the graph inside so the
458    # grad_state should be as if this is the top-level gradient state.
459    return None
460
461  @property
462  def back_prop(self):
463    """Forwards to the enclosing while context, if any."""
464    if self.GetWhileContext():
465      return self.GetWhileContext().back_prop
466    return False
467
468  def GetControlPivot(self):
469    return self._pivot
470
471
472def outside_compilation(computation, *args, **kwargs):
473  """Builds part of a computation outside any current TPU replicate scope.
474
475  Args:
476    computation: A Python function that builds the computation to
477      place on the host.
478    *args: the positional arguments for the computation.
479    **kwargs: the keyword arguments for the computation.
480
481  Returns:
482    The Tensors returned by computation.
483  """
484  args = [] if args is None else args
485  graph = ops.get_default_graph()
486
487  # If we are in a TPUReplicateContext, signal that we are now
488  # outside_compilation
489  initial_context = graph._get_control_flow_context()  # pylint: disable=protected-access
490  context = initial_context
491  while context:
492    if isinstance(context, TPUReplicateContext):
493      context._EnterOutsideCompilationScope()  # pylint: disable=protected-access
494    context = context.outer_context
495
496  retval = computation(*args, **kwargs)
497
498  # If we are in a TPUReplicateContext, signal that we are no longer
499  # outside_compilation
500  final_context = graph._get_control_flow_context()  # pylint: disable=protected-access
501  if initial_context is not final_context:
502    raise NotImplementedError(
503        "Control-flow context cannot be different at start and end of an "
504        "outside_compilation scope")
505  context = initial_context
506  while context:
507    if isinstance(context, TPUReplicateContext):
508      context._ExitOutsideCompilationScope()  # pylint: disable=protected-access
509    context = context.outer_context
510
511  return retval
512
513
514def replicate(computation,
515              inputs=None,
516              infeed_queue=None,
517              device_assignment=None,
518              name=None,
519              maximum_shapes=None):
520  """Builds a graph operator that runs a replicated TPU computation.
521
522  Args:
523    computation: A Python function that builds the computation to replicate.
524    inputs: A list of lists of input tensors or `None` (equivalent to
525      `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
526      have the same number of inputs. Each input can be a nested structure
527      containing values that are convertible to tensors. Note that passing an
528      N-dimension list of compatible values will result in a N-dimention list of
529      scalar tensors rather than a single Rank-N tensors. If you need different
530      behavior, convert part of inputs to tensors with `tf.convert_to_tensor`.
531    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
532      of arguments as inputs to computation.
533    device_assignment: If not `None`, a `DeviceAssignment` describing the
534      mapping between logical cores in the computation with physical cores in
535      the TPU topology. Uses a default device assignment if `None`. The
536      `DeviceAssignment` may be omitted if each replica of the computation uses
537      only one core, and there is either only one replica, or the number of
538      replicas is equal to the number of cores in the TPU system.
539    name: (Deprecated) Does nothing.
540    maximum_shapes: A nested structure of tf.TensorShape representing the shape
541      to which the respective component of each input element in each replica
542      should be padded. Any unknown dimensions (e.g. tf.Dimension(None) in a
543      tf.TensorShape or -1 in a tensor-like object) will be padded to the
544      maximum size of that dimension over all replicas. Note that if the input
545      dimension is already static, we won't do padding on it and we require the
546      maximum_shapes to have the same value or None on that dimension. The
547      structure of `maximum_shapes` needs to be the same as `inputs[0]`.
548  Returns:
549    A list of outputs, indexed by `[replica_num]` each output can be a nested
550    structure same as what computation() returns with a few exceptions.
551
552    Exceptions include:
553      1) None output: a NoOp would be returned which control-depends on
554         computation.
555      2) Single value output: A tuple containing the value would be returned.
556      3) Operation-only outputs: a NoOp would be returned which
557         control-depends on computation.
558      TODO(b/121383831): Investigate into removing these special cases.
559
560  Raises:
561    ValueError: If all replicas do not have equal numbers of input tensors.
562    ValueError: If the number of inputs per replica does not match
563      the number of formal parameters to `computation`.
564    ValueError: If the static `inputs` dimensions don't match with the values
565      given in `maximum_shapes`.
566    ValueError: If the structure of inputs per replica does not match
567      the structure of `maximum_shapes`.
568  """
569  return split_compile_and_replicate(
570      computation,
571      inputs,
572      infeed_queue,
573      device_assignment,
574      name,
575      maximum_shapes=maximum_shapes)[1]
576
577
578def _pad_all_input(inputs, padded_shapes):
579  """Pad all input tensors given padded_shapes.
580
581  The real shape tensors will be concatenated with the padded original inputs.
582
583  Args:
584    inputs: The original inputs.
585    padded_shapes: A list of padded shapes for each input.
586
587  Returns:
588    The padded inputs and a PaddingMap list which maps the padded input
589    dimension to the real shape argument index.
590  """
591  input_shape_tensors = []
592  for core_idx, inputs_per_core in enumerate(inputs):
593    for idx, input_tensor in enumerate(inputs_per_core):
594      if core_idx == 0:
595        input_shape_tensors.append([])
596      input_shape_tensors[idx].append(array_ops.shape(input_tensor))
597
598  maximum_shapes = []
599  for shapes_per_input in input_shape_tensors:
600    maximum_shapes.append(
601        math_ops.reduce_max(array_ops.stack(shapes_per_input), axis=0))
602
603  padded_inputs = []
604  real_shapes = []
605  padding_maps = []
606  for core_idx, inputs_per_core in enumerate(inputs):
607    padded_inputs.append([])
608    real_shapes.append([])
609    real_shape_idx = len(inputs_per_core) - 1
610    for idx, input_tensor in enumerate(inputs_per_core):
611      input_shape_tensor = input_shape_tensors[idx][core_idx]
612      input_shape = input_tensor.get_shape()
613      padded_shape = padded_shapes[idx]
614
615      # The static shape of inputs should be compatible with the given padded
616      # shapes.
617      input_shape.assert_is_compatible_with(padded_shape)
618
619      if input_shape.is_fully_defined():
620        # Do nothing if the shape of the whole tensor is already static.
621        padded_inputs[core_idx].append(input_tensor)
622      else:
623        # Only pad the non static shape dimension.
624        for i, s in enumerate(input_shape):
625          if s.value is None:
626            if core_idx == 0:
627              real_shape_idx += 1
628              padding_map = dynamic_padding.PaddingMap()
629              padding_map.arg_index = idx
630              padding_map.shape_index = i
631              padding_map.padding_arg_index = real_shape_idx
632              padding_maps.append(padding_map)
633            real_shapes[core_idx].append(
634                math_ops.cast(input_shape_tensor[i], dtypes.uint32))
635
636        paddings = []
637        for i, s in enumerate(padded_shape):
638          if input_shape[i].value:
639            # Don't pad if input shape is already static.
640            padding = [0, 0]
641          else:
642            if s.value:
643              # Pad to the given maximum value.
644              padding = [0, s.value - input_shape_tensor[i]]
645            else:
646              # If maximum value is not given, then pad to the maximum dimension
647              # among all the cores.
648              padding = [0, maximum_shapes[idx][i] - input_shape_tensor[i]]
649          paddings.append(padding)
650
651        padded_input = array_ops.pad(input_tensor, paddings)
652        padded_inputs[core_idx].append(padded_input)
653
654  num_replicas = len(padded_inputs)
655  for i in range(num_replicas):
656    padded_inputs[i].extend(real_shapes[i])
657
658  return padded_inputs, padding_maps
659
660
661def split_compile_and_replicate(computation,
662                                inputs=None,
663                                infeed_queue=None,
664                                device_assignment=None,
665                                name=None,
666                                use_tpu=True,
667                                maximum_shapes=None):
668  """Builds graph operators that runs compilation and replicated computation.
669
670  This is a lower level interface than replicate that returns a separate compile
671  and execute output tensor. In the generated graph the compile op feeds into
672  the execute op and no additional compilation is incurred when running the
673  compile op before the execute op. The compile op returns additional
674  information about the compilation but does not return the compiled program.
675
676  Args:
677    computation: A Python function that builds the computation to replicate.
678    inputs: A list of lists of input tensors or `None` (equivalent to
679      `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
680      have the same number of inputs. Each input can be a nested structure
681      containing values that are convertible to tensors. Note that passing an
682      N-dimension list of compatible values will result in a N-dimention list of
683      scalar tensors rather than a single Rank-N tensors. If you need different
684      behavior, convert part of inputs to tensors with `tf.convert_to_tensor`.
685    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
686      of arguments as inputs to computation.
687    device_assignment: If not `None`, a `DeviceAssignment` describing the
688      mapping between logical cores in the computation with physical cores in
689      the TPU topology. Uses a default device assignment if `None`. The
690      `DeviceAssignment` may be omitted if each replica of the computation uses
691      only one core, and there is either only one replica, or the number of
692      replicas is equal to the number of cores in the TPU system.
693    name: (Deprecated) Does nothing.
694    use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU
695      backends. Currently, only supports a default placement (computation is
696      placed on GPU if one is available, and on CPU if not).
697    maximum_shapes: A nested structure of tf.TensorShape representing the shape
698      to which the respective component of each input element in each replica
699      should be padded. Any unknown dimensions (e.g. tf.Dimension(None) in a
700      tf.TensorShape or -1 in a tensor-like object) will be padded to the
701      maximum size of that dimension over all replicas. Note that if the input
702      dimension is already static, we won't do padding on it and we require the
703      maximum_shapes to have the same value or None on that dimension. The
704      structure of `maximum_shapes` needs to be the same as `inputs[0]`.
705
706  Returns:
707    A list of lists with the first list corresponding to the compile op and the
708    second a list of output tensors, indexed by `[replica_num][output_num]`.
709  Raises:
710    ValueError: If all replicas do not have equal numbers of input tensors.
711    ValueError: If the number of inputs per replica does not match
712      the number of formal parameters to `computation`.
713    ValueError: If the static `inputs` dimensions don't match with the values
714      given in `maximum_shapes`.
715    ValueError: If the structure of inputs per replica does not match
716      the structure of `maximum_shapes`.
717  """
718  del name
719  inputs = [[]] if inputs is None else inputs
720
721  metadata_kwargs = {}
722  if device_assignment is not None:
723    # Turn the Numpy array into a flattened list so we can pass it as an
724    # operator attribute.
725    metadata_kwargs = {
726        "topology":
727            device_assignment.topology.serialized(),
728        "device_assignment":
729            device_assignment.core_assignment.flatten().tolist()
730    }
731    # TODO(phawkins): remove this case after the forward compatibility window
732    # expires on 2018-10-5.
733    if api_compat.forward_compatible(2018, 10, 5):
734      metadata_kwargs["num_cores_per_replica"] = (
735          device_assignment.num_cores_per_replica)
736    else:
737      metadata_kwargs["computation_shape"] = [
738          device_assignment.num_cores_per_replica
739      ]
740
741  if ((not isinstance(inputs, list)) or
742      any(not isinstance(inp, (list, tuple)) for inp in inputs)):
743    raise TypeError("tpu.replicate() inputs must be a list of lists/tuples")
744
745  num_replicas = len(inputs)
746
747  # No replicas? Nothing to do.
748  if num_replicas == 0:
749    return []
750
751  # Checks all replicas have the same structure.
752  for i in xrange(1, num_replicas):
753    nest.assert_same_structure(inputs[0], inputs[i])
754
755  # Flatten inputs.
756  flat_inputs = [
757      nest.flatten(per_replica_input) for per_replica_input in inputs
758  ]
759  # Converts inputs to Tensors.
760  flat_inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in flat_inputs]
761
762  # Verifies that all replicas have matching numbers and types of inputs
763  flat_input_types = [x.dtype for x in flat_inputs[0]]
764  input_arity = len(inputs[0])
765  flat_input_arity = len(flat_input_types)
766  for i in range(num_replicas):
767    if len(inputs[i]) != input_arity:
768      raise ValueError("Replicas must have the same number of inputs. "
769                       "Replica 0 had {} inputs, replica {} had {} "
770                       "inputs.".format(input_arity, i, len(inputs[i])))
771
772    types = [x.dtype for x in flat_inputs[i]]
773    if types != flat_input_types:
774      raise ValueError("Replicas must have matching input types. Replica 0 had "
775                       "input types {}, replica {} had input types {}".format(
776                           flat_input_types, i, types))
777
778  arg_error = xla.check_function_argument_count(
779      computation, input_arity, infeed_queue)
780  if arg_error is not None:
781    if infeed_queue is None:
782      raise TypeError(
783          "Supplied computation cannot be called with the specified inputs. "
784          "You specified %d inputs: %s, but the computation needs %s" % (
785              input_arity, str([i.name for i in inputs[0]]), arg_error))
786    else:
787      raise TypeError(
788          "Supplied computation cannot be called with the specified inputs. "
789          "You specified %d inputs: %s and %d additional inputs from infeed,"
790          " but the computation needs %s" % (input_arity, str(
791              [i.name
792               for i in inputs[0]]), infeed_queue.number_of_tuple_elements,
793                                             arg_error))
794
795  if maximum_shapes:
796    if infeed_queue:
797      raise ValueError(
798          "Dynamic input shapes are not supported with infeed queues")
799
800    # Make sure maximum_shapes has the same structure as inputs.
801    nest.assert_same_structure(inputs[0], maximum_shapes, check_types=False)
802
803    # Flatten padded shapes.
804    flat_maximum_shapes = nest.flatten(maximum_shapes)
805    flat_maximum_shapes = [
806        tensor_shape.TensorShape(s) for s in flat_maximum_shapes
807    ]
808
809    flat_inputs, padding_maps = _pad_all_input(flat_inputs, flat_maximum_shapes)
810
811    serialized_padding_maps = []
812    for padding_map in padding_maps:
813      serialized_padding_maps.append(padding_map.SerializeToString())
814    metadata_kwargs["padding_map"] = serialized_padding_maps
815
816  metadata_kwargs["step_marker_location"] = getattr(
817      computation, "step_marker_location", "STEP_MARK_AT_ENTRY")
818
819  graph = ops.get_default_graph()
820
821  # Fan-in: Builds a TPUReplicatedInput node for each input.
822  flat_replicated_inputs = []
823  for i in range(0, len(flat_inputs[0])):
824    replicas = [flat_inputs[replica][i] for replica in xrange(num_replicas)]
825    flat_replicated_inputs.append(
826        tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))
827
828  cluster_name = graph.unique_name("cluster")
829  pivot = control_flow_ops.no_op(name=cluster_name + "/pivot")
830  context = TPUReplicateContext(
831      name=cluster_name, num_replicas=num_replicas, pivot=pivot)
832  try:
833    context.Enter()
834
835    metadata = tpu_ops.tpu_replicate_metadata(
836        num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs)
837
838    with tpu_function.tpu_shard_context(
839        num_replicas), ops.control_dependencies([metadata]):
840
841      # Add identity ops so even unused inputs are "consumed" by the
842      # computation. This is to avoid orphaned TPUReplicatedInput nodes.
843      # TODO(phawkins): consider instead pruning unused TPUReplicatedInput
844      # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs.
845      flat_replicated_inputs = [
846          array_ops.identity(x, name="replicated_input_{}".format(i))
847          for i, x in enumerate(flat_replicated_inputs)
848      ]
849      for i in flat_replicated_inputs:
850        # pylint: disable=protected-access
851        # Add an attribute to the identity node so that they could be removed in
852        # encapsulate TPU computation pass if unused. However we don't remove
853        # inputs when dynamic padding is enabled.
854        # TODO(rxsang): Use other ways except argument index in padding_map so
855        # outside compilation can work with dynamic padding correctly.
856        if maximum_shapes is None:
857          i.op._set_attr("_tpu_input_identity",
858                         attr_value_pb2.AttrValue(b=True))
859        # pylint: enable=protected-access
860
861      # Unflatten the computation inputs to match original input structure.
862      computation_inputs = nest.pack_sequence_as(
863          structure=inputs[0],
864          flat_sequence=flat_replicated_inputs[:flat_input_arity])
865
866      # If there is an infeed queue, adds the dequeued values to the
867      # computation's inputs.
868      if infeed_queue is not None:
869        infeed_queue.set_number_of_shards(num_replicas)
870        for t in infeed_queue.generate_dequeue_op():
871          computation_inputs.append(t)
872
873      # Only resource variables work inside a TPU computation, so turn on
874      # resource variables for the computation.
875      # TODO(phawkins): consider removing this code. It will
876      # be less confusing to clients if they knowingly choose to use resource
877      # variables.
878      # Partitioned variables is not supported (b/112311320).
879      vscope = variable_scope.get_variable_scope()
880      saved_use_resource = vscope.use_resource
881      saved_custom_getter = vscope.custom_getter
882
883      def custom_getter(getter, name, *args, **kwargs):
884        """Variables on TPU have a few restrictions."""
885        partitioner = kwargs["partitioner"]
886        if partitioner is not None:
887          kwargs["partitioner"] = None
888          logging.warning(
889              "Partitioned variables are not supported on TPU. Got "
890              "`partitioner` that is {} for variable {}. "
891              "Setting `partitioner` to `None`."
892              .format(partitioner, name))
893        if saved_custom_getter is None:
894          return getter(name, *args, **kwargs)
895        else:
896          return saved_custom_getter(getter, name, *args, **kwargs)
897
898      vscope.set_use_resource(True)
899      vscope.set_custom_getter(custom_getter)
900
901      outputs = computation(*computation_inputs)
902
903      vscope.set_use_resource(saved_use_resource)
904      vscope.set_custom_getter(saved_custom_getter)
905
906    outputs_is_flat = xla.is_flat(outputs)
907    if outputs_is_flat:
908      output_tensors, control_deps = _postprocess_flat_outputs(outputs)
909    else:
910      output_tensors, control_deps = _postprocess_non_flat_outputs(outputs)
911
912    # tensor_tracer imports tpu.py. Local import to tensor_tracer to avoid
913    # import-cycle
914    # pylint: disable=g-import-not-at-top
915    from tensorflow.python.tpu import tensor_tracer
916    # pylint: enable=g-import-not-at-top
917    if tensor_tracer.TensorTracer.is_enabled():
918      tt = tensor_tracer.TensorTracer()
919      output_tensors = tt.trace_tpu(ops.get_default_graph(),
920                                    output_tensors, control_deps,
921                                    num_replicas)
922
923    context.ExitResult(output_tensors)
924  finally:
925    context.report_unsupported_operations()
926    context.Exit()
927    host_compute_core = context.HostComputeCore()
928
929  if host_compute_core:
930    attr_value = attr_value_pb2.AttrValue()
931    attr_value.list.s.extend([compat.as_bytes(x) for x in host_compute_core])
932    metadata._set_attr("host_compute_core", attr_value)  # pylint: disable=protected-access
933
934  with ops.control_dependencies([metadata]):
935    if use_tpu:
936      compile_status = tpu_ops.tpu_compilation_result()
937      op = compile_status.op
938      attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name))
939      op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value)  # pylint: disable=protected-access
940    else:
941      compile_status = control_flow_ops.no_op(name="compilation_status")
942
943  if not output_tensors:
944    # Returns a list of NoOps dependent on the replication Op, indexed by
945    # [replica_num].
946    return [
947        compile_status,
948        [
949            control_flow_ops.group(control_deps, name="shard_%d" % i)
950            for i in range(num_replicas)
951        ]
952    ]
953
954  # Fan-out: Builds a TPUReplicatedOutput node for each output.
955  replicated_outputs = [[] for i in xrange(num_replicas)]
956  for i, t in enumerate(output_tensors):
957    # Fan-out: Builds a TPUReplicatedOutput node for each output.
958    ys = tpu_ops.tpu_replicated_output(
959        t, num_replicas, name="output{}".format(i))
960
961    # Wraps the outputs in identity operators so the names of any possible
962    # `fetch` nodes are preserved by the replication rewrite.
963    with ops.control_dependencies(control_deps):
964      for replica in xrange(num_replicas):
965        replicated_outputs[replica].append(
966            array_ops.identity(
967                ys[replica], name="output_%d_shard_%d" % (i, replica)))
968
969  if not outputs_is_flat:
970    replicated_outputs = [
971        nest.pack_sequence_as(outputs, replica_outs)
972        for replica_outs in replicated_outputs
973    ]
974
975  return [compile_status, replicated_outputs]
976
977
978def _postprocess_flat_outputs(outputs):
979  """Validates non-flat outputs, add backs device assignments and other attrs.
980
981  Args:
982    outputs: Output from `computation` inside `tpu.rewrite`.
983
984  Returns:
985    Tensors and Operations extracted from outputs.
986  """
987  # Following code segment is to preserve legacy behavior. Previously we only
988  # supported flat outputs and thus for consistency it was nice to convert even
989  # single element into a tuple. But now that we support arbitrary output
990  # structure, this is no longer necessary.
991  # TODO(b/121383831): Migrate all legacy use cases and delete this special
992  # case.
993  # If the computation returns `None`, make it an empty tuple.
994  if outputs is None:
995    outputs = tuple()
996  # If the computation only returned one value, makes it a tuple.
997  if not isinstance(outputs, collections.Sequence):
998    outputs = (outputs,)
999
1000  # Append `no_op` here so that fetching any return value of this function
1001  # will trigger TPUExecute node.
1002  outputs += (control_flow_ops.no_op(),)
1003  try:
1004    with ops.device(core(0)):
1005      outputs = [
1006          o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
1007          for o in outputs
1008      ]
1009  except Exception as e:
1010    raise ValueError(
1011        "TPU function return values must all either be Operations or "
1012        "convertible to Tensors. Got '%s'" % str(e))
1013
1014  # Separates the returned Operations and Tensors.
1015  output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
1016  output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
1017
1018  if outputs != output_tensors + output_operations:
1019    raise ValueError(
1020        "TPU functions must return zero-or more Tensor values followed by "
1021        "zero or more Operations.")
1022
1023  # Wraps outputs in Identity ops. Otherwise a replicated input copied
1024  # straight to an output would bypass the replicate(). This would be bad
1025  # because the TPUReplicatedInput/TPUReplicatedOutput operator would not
1026  # be rewritten away, leading to a runtime error.
1027  # TODO(phawkins): extend the rewrite to elide these nodes instead.
1028  new_output_tensors = []
1029  for t in output_tensors:
1030    with ops.device(t.device if t.device else core(0)):
1031      o = array_ops.identity(t)
1032      # pylint: disable=protected-access
1033      o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True))
1034      # pylint: enable=protected-access
1035      new_output_tensors.append(o)
1036  return new_output_tensors, output_operations
1037
1038
1039def _postprocess_non_flat_outputs(outputs):
1040  """Validates non-flat outputs, add backs device assignments and other attrs.
1041
1042  Args:
1043    outputs: Output from `computation` inside `tpu.rewrite`.
1044
1045  Returns:
1046    Tensors extracted from outputs and an empty list because Operations are not
1047    allowed in non-flat outputs..
1048  """
1049
1050  # Flatten output items.
1051  flat_outputs = nest.flatten(outputs)
1052
1053  # Convert all non-Operation outputs to Tensors.
1054  for i, o in enumerate(flat_outputs):
1055    if isinstance(o, ops.Operation):
1056      raise ValueError(
1057          "tpu.rewrite does not support Operation as return value in non-flat "
1058          "output structure. You can set returned Operations as control "
1059          "dependencies of returned Tensors so Operations are triggered when "
1060          'Tensors are evaluated. Operation found: "%s"' % o.name)
1061
1062    try:
1063      o = ops.convert_to_tensor(o)
1064    except Exception as e:
1065      raise ValueError(
1066          "TPU function return values must all either be Operations or "
1067          'convertible to Tensors. Got error: "%s"' % str(e))
1068
1069    # Wraps outputs in Identity ops. Otherwise a replicated input copied
1070    # straight to an output would bypass the replicate(). This would be bad
1071    # because the TPUReplicatedInput/TPUReplicatedOutput operator would not
1072    # be rewritten away, leading to a runtime error.
1073    # TODO(phawkins): extend the rewrite to elide these nodes instead.
1074    with ops.device(core(0)):
1075      o = array_ops.identity(o)
1076      # pylint: disable=protected-access
1077      o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True))
1078      # pylint: enable=protected-access
1079      flat_outputs[i] = array_ops.identity(o)
1080
1081  # All flat_outputs are Tensors, and no Operations.
1082  return flat_outputs, []
1083
1084
1085def split_compile_and_shard(computation,
1086                            inputs=None,
1087                            num_shards=1,
1088                            input_shard_axes=None,
1089                            outputs_from_all_shards=True,
1090                            output_shard_axes=None,
1091                            infeed_queue=None,
1092                            device_assignment=None,
1093                            name=None):
1094  """Shards `computation` for parallel execution.
1095
1096  `inputs` must be a list of Tensors or None (equivalent to an empty list), each
1097  of which has a corresponding split axis (from `input_shard_axes`). Each input
1098  is split into `num_shards` pieces along the corresponding axis, and
1099  computation is applied to each shard in parallel.
1100
1101  Tensors are broadcast to all shards if they are lexically captured by
1102  `computation`. e.g.,
1103
1104  x = tf.constant(7)
1105  def computation():
1106    return x + 3
1107  ... = shard(computation, ...)
1108
1109  If `outputs_from_all_shards` is true, the outputs from all shards of
1110  `computation` are concatenated back together along their `output_shards_axes`.
1111  Otherwise, each output is taken from an arbitrary shard.
1112
1113  Inputs and outputs of the computation must be at least rank-1 Tensors.
1114
1115  Args:
1116    computation: A Python function that builds a computation to apply to each
1117      shard of the input.
1118    inputs: A list of input tensors or None (equivalent to an empty list). Each
1119      input tensor has a corresponding shard axes, given by `input_shard_axes`,
1120      which must have size divisible by `num_shards`.
1121    num_shards: The number of shards.
1122    input_shard_axes: A list of dimensions along which to shard `inputs`, or
1123      `None`. `None` means "shard all inputs along dimension 0". If not `None`,
1124      there must be one dimension per input.
1125    outputs_from_all_shards: Boolean or list of boolean. For each output, if
1126      `True`, outputs from all shards are concatenated along the corresponding
1127      `output_shard_axes` entry. Otherwise, each output is taken
1128      from an arbitrary shard. If the argument is a boolean, the argument's
1129      value is used for each output.
1130    output_shard_axes: A list of dimensions along which to concatenate the
1131      outputs of `computation`, or `None`. `None` means "concatenate all outputs
1132      along dimension 0". If not `None`, there must be one dimension per output.
1133      Ignored if `outputs_from_all_shards` is False.
1134    infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs
1135      of `computation`.
1136    device_assignment: If not `None`, a `DeviceAssignment` describing the
1137      mapping between logical cores in the computation with physical cores in
1138      the TPU topology. Uses a default device assignment if `None`. The
1139      `DeviceAssignment` may be omitted if each shard of the computation uses
1140      only one core, and there is either only one shard, or the number of shards
1141      is equal to the number of cores in the TPU system.
1142    name: (Deprecated) Does nothing.
1143  Returns:
1144    A tuple of (compile op, [output tensors]).
1145  Raises:
1146    ValueError: If num_shards <= 0
1147    ValueError: If len(input_shard_axes) != len(inputs)
1148    ValueError: If len(output_shard_axes) != len(outputs from `computation`)
1149  """
1150  # TODO(phawkins): consider adding support for broadcasting Tensors passed as
1151  # inputs.
1152
1153  if num_shards <= 0:
1154    raise ValueError("num_shards must be a positive integer.")
1155
1156  inputs = [] if inputs is None else inputs
1157  if not isinstance(inputs, list):
1158    raise TypeError("tpu.shard()'s inputs must be a list of Tensors or None.")
1159
1160  # Converts inputs to Tensors.
1161  inputs = [ops.convert_to_tensor(x) for x in inputs]
1162
1163  if input_shard_axes is None:
1164    input_shard_axes = [0] * len(inputs)
1165  if len(inputs) != len(input_shard_axes):
1166    raise ValueError("Length of input_shard_axes must be equal to the number "
1167                     "of inputs.")
1168
1169  if inputs:
1170    # Splits the `inputs` along the corresponding `input_shard_axes`, giving
1171    # lists with layout [input][shard]
1172    split_inputs = [
1173        array_ops.split(x, num_shards, axis=axis)
1174        for (axis, x) in zip(input_shard_axes, inputs)]
1175
1176    # Transposes the input lists to have layout [shard][input]
1177    transposed_inputs = [list(i) for i in zip(*split_inputs)]
1178  else:
1179    transposed_inputs = [[]] * num_shards
1180
1181  compile_op, outputs = split_compile_and_replicate(
1182      computation,
1183      transposed_inputs,
1184      infeed_queue=infeed_queue,
1185      device_assignment=device_assignment,
1186      name=name)
1187
1188  # There must be at least one shard since num_shards > 0.
1189  # TODO(b/36647078) remove disable when pylint bug is fixed.
1190  # pylint: disable=indexing-exception
1191  if isinstance(outputs[0], ops.Operation):
1192    # pylint: enable=indexing-exception
1193    # There were no outputs from the computation and replicate returned a list
1194    # of NoOps with control dependencies on the computation. Return the first
1195    # one so it can be used as a control dependency or fetch node.
1196    # TODO(b/36647078) remove disable when pylint bug is fixed.
1197    # pylint: disable=indexing-exception
1198    return compile_op, [outputs[0]]
1199    # pylint: enable=indexing-exception
1200
1201  # TODO(b/36647078) remove disable when pylint bug is fixed.
1202  # pylint: disable=indexing-exception
1203  num_outputs = len(outputs[0])
1204  # pylint: enable=indexing-exception
1205
1206  if output_shard_axes is None:
1207    output_shard_axes = [0] * num_outputs
1208  if num_outputs != len(output_shard_axes):
1209    raise ValueError("Length of output_shard_axes must be equal to the number "
1210                     "of outputs.")
1211
1212  if isinstance(outputs_from_all_shards, bool):
1213    outputs_from_all_shards = [outputs_from_all_shards] * num_outputs
1214
1215  if num_outputs != len(outputs_from_all_shards):
1216    raise ValueError("Length of outputs_from_all_shards must be equal to the "
1217                     "number of outputs.")
1218
1219  results = []
1220  for (axis, all_shards, x) in zip(output_shard_axes, outputs_from_all_shards,
1221                                   zip(*outputs)):
1222    if all_shards:
1223      # Concatenate all of the outputs together (use stack for scalars).
1224      shape = x[0].shape
1225      is_scalar = shape is not None and (shape.ndims == 0)
1226      results.append((array_ops.stack(list(x)) if is_scalar
1227                      else array_ops.concat(list(x), axis=axis)))
1228    else:
1229      # TODO(phawkins): use a smarter policy, e.g., round-robin across shards.
1230      results.append(x[0])
1231
1232  return compile_op, results
1233
1234
1235def shard(computation,
1236          inputs=None,
1237          num_shards=1,
1238          input_shard_axes=None,
1239          outputs_from_all_shards=True,
1240          output_shard_axes=None,
1241          infeed_queue=None,
1242          device_assignment=None,
1243          name=None):
1244  """Shards `computation` for parallel execution.
1245
1246  `inputs` must be a list of Tensors or None (equivalent to an empty list), each
1247  of which has a corresponding split axis (from `input_shard_axes`). Each input
1248  is split into `num_shards` pieces along the corresponding axis, and
1249  computation is applied to each shard in parallel.
1250
1251  Tensors are broadcast to all shards if they are lexically captured by
1252  `computation`. e.g.,
1253
1254  x = tf.constant(7)
1255  def computation():
1256    return x + 3
1257  ... = shard(computation, ...)
1258
1259  TODO(phawkins): consider adding support for broadcasting Tensors passed
1260  as inputs.
1261
1262  If `outputs_from_all_shards` is true, the outputs from all shards of
1263  `computation` are concatenated back together along their `output_shards_axes`.
1264  Otherwise, each output is taken from an arbitrary shard.
1265
1266  Inputs and outputs of the computation must be at least rank-1 Tensors.
1267
1268  Args:
1269    computation: A Python function that builds a computation to apply to each
1270      shard of the input.
1271    inputs: A list of input tensors or None (equivalent to an empty list). Each
1272      input tensor has a corresponding shard axes, given by `input_shard_axes`,
1273      which must have size divisible by `num_shards`.
1274    num_shards: The number of shards.
1275    input_shard_axes: A list of dimensions along which to shard `inputs`, or
1276      `None`. `None` means "shard all inputs along dimension 0". If not `None`,
1277      there must be one dimension per input.
1278    outputs_from_all_shards: Boolean or list of boolean. For each output, if
1279      `True`, outputs from all shards are concatenated along the corresponding
1280      `output_shard_axes` entry. Otherwise, each output is taken
1281      from an arbitrary shard. If the argument is a boolean, the argument's
1282      value is used for each output.
1283    output_shard_axes: A list of dimensions along which to concatenate the
1284      outputs of `computation`, or `None`. `None` means "concatenate all outputs
1285      along dimension 0". If not `None`, there must be one dimension per output.
1286      Ignored if `outputs_from_all_shards` is False.
1287    infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs
1288      of `computation`.
1289    device_assignment: If not `None`, a `DeviceAssignment` describing the
1290      mapping between logical cores in the computation with physical cores in
1291      the TPU topology. Uses a default device assignment if `None`. The
1292      `DeviceAssignment` may be omitted if each shard of the computation uses
1293      only one core, and there is either only one shard, or the number of shards
1294      is equal to the number of cores in the TPU system.
1295    name: (Deprecated) Does nothing.
1296  Returns:
1297    A list of output tensors.
1298  Raises:
1299    ValueError: If num_shards <= 0
1300    ValueError: If len(input_shard_axes) != len(inputs)
1301    ValueError: If len(output_shard_axes) != len(outputs from `computation`)
1302  """
1303  return split_compile_and_shard(
1304      computation,
1305      inputs=inputs,
1306      num_shards=num_shards,
1307      input_shard_axes=input_shard_axes,
1308      outputs_from_all_shards=outputs_from_all_shards,
1309      output_shard_axes=output_shard_axes,
1310      infeed_queue=infeed_queue,
1311      device_assignment=device_assignment,
1312      name=name)[1]
1313
1314
1315def batch_parallel(computation,
1316                   inputs=None,
1317                   num_shards=1,
1318                   infeed_queue=None,
1319                   device_assignment=None,
1320                   name=None):
1321  """Shards `computation` along the batch dimension for parallel execution.
1322
1323  Convenience wrapper around shard().
1324
1325  `inputs` must be a list of Tensors or None (equivalent to an empty list).
1326  Each input is split into `num_shards` pieces along the 0-th dimension, and
1327  computation is applied to each shard in parallel.
1328
1329  Tensors are broadcast to all shards if they are lexically captured by
1330  `computation`. e.g.,
1331
1332  x = tf.constant(7)
1333  def computation():
1334    return x + 3
1335  ... = shard(computation, ...)
1336
1337  The outputs from all shards are concatenated back together along their 0-th
1338  dimension.
1339
1340  Inputs and outputs of the computation must be at least rank-1 Tensors.
1341
1342  Args:
1343    computation: A Python function that builds a computation to apply to each
1344      shard of the input.
1345    inputs: A list of input tensors or None (equivalent to an empty list). The
1346      0-th dimension of each Tensor must have size divisible by `num_shards`.
1347    num_shards: The number of shards.
1348    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
1349      of arguments as inputs to `computation`.
1350    device_assignment: If not `None`, a `DeviceAssignment` describing the
1351      mapping between logical cores in the computation with physical cores in
1352      the TPU topology. Uses a default device assignment if `None`. The
1353      `DeviceAssignment` may be omitted if each shard of the computation uses
1354      only one core, and there is either only one shard, or the number of shards
1355      is equal to the number of cores in the TPU system.
1356    name: (Deprecated) Does nothing.
1357  Returns:
1358    A list of output tensors.
1359  Raises:
1360    ValueError: If `num_shards <= 0`
1361  """
1362  return shard(
1363      computation,
1364      inputs,
1365      num_shards=num_shards,
1366      infeed_queue=infeed_queue,
1367      device_assignment=device_assignment,
1368      name=name)
1369
1370
1371def rewrite(computation,
1372            inputs=None,
1373            infeed_queue=None,
1374            device_assignment=None,
1375            name=None):
1376  """Rewrites `computation` for execution on a TPU system.
1377
1378  Args:
1379    computation: A Python function that builds a computation to apply to the
1380      input. If the function takes n inputs, 'inputs' should be a list of n
1381      tensors.
1382
1383      `computation` may return a list of operations and tensors. Tensors must
1384      come before operations in the returned list.  The return value of
1385      `rewrite` is a list of tensors corresponding to the tensors from the
1386      output of `computation`.
1387
1388      All `Operation`s constructed during `computation` will be executed when
1389      evaluating any of the returned output tensors, not just the ones returned.
1390    inputs: A list of input tensors or `None` (equivalent to an empty list).
1391      Each input can be a nested structure containing values that are
1392      convertible to tensors. Note that passing an N-dimension list of
1393      compatible values will result in a N-dimention list of scalar tensors
1394      rather than a single Rank-N tensors. If you need different behavior,
1395      convert part of inputs to tensors with `tf.convert_to_tensor`.
1396    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
1397      of arguments as inputs to `computation`.
1398    device_assignment: if not `None`, a `DeviceAssignment` describing the
1399      mapping between logical cores in the computation with physical cores in
1400      the TPU topology. May be omitted for a single-core computation, in which
1401      case the core attached to task 0, TPU device 0 is used.
1402    name: (Deprecated) Does nothing.
1403  Returns:
1404    Same data structure as if computation(*inputs) is called directly with some
1405    exceptions for correctness. Exceptions include:
1406      1) None output: a NoOp would be returned which control-depends on
1407         computation.
1408      2) Single value output: A tuple containing the value would be returned.
1409      3) Operation-only outputs: a NoOp would be returned which
1410         control-depends on computation.
1411      TODO(b/121383831): Investigate into removing these special cases.
1412  """
1413  # TODO(b/36647078) remove disable when pylint bug is fixed.
1414  # pylint: disable=indexing-exception
1415  return replicate(
1416      computation,
1417      None if inputs is None else [inputs],
1418      infeed_queue=infeed_queue,
1419      device_assignment=device_assignment,
1420      name=name)[0]
1421  # pylint: enable=indexing-exception
1422
1423  # Operations that indicate some error in the user's inference graph.
1424_BLACKLISTED_INFERENCE_OPS = set([
1425    "ReadVariableOp",
1426    "AssignVariableOp",
1427    "AssignAddVariableOp",
1428    "AssignSubVariableOp",
1429    "VarHandleOp",
1430    "Variable",
1431    "VariableV2",
1432])
1433
1434
1435def under_tpu_inference_context():
1436  """Check if it is currently under `tpu.rewrite_for_inference()`."""
1437  graph = ops.get_default_graph()
1438
1439  context = graph._get_control_flow_context()  # pylint: disable=protected-access
1440  while context:
1441    if isinstance(context, _TPUInferenceContext):
1442      return True
1443    context = context.outer_context
1444
1445  return False
1446
1447
1448class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext):
1449  """A `ControlFlowContext` for nodes inside a TPU inference computation.
1450
1451  The primary role of `TPUReplicateContext` is to sanity check operators inside
1452  a tpu.rewrite_for_inference() computation.
1453  """
1454
1455  def __init__(self, name):
1456    super(_TPUInferenceContext, self).__init__()
1457    self._name = name
1458
1459  def AddOp(self, op):
1460    self._AddOpInternal(op)
1461
1462  def _AddOpInternal(self, op):
1463    # pylint: disable=protected-access
1464    if op.type in _BLACKLISTED_INFERENCE_OPS:
1465      raise NotImplementedError(
1466          "Operation of type %s (%s) is not supported on the TPU for inference."
1467          " Execution will fail if this op is used in the graph. Make sure your"
1468          " variables are using variable_scope." % (op.type, op.name))
1469    if self._outer_context:
1470      self._outer_context.AddInnerOp(op)
1471
1472  def AddValue(self, val):
1473    result = val
1474    if self._outer_context:
1475      result = self._outer_context.AddValue(val)
1476    return result
1477
1478  def AddInnerOp(self, op):
1479    self._AddOpInternal(op)
1480
1481  @property
1482  def grad_state(self):
1483    return None
1484
1485
1486def validate_inference_rewrite_for_variables(graph):
1487  """Validates whether rewrite_for_inference() 'worked' for variables.
1488
1489     The rewrite_for_inference() method is supposed to append GuaranteeConstOps
1490     after ReadVariableOps, but this mechanism works only if you are using
1491     tf.get_variable() to create and access variables in your tpu computation.
1492     This validation method can be called immediately after calling
1493     tpu.rewrite_for_inference() to check whether GuaranteeConstOps where added
1494     to the graph.
1495
1496     Typical usages:
1497       tpu.validate_inference_rewrite_for_variables(tf.get_default_graph())
1498
1499       tpu.validate_inference_rewrite_for_variables(sess.graph)
1500
1501  Args:
1502    graph: The graph which needs to be validated.
1503  Raises:
1504    RuntimeError: if validation failed.
1505  """
1506  if not any(x.type == "GuaranteeConst" for x in graph.get_operations()):
1507    raise RuntimeError(
1508        "No GuaranteeConst ops found in the graph after running "
1509        "tpu.rewrite_for_inference(...). Please check that you are using "
1510        "tf.get_variable() to create and access variables in your tpu "
1511        "computation.")
1512
1513
1514def rewrite_for_inference(computation,
1515                          inputs=None,
1516                          infeed_queue=None,
1517                          device_assignment=None,
1518                          name=None):
1519  """Rewrites `computation` for inference on a TPU system.
1520
1521     Other than 'rewriting' the computation to run on a TPU, if using variables
1522     in your computation, it moves the ReadVariableOps outside the TPU
1523     computation, and adds GuaranteeConst ops just after the ReadVariableOps.
1524     This mechanism works only if you are using tf.get_variable() to create and
1525     access variables in your tpu computation. You can validate whether this
1526     worked, by calling validate_inference_rewrite_for_variables() method
1527     immediately after this method to check whether GuaranteeConstOps where
1528     added to the graph.
1529
1530  Args:
1531    computation: A Python function that builds a computation to apply to the
1532      input. If the function takes n inputs, 'inputs' should be a list of n
1533      tensors. If the function returns m outputs, rewrite will return a list of
1534      m tensors.
1535    inputs: A list of input tensors or `None` (equivalent to an empty list).
1536    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
1537      of arguments as inputs to `computation`.
1538    device_assignment: if not `None`, a `DeviceAssignment` describing the
1539      mapping between logical cores in the computation with physical cores in
1540      the TPU topology. May be omitted for a single-core computation, in which
1541      case the core attached to task 0, TPU device 0 is used.
1542    name: The name of the operator.
1543  Returns:
1544    A list of output tensors.
1545  """
1546
1547  def guarantee_const_getter(getter, name, *args, **kwargs):
1548    with ops.control_dependencies(None):
1549      return array_ops.guarantee_const(
1550          getter(name, *args, **kwargs), name=name + "/GuaranteeConst")
1551
1552  def wrapped_computation(*args, **kwargs):
1553    """Execute computation under `_TPUInferenceContext`."""
1554    context = _TPUInferenceContext(
1555        name=ops.get_default_graph().unique_name("rewrite_for_inference"))
1556    try:
1557      context.Enter()
1558
1559      vscope = variable_scope.get_variable_scope()
1560      prev_custom_getter = vscope.custom_getter
1561      prev_caching_device = vscope.caching_device
1562      vscope.set_custom_getter(guarantee_const_getter)
1563      vscope.set_caching_device(lambda op: op.device)
1564
1565      result = computation(*args, **kwargs)
1566
1567      vscope.set_custom_getter(prev_custom_getter)
1568      vscope.set_caching_device(prev_caching_device)
1569    finally:
1570      context.Exit()
1571    return result
1572
1573  # pylint: disable=undefined-variable
1574  return rewrite(
1575      wrapped_computation,
1576      inputs=inputs,
1577      infeed_queue=infeed_queue,
1578      device_assignment=device_assignment,
1579      name=name)
1580  # pylint: enable=undefined-variable
1581
1582
1583def prune_unconnected_ops_from_xla(prune_graph):
1584  """Prunes unconnected ops as listed in _UNCONNECTED_OPS_TO_PRUNE.
1585
1586  Args:
1587    prune_graph: A tensorflow graph from which we wish to prune unconnected ops
1588      as listed in _UNCONNECTED_OPS_TO_PRUNE.  In general, these ops should have
1589      no inputs and no consumers. These can often be left behind due to graph
1590      construction rewiring (for instance TF-Hub). While they never execute,
1591      they will cause XLA compile to fail so we strip them from XLA compile by
1592      removing the tpu_replicate attribute.
1593  """
1594  # Scan over the top level graph and all function graphs.
1595  for graph in [prune_graph] + list(prune_graph._functions.values()):  # pylint: disable=protected-access
1596    for op in graph.get_operations():
1597      if op.type not in _UNCONNECTED_OPS_TO_PRUNE:
1598        continue
1599      outputs_consumed = False
1600      for output in op.outputs:
1601        if output.consumers():
1602          outputs_consumed = True
1603          break
1604      if not outputs_consumed:
1605        logging.info(
1606            "Pruning OP %s of type %s from XLA Compile due to "
1607            "it being disconnected.", op.name, op.type)
1608        op._clear_attr(_TPU_REPLICATE_ATTR)  # pylint: disable=protected-access
1609