• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ======================================
15
16"""Library of TPU helper functions."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from absl import logging
23import numpy as np
24from six.moves import xrange  # pylint: disable=redefined-builtin
25
26from tensorflow.core.framework import attr_value_pb2
27from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
28from tensorflow.python.client import pywrap_tf_session
29from tensorflow.python.compiler.xla import xla
30from tensorflow.python.distribute import device_util
31from tensorflow.python.distribute import distribution_strategy_context
32from tensorflow.python.framework import auto_control_deps
33from tensorflow.python.framework import config
34from tensorflow.python.framework import device as pydev
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import errors
37from tensorflow.python.framework import func_graph
38from tensorflow.python.framework import function
39from tensorflow.python.framework import ops
40from tensorflow.python.framework import tensor_shape
41from tensorflow.python.ops import array_ops
42from tensorflow.python.ops import control_flow_ops
43from tensorflow.python.ops import math_ops
44from tensorflow.python.ops import variable_scope
45from tensorflow.python.tpu import tpu_function
46from tensorflow.python.tpu.ops import tpu_ops
47from tensorflow.python.util import compat
48from tensorflow.python.util import nest
49from tensorflow.python.util.compat import collections_abc
50from tensorflow.python.util.tf_export import tf_export
51
52ops.NotDifferentiable("TPUReplicatedInput")
53
54# Operations that indicate some error in the users graph, e.g. a placeholder
55# that's introduced outside of the infeed.
56_BLACKLISTED_OPS = set([
57    "Placeholder",
58])
59
60# XLA doesn't currently support reading of intermediate tensors, thus some ops
61# are not supported.
62_UNSUPPORTED_OPS = set([
63    "AudioSummary",
64    "AudioSummaryV2",
65    "HistogramSummary",
66    "ImageSummary",
67    "MergeSummary",
68    "Print",
69    "ScalarSummary",
70    "TensorSummary",
71    "TensorSummaryV2",
72    ])
73
74# Ops which can be safely pruned from XLA compile if they have no consumers.
75#  These ops should also have no inputs.
76_UNCONNECTED_OPS_TO_PRUNE = set(["Placeholder", "VarHandleOp"])
77
78_MAX_WARNING_LINES = 5
79
80_TPU_REPLICATE_ATTR = "_tpu_replicate"
81_POST_DEVICE_REWRITE_ATTR = "_post_device_rewrite"
82_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status"
83_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation"
84
85
86def _tpu_system_device_name(job):
87  """Returns the device name for the TPU_SYSTEM device of `job`."""
88  if job is None:
89    return "/device:TPU_SYSTEM:0"
90  else:
91    return "/job:%s/device:TPU_SYSTEM:0" % job
92
93
94@tf_export(v1=["tpu.initialize_system"])
95def initialize_system(embedding_config=None,
96                      job=None,
97                      compilation_failure_closes_chips=True):
98  """Initializes a distributed TPU system for use with TensorFlow.
99
100  Args:
101    embedding_config: If not None, a `TPUEmbeddingConfiguration` proto
102      describing the desired configuration of the hardware embedding lookup
103      tables. If embedding_config is None, no hardware embeddings can be used.
104    job: The job (the XXX in TensorFlow device specification /job:XXX) that
105      contains the TPU devices that will be initialized. If job=None it is
106      assumed there is only one job in the TensorFlow flock, and an error will
107      be returned if this assumption does not hold.
108    compilation_failure_closes_chips: Set the configuration whether
109      we want to close TPU chips when there is a compilation failure.
110  Returns:
111    A serialized `TopologyProto` that describes the TPU system. Note:
112      the topology must be evaluated using `Session.run` before it can be used.
113  """
114  config_string = ("" if embedding_config is None else
115                   embedding_config.SerializeToString())
116  with ops.device(_tpu_system_device_name(job)):
117    return tpu_ops.configure_distributed_tpu(
118        embedding_config=config_string,
119        compilation_failure_closes_chips=compilation_failure_closes_chips)
120
121
122def initialize_system_for_tpu_embedding(embedding_config, job=None):
123  """Initializes a distributed TPU Embedding system for use with TensorFlow.
124
125  The following two are equivalent:
126  1. initialize_system() with embedding_config.
127  2. initialize_system() without embedding_config, then
128     initialize_system_for_tpu_embedding().
129  initialize_system() should not be called with embedding_config if
130  initialize_system_for_tpu_embedding() is meant to be called later.
131
132  Args:
133    embedding_config: a `TPUEmbeddingConfiguration` proto describing the desired
134      configuration of the hardware embedding lookup tables.
135    job: The job (the XXX in TensorFlow device specification /job:XXX) that
136      contains the TPU devices that will be initialized. If job=None it is
137      assumed there is only one job in the TensorFlow flock, and an error will
138      be returned if this assumption does not hold.
139
140  Returns:
141    A no-op.
142  """
143  config_string = embedding_config.SerializeToString()
144  with ops.device(_tpu_system_device_name(job)):
145    return tpu_ops.configure_tpu_embedding(config=config_string)
146
147
148@tf_export(v1=["tpu.shutdown_system"])
149def shutdown_system(job=None):
150  """Shuts down a running a distributed TPU system.
151
152  Args:
153    job: The job (the XXX in TensorFlow device specification /job:XXX) that
154      contains the TPU devices that will be shutdown. If job=None it is
155      assumed there is only one job in the TensorFlow flock, and an error will
156      be returned if this assumption does not hold.
157  """
158  with ops.device(_tpu_system_device_name(job)):
159    shutdown_distributed_tpu = tpu_ops.shutdown_distributed_tpu()
160  return shutdown_distributed_tpu
161
162
163@tf_export(v1=["tpu.core"])
164def core(num):
165  """Returns the device name for a core in a replicated TPU computation.
166
167  Args:
168    num: the virtual core number within each replica to which operators should
169    be assigned.
170  Returns:
171    A device name, suitable for passing to `tf.device()`.
172  """
173  return "device:TPU_REPLICATED_CORE:{}".format(num)
174
175
176def _enclosing_tpu_context_and_graph():
177  """Returns the TPUReplicateContext and its associated graph."""
178  graph = ops.get_default_graph()
179  while graph is not None:
180    # pylint: disable=protected-access
181    context_ = graph._get_control_flow_context()
182    # pylint: enable=protected-access
183    while context_ is not None:
184      if isinstance(context_, TPUReplicateContext):
185        return context_, graph
186      context_ = context_.outer_context
187    graph = getattr(graph, "outer_graph", None)
188  raise ValueError("get_replicated_var_handle() called without "
189                   "TPUReplicateContext. This shouldn't happen. Please file "
190                   "a bug.")
191
192
193def is_tpu_strategy(strategy):
194  is_tpu_strat = lambda k: k.__name__.startswith("TPUStrategy")
195  clz = strategy.__class__
196  return is_tpu_strat(clz) or any(map(is_tpu_strat, clz.__bases__))
197
198
199def _enclosing_tpu_device_assignment():
200  if not distribution_strategy_context.has_strategy():
201    return None
202  strategy = distribution_strategy_context.get_strategy()
203  if not is_tpu_strategy(strategy):
204    return None
205  return strategy.extended._device_assignment  # pylint: disable=protected-access
206
207
208@auto_control_deps.register_acd_resource_resolver
209def tpu_replicated_input_resolver(op, resource_inputs):
210  """Replaces TPUReplicatedInput outputs with its inputs in resource_inputs."""
211  # Ignore TPUReplicatedInput for ACD purposes since we will be directly adding
212  # control deps on the replicated inputs.
213  if op.type == "TPUReplicatedInput":
214    if resource_inputs:
215      resource_inputs.clear()
216      return True
217    else:
218      return False
219  # Replace tensors in `resource_inputs` which are outputs of TPUReplicatedInput
220  # with the actual replicated inputs. This allows ACD to correct add control
221  # deps when there are multiple calls to `experimental_run_v2` in a
222  # `tf.function`.
223  to_remove = []
224  to_add = []
225  for resource in resource_inputs:
226    if resource.op.type == "TPUReplicatedInput":
227      to_remove.append(resource)
228      to_add.extend(resource.op.inputs)
229  if not to_add and not to_remove:
230    return False
231  for t in to_remove:
232    resource_inputs.discard(t)
233  resource_inputs.update(to_add)
234  return True
235
236
237class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
238  """A `ControlFlowContext` for nodes inside a TPU computation.
239
240  The primary role of `TPUReplicateContext` is to mark operators inside a
241  tpu.replicate() computation with the attribute "_tpu_replicate=XYZ", where XYZ
242  is a unique name.
243
244  We use a `ControlFlowContext` to perform the annotation since it integrates
245  with Tensorflow constructs like ResourceVariables. For example, if a
246  `ResourceVariable` is constructed inside a tpu.replicate() block, the
247  `ResourceVariable` implementation can use
248  `with ops.control_dependencies(None)` to build the variable's definition
249  outside the replicated computation.
250  """
251
252  class _TFBufferWrapper(object):
253    """An internal class to help manage the TF_Buffer lifetime."""
254
255    def __init__(self, buf_string):
256      self._buffer = pywrap_tf_session.TF_NewBufferFromString(
257          compat.as_bytes(buf_string))
258
259    def __del__(self):
260      pywrap_tf_session.TF_DeleteBuffer(self._buffer)
261
262  def __init__(self, name, num_replicas, pivot):
263    """Builds a new TPUReplicateContext.
264
265    Args:
266      name: a unique name for the context, used to populate the `_tpu_replicate`
267        attribute.
268      num_replicas: an integer that gives the number of replicas for the
269        computation.
270      pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any
271        inputs will have a control dependency on the pivot node. This ensures
272        that nodes are correctly included in any enclosing control flow
273        contexts.
274    """
275    super(TPUReplicateContext, self).__init__()
276    self._num_replicas = num_replicas
277    self._outer_device_function_stack = None
278    self._oc_dev_fn_stack = None
279    self._outside_compilation_cluster = None
280    self._outside_compilation_counter = 0
281    self._in_gradient_colocation = None
282    self._gradient_colocation_stack = []
283    self._host_compute_core = []
284    self._name = name
285    self._name_as_bytes = compat.as_bytes(name)
286    self._tpu_relicate_attr_buf = self._TFBufferWrapper(
287        attr_value_pb2.AttrValue(s=self._name_as_bytes).SerializeToString())
288    self._unsupported_ops = []
289    self._pivot = pivot
290    self._replicated_vars = {}
291
292  def get_replicated_var_handle(self, name, vars_, is_mirrored=False):
293    """Returns a variable handle for replicated TPU variable 'var'.
294
295    This is a method used by an experimental replicated variable implementation
296    and is not intended as a public API.
297
298    Args:
299      name: The common name of the variable.
300      vars_: The replicated TPU variables.
301      is_mirrored: Whether the variables are mirrored, which guarantees the
302        values in each replica are always the same.
303
304    Returns:
305      The handle of the TPU replicated input node.
306    """
307    device_assignment = _enclosing_tpu_device_assignment()
308    # We don't need to put device assignment as part of the replicated_vars key
309    # because each TPUReplicateContext will only have one device assignment.
310    handle = self._replicated_vars.get(name)
311    if handle is not None:
312      return handle
313
314    if device_assignment is not None:
315      # Find a variable copy for each replica in the device assignment.
316      # Note that the order of devices for replicas for the variable and the
317      # device assignment might not match.
318      job_name = pydev.DeviceSpec.from_string(vars_[0].device).job
319      devices_to_vars = {v.device: v for v in vars_}
320      replicated_vars = []
321      for replica_id in range(device_assignment.num_replicas):
322        for logical_core in range(device_assignment.num_cores_per_replica):
323          device = device_util.canonicalize(
324              device_assignment.tpu_device(
325                  replica=replica_id, logical_core=logical_core, job=job_name))
326          if device in devices_to_vars:
327            replicated_vars.append(devices_to_vars[device])
328            break
329        else:
330          raise ValueError(
331              "Failed to find a variable on any device in replica {} for "
332              "current device assignment".format(replica_id))
333    else:
334      replicated_vars = vars_
335
336    # Builds a TPUReplicatedInput node for the variable, if one does not already
337    # exist. The TPUReplicatedInput node must belong to the enclosing
338    # control-flow scope of the TPUReplicateContext.
339    # TODO(phawkins): consider changing the contract of the TPU encapsulation
340    # so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope
341    # instead.
342
343    _, graph = _enclosing_tpu_context_and_graph()
344    with graph.as_default():
345      # pylint: disable=protected-access
346      saved_context = graph._get_control_flow_context()
347      graph._set_control_flow_context(self.outer_context)
348      handle = tpu_ops.tpu_replicated_input([v.handle for v in replicated_vars],
349                                            name=name + "/handle",
350                                            is_mirrored_variable=is_mirrored)
351      graph._set_control_flow_context(saved_context)
352      # pylint: enable=protected-access
353    self._replicated_vars[name] = handle
354    return handle
355
356  def report_unsupported_operations(self):
357    if self._unsupported_ops:
358      op_str = "\n".join("  %s (%s)" % (op.type, op.name)
359                         for op in self._unsupported_ops[:_MAX_WARNING_LINES])
360      logging.warning("%d unsupported operations found: \n%s",
361                      len(self._unsupported_ops), op_str)
362      if len(self._unsupported_ops) > _MAX_WARNING_LINES:
363        logging.warning("... and %d more" %
364                        (len(self._unsupported_ops) - _MAX_WARNING_LINES))
365
366  def EnterGradientColocation(self, op, gradient_uid):
367    if op is not None:
368      self._gradient_colocation_stack.append(op)
369      if not self._outside_compilation_cluster:
370        try:
371          outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii")
372          if self._in_gradient_colocation:
373            raise NotImplementedError(
374                "Cannot nest gradient colocation operations outside compilation"
375            )
376          if gradient_uid == "__unsupported__":
377            raise NotImplementedError(
378                "No gradient_uid calling gradient within outside_compilation")
379          # When we take the gradient of an op X in an outside_compilation
380          # cluster C in a forward computation we would like to put the ops
381          # corresponding to the gradient of X into a new outside_compilation
382          # cluster C'. However, if we take the gradient of X twice, the second
383          # one should get yet another new outside_compilation cluster C''.
384          #
385          # The mechanism we adopt is to use a 'root_cluster' which is the
386          # cluster that X was in before we took gradients, and a 'gradient_uid'
387          # which is different for every invocation of gradients, and put the
388          # gradient of X in cluster 'root_cluster.gradient_uid'.
389          #
390          # When taking a gradient of a gradient, some ops will be colocated
391          # with Op in the forward pass (e.g., cluster root_cluster) and some in
392          # the backward pass (e.g., cluster root_cluster.initial_gradient_uid).
393          # We need all of the grad-of-grad ops to be in the same cluster to
394          # avoid cyclic dependencies between clusters. We adopt a heuristic
395          # that puts any op clustered with root_cluster.<xxx> in
396          # root_cluster.gradient_uid, even if xxx was initial_gradient_uid.
397          self._in_gradient_colocation = op
398          parts = outside_attr.split(".")
399          cluster = parts[0] + "." + gradient_uid
400          self._EnterOutsideCompilationScope(cluster=cluster)
401        except ValueError:
402          # The attr was not present: do nothing.
403          pass
404
405  def ExitGradientColocation(self, op, gradient_uid):
406    if op is not None:
407      if not self._gradient_colocation_stack:
408        raise errors.InternalError(
409            op.node_def, op,
410            "Badly nested gradient colocation: empty stack when popping Op " +
411            op.name)
412      last_op = self._gradient_colocation_stack.pop()
413      if op is last_op:
414        if op is self._in_gradient_colocation:
415          self._in_gradient_colocation = None
416          self._ExitOutsideCompilationScope()
417      else:
418        raise errors.InternalError(
419            op.node_def, op, "Badly nested gradient colocation, expected " +
420            last_op + ", got " + op.name)
421
422  def _EnterOutsideCompilationScope(self, cluster=None):
423
424    class FakeOp(object):
425      """A helper class to determine the current device.
426
427      Supports only the type and device set/get methods needed to run the
428      graph's _apply_device_function method.
429      """
430
431      def __init__(self):
432        self._device = ""
433
434      @property
435      def type(self):
436        return "FakeOp"
437
438      @property
439      def device(self):
440        return self._device
441
442      def _set_device(self, device):
443        if isinstance(device, pydev.DeviceSpec):
444          self._device = device.to_string()
445        else:
446          self._device = device
447
448      def _set_device_from_string(self, device_str):
449        self._device = device_str
450
451    if self._outside_compilation_cluster:
452      raise NotImplementedError("Cannot nest outside_compilation clusters")
453    if cluster:
454      self._outside_compilation_cluster = cluster
455    else:
456      self._outside_compilation_cluster = str(self._outside_compilation_counter)
457      self._outside_compilation_counter += 1
458    graph = ops.get_default_graph()
459    fake_op = FakeOp()
460    graph._apply_device_functions(fake_op)  # pylint: disable=protected-access
461    device = pydev.DeviceSpec.from_string(fake_op.device)
462    if (device.device_type == "TPU_REPLICATED_CORE" and
463        device.device_index is not None):
464      self._host_compute_core.append(self._outside_compilation_cluster + ":" +
465                                     str(device.device_index))
466    self._oc_dev_fn_stack = graph._device_function_stack  # pylint: disable=protected-access
467    graph._device_function_stack = self._outer_device_function_stack  # pylint: disable=protected-access
468
469  def _ExitOutsideCompilationScope(self):
470    if not self._outside_compilation_cluster:
471      raise NotImplementedError(
472          "Attempted to exit outside_compilation scope when not in scope")
473    self._outside_compilation_cluster = None
474    graph = ops.get_default_graph()
475    graph._device_function_stack = self._oc_dev_fn_stack  # pylint: disable=protected-access
476
477  def Enter(self):
478    if not self._outer_device_function_stack:
479      # Capture the device function stack at the time of first entry
480      # since that is the stack that will be used outside_compilation.
481      graph = ops.get_default_graph()
482      # pylint: disable=protected-access
483      self._outer_device_function_stack = graph._device_function_stack.copy()
484      # pylint: enable=protected-access
485    super(TPUReplicateContext, self).Enter()
486
487  def HostComputeCore(self):
488    return self._host_compute_core
489
490  def _RemoveExternalControlEdges(self, op):
491    """Remove any external control dependency on this op."""
492    internal_control_inputs = []
493    external_control_inputs = []
494    for x in op.control_inputs:
495      # pylint: disable=protected-access
496      is_internal_op = False
497      ctxt = x._get_control_flow_context()
498      while ctxt is not None:
499        if ctxt == self:
500          is_internal_op = True
501          break
502        ctxt = ctxt._outer_context
503      if is_internal_op:
504        internal_control_inputs.append(x)
505      else:
506        external_control_inputs.append(x)
507      # pylint: enable=protected-access
508    # pylint: disable=protected-access
509    op._remove_all_control_inputs()
510    op._add_control_inputs(internal_control_inputs)
511    # pylint: enable=protected-access
512    return internal_control_inputs, external_control_inputs
513
514  def AddOp(self, op):
515    # pylint: disable=protected-access
516    if op.type in _BLACKLISTED_OPS:
517      logging.error("Operation of type %s (%s) is not supported on the TPU. "
518                    "Execution will fail if this op is used in the graph. " %
519                    (op.type, op.name))
520
521    if op.type in _UNSUPPORTED_OPS:
522      self._unsupported_ops.append(op)
523
524    if any(x.dtype._is_ref_dtype for x in op.inputs):
525      raise NotImplementedError(
526          "Non-resource Variables are not supported inside TPU computations "
527          "(operator name: %s)" % op.name)
528
529    # TensorFlowOpLayer may clone nodes that are in tpu.rewrite()s. It'll add
530    # the "_cloned" attribute and we should continue in that case.
531    if (_TPU_REPLICATE_ATTR in op.node_def.attr and
532        "_cloned" not in op.node_def.attr):
533      raise ValueError("TPU computations cannot be nested on op (%s)" %
534                       op)
535    op._set_attr_with_buf(
536        _TPU_REPLICATE_ATTR, self._tpu_relicate_attr_buf._buffer)
537    if self._outside_compilation_cluster:
538      op._set_attr(
539          _OUTSIDE_COMPILATION_ATTR,
540          attr_value_pb2.AttrValue(
541              s=compat.as_bytes(self._outside_compilation_cluster)))
542    if self._num_replicas > 1 or not self._outside_compilation_cluster:
543      # Prevent feeding or fetching anything that is being compiled,
544      # and any replicated outside_compilation Op.
545      op.graph.prevent_feeding(op)
546      op.graph.prevent_fetching(op)
547
548    # Remove any control edges from outer control flow contexts. These may cause
549    # mismatched frame errors.
550    (internal_control_inputs,
551     external_control_inputs) = self._RemoveExternalControlEdges(op)
552
553    if not op.inputs:
554      # Add a control edge from the control pivot to this op.
555      if not internal_control_inputs:
556        # pylint: disable=protected-access
557        op._add_control_input(self.GetControlPivot())
558        # pylint: enable=protected-access
559    else:
560      for index in xrange(len(op.inputs)):
561        x = op.inputs[index]
562        real_x = self.AddValue(x)
563        if real_x is not x:
564          op._update_input(index, real_x)  # pylint: disable=protected-access
565
566    if external_control_inputs:
567      # Use an identity to pull control inputs as data inputs. Note that we
568      # ignore ops which don't have outputs. TODO(phawkins): fix that.
569      with ops.control_dependencies(None):
570        self.Enter()
571        external_control_inputs = [
572            array_ops.identity(x.outputs[0]).op
573            for x in external_control_inputs
574            if x.outputs
575        ]
576        self.Exit()
577      # pylint: disable=protected-access
578      op._add_control_inputs(external_control_inputs)
579      # pylint: enable=protected-access
580
581    # Mark op's outputs as seen by this context and any outer contexts.
582    output_names = [x.name for x in op.outputs]
583    context = self
584    while context is not None:
585      # pylint: disable=protected-access
586      context._values.update(output_names)
587      context = context._outer_context
588      # pylint: enable=protected-access
589
590    if self._outer_context:
591      self._outer_context.AddInnerOp(op)
592
593  def AddValue(self, val):
594    """Add `val` to the current context and its outer context recursively."""
595    if val.name in self._values:
596      # Use the real value if it comes from outer context.
597      result = self._external_values.get(val.name)
598      return val if result is None else result
599
600    result = val
601    self._values.add(val.name)
602    if self._outer_context:
603      result = self._outer_context.AddValue(val)
604      self._values.add(result.name)
605
606    self._external_values[val.name] = result
607
608    return result
609
610  def AddInnerOp(self, op):
611    self.AddOp(op)
612    if self._outer_context:
613      self._outer_context.AddInnerOp(op)
614
615  @property
616  def grad_state(self):
617    # Define the gradient loop state associated with the TPUReplicateContext to
618    # be None as the TPUReplicateContext does not get nested nor does the
619    # grad_state outside the TPUReplicateContext affect the graph inside so the
620    # grad_state should be as if this is the top-level gradient state.
621    return None
622
623  @property
624  def back_prop(self):
625    """Forwards to the enclosing while context, if any."""
626    if self.GetWhileContext():
627      return self.GetWhileContext().back_prop
628    return False
629
630  def GetControlPivot(self):
631    return self._pivot
632
633
634class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext):
635  """The context for outside compilation in Tensorflow 2.0.
636
637  Every op added in this context will be assigned an _xla_outside_compilation
638  attribute.
639  """
640
641  def __init__(self, name):
642    control_flow_ops.ControlFlowContext.__init__(self)
643    self._name = name
644
645  def AddOp(self, op):
646    if self._outer_context:
647      self._outer_context.AddOp(op)
648    # pylint: disable=protected-access
649    op._set_attr("_xla_outside_compilation",
650                 attr_value_pb2.AttrValue(s=compat.as_bytes(self._name)))
651    # pylint: enable=protected-access
652
653  def AddInnerOp(self, op):
654    if self._outer_context:
655      self._outer_context.AddInnerOp(op)
656    # pylint: disable=protected-access
657    op._set_attr("_xla_outside_compilation",
658                 attr_value_pb2.AttrValue(s=compat.as_bytes(self._name)))
659    # pylint: enable=protected-access
660
661  def to_control_flow_context_def(self, context_def, export_scope=None):
662    raise NotImplementedError("to_control_flow_context_def not implemented")
663
664
665@tf_export(v1=["tpu.outside_compilation"])
666def outside_compilation(computation, *args, **kwargs):
667  """Builds part of a computation outside any current TPU replicate scope.
668
669  `tf.tpu.outside_compilation()` is used to run ops in `computation` on CPU
670  instead of running on TPU. For example, users can run ops that are not
671  supported on TPU's (e.g. tf.summary.write()) by explicitly placing those
672  ops on CPU's. Below usage of outside compilation will place ops in
673  `computation_with_string_ops` on CPU.
674
675  def computation_with_string_ops(x):
676    # strings types are not supported on TPU's and below ops must
677    # run on CPU instead.
678    output = tf.strings.format('1{}', x)
679    return tf.strings.to_number(output)
680
681  def tpu_computation():
682    # Expected output is 11.
683    output = tf.tpu.outside_compilation(computation_with_string_ops, 1)
684
685  Outside compilation should be called inside TPUReplicateContext. That is,
686  `tf.tpu.outside_compilation()` should be called inside a function that is
687  passed to `tpu.split_compile_and_replicate()` -- this is implied when
688  outside compilation is invoked inside a function passed to TPUStrategy
689  `experimental_run_v2()`. If invoked outside of TPUReplicateContext,
690  then this simply returns the result of `computation`, and therefore,
691  would be a no-op. Note that outside compilation is different from
692  `tf.distribute.experimental.TPUStrategy.merge_call()` as logic in
693  outside compilation is replicated and executed separately for each
694  replica. On the other hand, `merge_call()` requires a `merge_fn`
695  to aggregate the inputs from different replicas and is executed only
696  once.
697
698  For variables placed in TPU device, which includes variables created inside
699  TPUStrategy scope, outside compilation logic must not include variable
700  read/write. For variables placed on host, which is the case when variables
701  created via TPUEstimator, variable read/write is only allowed if the variable
702  is not accessed by any other ops in the TPU computation. Variable read/write
703  from outside compilation cluster is not visible from TPU computation and
704  vice versa. Therefore, if outside compilation logic contains such host
705  variables read/write ops and if the variables are accessed by TPU
706  computation as well, then this may lead to deadlock.
707
708  Internally, `tf.tpu.outside_compilation()` adds outside compilation
709  attributes to all ops in `computation`. During later graph pass, these
710  ops with outside compilation attribute is extracted out and replicated
711  into a host-side graph. Inputs to this extract host-side graph is sent
712  from TPU computation graph to host graph via a pair of XlaSendToHost and
713  XlaRecvFromHost ops. Note that using `tf.tpu.outside_compilation()`
714  may result in tensor transfer between TPU and CPU, leading to non-trivial
715  performance impact.
716
717  Args:
718    computation: A Python function that builds the computation to
719      place on the host.
720    *args: the positional arguments for the computation.
721    **kwargs: the keyword arguments for the computation.
722
723  Returns:
724    The Tensors returned by computation.
725  """
726  args = [] if args is None else args
727  graph = ops.get_default_graph()
728
729  # If we are in TF 2 functions (control flow V2 functions, or tf.function()),
730  # we need to attach _xla_outside_compilation attribute directly because we are
731  # not in TPUReplicateContext.
732  if isinstance(graph, func_graph.FuncGraph):
733    try:
734      tpu_context, _ = _enclosing_tpu_context_and_graph()
735    except ValueError:
736      logging.warning(
737          "Outside compilation attempted outside TPUReplicateContext "
738          "scope. As no enclosing TPUReplicateContext can be found, "
739          "returning the result of `computation` as is.")
740      return computation(*args, **kwargs)
741
742    # pylint: disable=protected-access
743    outside_compilation_name = str(tpu_context._outside_compilation_counter)
744    tpu_context._outside_compilation_counter = (
745        tpu_context._outside_compilation_counter + 1)
746    # pylint: enable=protected-access
747
748    outside_compilation_context = OutsideCompilationV2Context(
749        outside_compilation_name)
750    outside_compilation_context.Enter()
751    args = [] if args is None else args
752    retval = computation(*args, **kwargs)
753    outside_compilation_context.Exit()
754    return retval
755
756  # If we are in a TPUReplicateContext, signal that we are now
757  # outside_compilation
758  initial_context = graph._get_control_flow_context()  # pylint: disable=protected-access
759  context = initial_context
760  while context:
761    if isinstance(context, TPUReplicateContext):
762      context._EnterOutsideCompilationScope()  # pylint: disable=protected-access
763    context = context.outer_context
764
765  retval = computation(*args, **kwargs)
766
767  # If we are in a TPUReplicateContext, signal that we are no longer
768  # outside_compilation
769  final_context = graph._get_control_flow_context()  # pylint: disable=protected-access
770  if initial_context is not final_context:
771    raise NotImplementedError(
772        "Control-flow context cannot be different at start and end of an "
773        "outside_compilation scope")
774  context = initial_context
775  while context:
776    if isinstance(context, TPUReplicateContext):
777      context._ExitOutsideCompilationScope()  # pylint: disable=protected-access
778    context = context.outer_context
779
780  return retval
781
782
783@tf_export(v1=["tpu.replicate"])
784def replicate(computation,
785              inputs=None,
786              infeed_queue=None,
787              device_assignment=None,
788              name=None,
789              maximum_shapes=None):
790  """Builds a graph operator that runs a replicated TPU computation.
791
792  Args:
793    computation: A Python function that builds the computation to replicate.
794    inputs: A list of lists of input tensors or `None` (equivalent to
795      `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
796      have the same number of inputs. Each input can be a nested structure
797      containing values that are convertible to tensors. Note that passing an
798      N-dimension list of compatible values will result in a N-dimension list of
799      scalar tensors rather than a single Rank-N tensors. If you need different
800      behavior, convert part of inputs to tensors with `tf.convert_to_tensor`.
801    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
802      of arguments as inputs to computation.
803    device_assignment: If not `None`, a `DeviceAssignment` describing the
804      mapping between logical cores in the computation with physical cores in
805      the TPU topology. Uses a default device assignment if `None`. The
806      `DeviceAssignment` may be omitted if each replica of the computation uses
807      only one core, and there is either only one replica, or the number of
808      replicas is equal to the number of cores in the TPU system.
809    name: (Deprecated) Does nothing.
810    maximum_shapes: A nested structure of tf.TensorShape representing the shape
811      to which the respective component of each input element in each replica
812      should be padded. Any unknown dimensions (e.g.
813      tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like
814      object) will be padded to the maximum size of that dimension over all
815      replicas. The structure of `maximum_shapes` needs to be the same as
816      `inputs[0]`.
817  Returns:
818    A list of outputs, indexed by `[replica_num]` each output can be a nested
819    structure same as what computation() returns with a few exceptions.
820
821    Exceptions include:
822      1) None output: a NoOp would be returned which control-depends on
823         computation.
824      2) Single value output: A tuple containing the value would be returned.
825      3) Operation-only outputs: a NoOp would be returned which
826         control-depends on computation.
827      TODO(b/121383831): Investigate into removing these special cases.
828
829  Raises:
830    ValueError: If all replicas do not have equal numbers of input tensors.
831    ValueError: If the number of inputs per replica does not match
832      the number of formal parameters to `computation`.
833    ValueError: If the static `inputs` dimensions don't match with the values
834      given in `maximum_shapes`.
835    ValueError: If the structure of inputs per replica does not match
836      the structure of `maximum_shapes`.
837  """
838  return split_compile_and_replicate(
839      computation,
840      inputs,
841      infeed_queue,
842      device_assignment,
843      name,
844      maximum_shapes=maximum_shapes)[1]
845
846
847def _pad_all_input(inputs, padded_shapes):
848  """Pad all input tensors given padded_shapes.
849
850  The real shape tensors will be concatenated with the padded original inputs.
851
852  Args:
853    inputs: The original inputs.
854    padded_shapes: A list of padded shapes for each input.
855
856  Returns:
857    The padded inputs and a PaddingMap list which maps the padded input
858    dimension to the real shape argument index.
859  """
860  # maximum_static_shapes[idx][i] indicates the maximum static size of ith
861  # dimension of the idx input among all the replicas.
862  maximum_static_shapes = []
863  # need_padding[idx][i] indicates whether the ith dimension of the idx input
864  # needs padding.
865  need_padding = []
866  input_shape_tensors = []
867  for core_idx, inputs_per_core in enumerate(inputs):
868    for idx, input_tensor in enumerate(inputs_per_core):
869      input_shape = input_tensor.get_shape().as_list()
870      if core_idx == 0:
871        input_shape_tensors.append([])
872        maximum_static_shapes.append(input_shape)
873        need_padding.append(np.full_like(input_shape, False, dtype=bool))
874      else:
875        for i, s in enumerate(input_shape):
876          if not s or s != maximum_static_shapes[idx][i]:
877            need_padding[idx][i] = True
878        maximum_static_shapes[idx] = max(input_shape,
879                                         maximum_static_shapes[idx])
880
881      # Append _POST_DEVICE_REWRITE_ATTR attributes to the real shape ops.
882      real_input_shape = array_ops.shape(input_tensor)
883      real_input_shape.op._set_attr(  # pylint: disable=protected-access
884          _POST_DEVICE_REWRITE_ATTR,
885          attr_value_pb2.AttrValue(b=True))
886      input_shape_tensors[idx].append(real_input_shape)
887
888  maximum_shapes = []
889  for shapes_per_input in input_shape_tensors:
890    maximum_shapes.append(
891        math_ops.reduce_max(array_ops.stack(shapes_per_input), axis=0))
892
893  padded_inputs = []
894  real_shapes = []
895  padding_maps = []
896  for core_idx, inputs_per_core in enumerate(inputs):
897    padded_inputs.append([])
898    real_shapes.append([])
899    real_shape_idx = len(inputs_per_core) - 1
900    for idx, input_tensor in enumerate(inputs_per_core):
901      input_shape_tensor = input_shape_tensors[idx][core_idx]
902      input_shape = input_tensor.get_shape().as_list()
903      padded_shape = padded_shapes[idx]
904
905      if any(need_padding[idx]):
906        for i, s in enumerate(input_shape):
907          if need_padding[idx][i]:
908            if core_idx == 0:
909              real_shape_idx += 1
910              padding_map = dynamic_padding.PaddingMap()
911              padding_map.arg_index = idx
912              padding_map.shape_index = i
913              padding_map.padding_arg_index = real_shape_idx
914              padding_maps.append(padding_map)
915            real_shapes[core_idx].append(
916                math_ops.cast(input_shape_tensor[i], dtypes.int32))
917
918        paddings = []
919        for i, s in enumerate(padded_shape.dims):
920          if need_padding[idx][i]:
921            # The minimum padded dimension size is 2 as XLA doesn't support size
922            # 1 dynamic size.
923            minimum_dynamic_dim_size = 2
924            if s.value:
925              # Pad to the given maximum value.
926              max_dim_size = max(s.value, minimum_dynamic_dim_size)
927            else:
928              # If maximum value is not given, then pad to the maximum dimension
929              # among all the cores.
930              max_dim_size = math_ops.maximum(maximum_shapes[idx][i],
931                                              minimum_dynamic_dim_size)
932            # Pad to the given maximum value.
933            padding = [0, max_dim_size - input_shape_tensor[i]]
934          else:
935            padding = [0, 0]
936          paddings.append(padding)
937
938        if input_tensor.get_shape().is_fully_defined():
939          # TODO(rxsang): This is a hack to make sure padded_input has dynamic
940          # shapes, so any tf.size/tf.shape op performed on it won't be constant
941          # folded. Do we have better ways to do it?
942          padded_input = control_flow_ops.cond(
943              array_ops.constant(True),
944              lambda: array_ops.pad(input_tensor, paddings),  # pylint: disable=cell-var-from-loop
945              lambda: input_tensor)
946        else:
947          padded_input = array_ops.pad(input_tensor, paddings)
948
949        # Append _POST_DEVICE_REWRITE_ATTR attributes to all padded inputs.
950        padded_input.op._set_attr(  # pylint: disable=protected-access
951            _POST_DEVICE_REWRITE_ATTR,
952            attr_value_pb2.AttrValue(b=True))
953
954        padded_inputs[core_idx].append(padded_input)
955      else:
956        padded_inputs[core_idx].append(input_tensor)
957
958  num_replicas = len(padded_inputs)
959  for i in range(num_replicas):
960    padded_inputs[i].extend(real_shapes[i])
961
962  return padded_inputs, padding_maps
963
964
965def split_compile_and_replicate(computation,
966                                inputs=None,
967                                infeed_queue=None,
968                                device_assignment=None,
969                                name=None,
970                                use_tpu=True,
971                                maximum_shapes=None):
972  """Builds graph operators that runs compilation and replicated computation.
973
974  This is a lower level interface than replicate that returns a separate compile
975  and execute output tensor. In the generated graph the compile op feeds into
976  the execute op and no additional compilation is incurred when running the
977  compile op before the execute op. The compile op returns additional
978  information about the compilation but does not return the compiled program.
979
980  Args:
981    computation: A Python function that builds the computation to replicate.
982    inputs: A list of lists of input tensors or `None` (equivalent to
983      `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
984      have the same number of inputs. Each input can be a nested structure
985      containing values that are convertible to tensors. Note that passing an
986      N-dimension list of compatible values will result in a N-dimension list of
987      scalar tensors rather than a single Rank-N tensors. If you need different
988      behavior, convert part of inputs to tensors with `tf.convert_to_tensor`.
989    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
990      of arguments as inputs to computation.
991    device_assignment: If not `None`, a `DeviceAssignment` describing the
992      mapping between logical cores in the computation with physical cores in
993      the TPU topology. Uses a default device assignment if `None`. The
994      `DeviceAssignment` may be omitted if each replica of the computation uses
995      only one core, and there is either only one replica, or the number of
996      replicas is equal to the number of cores in the TPU system.
997    name: (Deprecated) Does nothing.
998    use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU
999      backends. Currently, only supports a default placement (computation is
1000      placed on GPU if one is available, and on CPU if not).
1001    maximum_shapes: A nested structure of tf.TensorShape representing the shape
1002      to which the respective component of each input element in each replica
1003      should be padded. Any unknown dimensions (e.g.
1004      tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like
1005      object) will be padded to the maximum size of that dimension over all
1006      replicas. The structure of `maximum_shapes` needs to be the same as
1007      `inputs[0]`.
1008
1009  Returns:
1010    A list of lists with the first list corresponding to the compile op and the
1011    second a list of output tensors, indexed by `[replica_num][output_num]`.
1012  Raises:
1013    ValueError: If all replicas do not have equal numbers of input tensors.
1014    ValueError: If the number of inputs per replica does not match
1015      the number of formal parameters to `computation`.
1016    ValueError: If the static `inputs` dimensions don't match with the values
1017      given in `maximum_shapes`.
1018    ValueError: If the structure of inputs per replica does not match
1019      the structure of `maximum_shapes`.
1020  """
1021  del name
1022  inputs = [[]] if inputs is None else inputs
1023
1024  metadata_kwargs = {}
1025  if device_assignment is not None:
1026    # Turn the Numpy array into a flattened list so we can pass it as an
1027    # operator attribute.
1028    metadata_kwargs = {
1029        "topology":
1030            device_assignment.topology.serialized(),
1031        "device_assignment":
1032            device_assignment.core_assignment.flatten().tolist()
1033    }
1034    metadata_kwargs["num_cores_per_replica"] = (
1035        device_assignment.num_cores_per_replica)
1036  # This entry is used for enabling automatic outside compilation.
1037  metadata_kwargs["allow_soft_placement"] = config.get_soft_device_placement()
1038
1039  if ((not isinstance(inputs, list)) or
1040      any(not isinstance(inp, (list, tuple)) for inp in inputs)):
1041    raise TypeError("tpu.replicate() inputs must be a list of lists/tuples")
1042
1043  num_replicas = len(inputs)
1044
1045  # No replicas? Nothing to do.
1046  if num_replicas == 0:
1047    return []
1048
1049  # Checks all replicas have the same structure.
1050  for i in xrange(1, num_replicas):
1051    nest.assert_same_structure(inputs[0], inputs[i])
1052
1053  # Flatten inputs.
1054  flat_inputs = [
1055      nest.flatten(per_replica_input) for per_replica_input in inputs
1056  ]
1057  # Converts inputs to Tensors.
1058  flat_inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in flat_inputs]
1059
1060  # Verifies that all replicas have matching numbers and types of inputs
1061  flat_input_types = [x.dtype for x in flat_inputs[0]]
1062  input_arity = len(inputs[0])
1063  flat_input_arity = len(flat_input_types)
1064  for i in range(num_replicas):
1065    if len(inputs[i]) != input_arity:
1066      raise ValueError("Replicas must have the same number of inputs. "
1067                       "Replica 0 had {} inputs, replica {} had {} "
1068                       "inputs.".format(input_arity, i, len(inputs[i])))
1069
1070    types = [x.dtype for x in flat_inputs[i]]
1071    if types != flat_input_types:
1072      raise ValueError("Replicas must have matching input types. Replica 0 had "
1073                       "input types {}, replica {} had input types {}".format(
1074                           flat_input_types, i, types))
1075
1076  arg_error = xla.check_function_argument_count(
1077      computation, input_arity, infeed_queue)
1078  if arg_error is not None:
1079    if infeed_queue is None:
1080      raise TypeError(
1081          "Supplied computation cannot be called with the specified inputs. "
1082          "You specified %d inputs: %s, but the computation needs %s" % (
1083              input_arity, str([i.name for i in inputs[0]]), arg_error))
1084    else:
1085      raise TypeError(
1086          "Supplied computation cannot be called with the specified inputs. "
1087          "You specified %d inputs: %s and %d additional inputs from infeed,"
1088          " but the computation needs %s" % (input_arity, str(
1089              [i.name
1090               for i in inputs[0]]), infeed_queue.number_of_tuple_elements,
1091                                             arg_error))
1092
1093  if maximum_shapes:
1094    if infeed_queue:
1095      raise ValueError(
1096          "Dynamic input shapes are not supported with infeed queues")
1097
1098    # Make sure maximum_shapes has the same structure as inputs.
1099    nest.assert_same_structure(inputs[0], maximum_shapes, check_types=False)
1100
1101    # Flatten padded shapes.
1102    flat_maximum_shapes = nest.flatten(maximum_shapes)
1103    flat_maximum_shapes = [
1104        tensor_shape.TensorShape(s) for s in flat_maximum_shapes
1105    ]
1106
1107    flat_inputs, padding_maps = _pad_all_input(flat_inputs, flat_maximum_shapes)
1108
1109    serialized_padding_maps = []
1110    for padding_map in padding_maps:
1111      serialized_padding_maps.append(padding_map.SerializeToString())
1112    metadata_kwargs["padding_map"] = serialized_padding_maps
1113
1114  metadata_kwargs["step_marker_location"] = getattr(
1115      computation, "step_marker_location", "STEP_MARK_AT_ENTRY")
1116
1117  graph = ops.get_default_graph()
1118
1119  # Fan-in: Builds a TPUReplicatedInput node for each input.
1120  flat_replicated_inputs = []
1121  for i in range(0, len(flat_inputs[0])):
1122    replicas = [flat_inputs[replica][i] for replica in xrange(num_replicas)]
1123    flat_replicated_inputs.append(
1124        tpu_ops.tpu_replicated_input(
1125            replicas, name="input{}".format(i), index=i))
1126  if isinstance(graph, func_graph.FuncGraph):
1127    # When we are in Tensorflow 2.0 function, 'graph' will be a FuncGraph
1128    # object. If both outside graph and this function have a TPU cluster,
1129    # they will have the same cluster name and it will cause problems (because
1130    # we lower functional ops in Tensorflow 2.0). Append function name to
1131    # 'cluster_name' to avoid cluster name collision.
1132    cluster_name = graph.unique_name("cluster_" + graph.name)
1133  else:
1134    cluster_name = graph.unique_name("cluster")
1135  pivot = control_flow_ops.no_op(name=cluster_name + "/pivot")
1136  context = TPUReplicateContext(
1137      name=cluster_name, num_replicas=num_replicas, pivot=pivot)
1138  try:
1139    context.Enter()
1140
1141    metadata = tpu_ops.tpu_replicate_metadata(
1142        num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs)
1143
1144    with tpu_function.tpu_shard_context(
1145        num_replicas), ops.control_dependencies([metadata]):
1146
1147      # Add identity ops so even unused inputs are "consumed" by the
1148      # computation. This is to avoid orphaned TPUReplicatedInput nodes.
1149      # TODO(phawkins): consider instead pruning unused TPUReplicatedInput
1150      # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs.
1151      flat_replicated_inputs = [
1152          array_ops.identity(x, name="replicated_input_{}".format(i))
1153          for i, x in enumerate(flat_replicated_inputs)
1154      ]
1155      for i in flat_replicated_inputs:
1156        # pylint: disable=protected-access
1157        # Add an attribute to the identity node so that they could be removed in
1158        # encapsulate TPU computation pass if unused. However we don't remove
1159        # inputs when dynamic padding is enabled.
1160        # TODO(rxsang): Use other ways except argument index in padding_map so
1161        # outside compilation can work with dynamic padding correctly.
1162        if maximum_shapes is None:
1163          i.op._set_attr("_tpu_input_identity",
1164                         attr_value_pb2.AttrValue(b=True))
1165        # pylint: enable=protected-access
1166
1167      # Unflatten the computation inputs to match original input structure.
1168      computation_inputs = nest.pack_sequence_as(
1169          structure=inputs[0],
1170          flat_sequence=flat_replicated_inputs[:flat_input_arity])
1171
1172      # If there is an infeed queue, adds the dequeued values to the
1173      # computation's inputs.
1174      if infeed_queue is not None:
1175        infeed_queue.set_number_of_shards(num_replicas)
1176        for t in infeed_queue.generate_dequeue_op():
1177          computation_inputs.append(t)
1178
1179      # Only resource variables work inside a TPU computation, so turn on
1180      # resource variables for the computation.
1181      # TODO(phawkins): consider removing this code. It will
1182      # be less confusing to clients if they knowingly choose to use resource
1183      # variables.
1184      # Partitioned variables is not supported (b/112311320).
1185      vscope = variable_scope.get_variable_scope()
1186      saved_use_resource = vscope.use_resource
1187      saved_custom_getter = vscope.custom_getter
1188
1189      def custom_getter(getter, name, *args, **kwargs):
1190        """Variables on TPU have a few restrictions."""
1191        partitioner = kwargs["partitioner"]
1192        if partitioner is not None:
1193          kwargs["partitioner"] = None
1194          logging.warning(
1195              "Partitioned variables are not supported on TPU. Got "
1196              "`partitioner` that is {} for variable {}. "
1197              "Setting `partitioner` to `None`."
1198              .format(partitioner, name))
1199        if saved_custom_getter is None:
1200          return getter(name, *args, **kwargs)
1201        else:
1202          return saved_custom_getter(getter, name, *args, **kwargs)
1203
1204      vscope.set_use_resource(True)
1205      vscope.set_custom_getter(custom_getter)
1206
1207      outputs = computation(*computation_inputs)
1208
1209      vscope.set_use_resource(saved_use_resource)
1210      vscope.set_custom_getter(saved_custom_getter)
1211
1212    outputs_is_flat = xla.is_flat(outputs)
1213    if outputs_is_flat:
1214      output_tensors, control_deps = _postprocess_flat_outputs(outputs)
1215    else:
1216      output_tensors, control_deps = _postprocess_non_flat_outputs(outputs)
1217
1218    # tensor_tracer imports tpu.py. Local import to tensor_tracer to avoid
1219    # import-cycle
1220    # pylint: disable=g-import-not-at-top
1221    from tensorflow.python.tpu import tensor_tracer
1222    # pylint: enable=g-import-not-at-top
1223    if tensor_tracer.TensorTracer.is_enabled():
1224      tt = tensor_tracer.TensorTracer()
1225      output_tensors = tt.trace_tpu(ops.get_default_graph(),
1226                                    output_tensors, control_deps,
1227                                    num_replicas)
1228
1229    context.ExitResult(output_tensors)
1230  finally:
1231    context.report_unsupported_operations()
1232    context.Exit()
1233    host_compute_core = context.HostComputeCore()
1234
1235  if host_compute_core:
1236    attr_value = attr_value_pb2.AttrValue()
1237    attr_value.list.s.extend(compat.as_bytes(x) for x in host_compute_core)
1238    metadata._set_attr("host_compute_core", attr_value)  # pylint: disable=protected-access
1239
1240  with ops.control_dependencies([metadata]):
1241    if use_tpu:
1242      compile_status = tpu_ops.tpu_compilation_result()
1243      op = compile_status.op
1244      attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name))
1245      op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value)  # pylint: disable=protected-access
1246    else:
1247      compile_status = control_flow_ops.no_op(name="compilation_status")
1248
1249  if not output_tensors:
1250    # Returns a list of NoOps dependent on the replication Op, indexed by
1251    # [replica_num].
1252    return [
1253        compile_status,
1254        [
1255            control_flow_ops.group(control_deps, name="shard_%d" % i)
1256            for i in range(num_replicas)
1257        ]
1258    ]
1259
1260  # Fan-out: Builds a TPUReplicatedOutput node for each output.
1261  replicated_outputs = [[] for i in xrange(num_replicas)]
1262  for i, t in enumerate(output_tensors):
1263    # Fan-out: Builds a TPUReplicatedOutput node for each output.
1264    ys = tpu_ops.tpu_replicated_output(
1265        t, num_replicas, name="output{}".format(i))
1266
1267    # Wraps the outputs in identity operators so the names of any possible
1268    # `fetch` nodes are preserved by the replication rewrite.
1269    with ops.control_dependencies(control_deps):
1270      for replica in xrange(num_replicas):
1271        replicated_outputs[replica].append(
1272            array_ops.identity(
1273                ys[replica], name="output_%d_shard_%d" % (i, replica)))
1274
1275  if not outputs_is_flat:
1276    replicated_outputs = [
1277        nest.pack_sequence_as(outputs, replica_outs)
1278        for replica_outs in replicated_outputs
1279    ]
1280
1281  return [compile_status, replicated_outputs]
1282
1283
1284def _postprocess_flat_outputs(outputs):
1285  """Validates non-flat outputs, add backs device assignments and other attrs.
1286
1287  Args:
1288    outputs: Output from `computation` inside `tpu.rewrite`.
1289
1290  Returns:
1291    Tensors and Operations extracted from outputs.
1292  """
1293  # Following code segment is to preserve legacy behavior. Previously we only
1294  # supported flat outputs and thus for consistency it was nice to convert even
1295  # single element into a tuple. But now that we support arbitrary output
1296  # structure, this is no longer necessary.
1297  # TODO(b/121383831): Migrate all legacy use cases and delete this special
1298  # case.
1299  # If the computation returns `None`, make it an empty tuple.
1300  if outputs is None:
1301    outputs = tuple()
1302  # If the computation only returned one value, makes it a tuple.
1303  if not isinstance(outputs, collections_abc.Sequence):
1304    outputs = (outputs,)
1305
1306  # Append `no_op` here so that fetching any return value of this function
1307  # will trigger TPUExecute node.
1308  outputs += (control_flow_ops.no_op(),)
1309  try:
1310    with ops.device(core(0)):
1311      outputs = [
1312          o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
1313          for o in outputs
1314      ]
1315  except Exception as e:
1316    raise ValueError(
1317        "TPU function return values must all either be Operations or "
1318        "convertible to Tensors. Got '%s'" % str(e))
1319
1320  # Separates the returned Operations and Tensors.
1321  output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
1322  output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
1323
1324  if outputs != output_tensors + output_operations:
1325    raise ValueError(
1326        "TPU functions must return zero-or more Tensor values followed by "
1327        "zero or more Operations.")
1328
1329  # Wraps outputs in Identity ops. Otherwise a replicated input copied
1330  # straight to an output would bypass the replicate(). This would be bad
1331  # because the TPUReplicatedInput/TPUReplicatedOutput operator would not
1332  # be rewritten away, leading to a runtime error.
1333  # TODO(phawkins): extend the rewrite to elide these nodes instead.
1334  new_output_tensors = []
1335  for t in output_tensors:
1336    with ops.device(t.device if t.device else core(0)):
1337      o = array_ops.identity(t)
1338      # pylint: disable=protected-access
1339      o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True))
1340      # pylint: enable=protected-access
1341      new_output_tensors.append(o)
1342  return new_output_tensors, output_operations
1343
1344
1345def _postprocess_non_flat_outputs(outputs):
1346  """Validates non-flat outputs, add backs device assignments and other attrs.
1347
1348  Args:
1349    outputs: Output from `computation` inside `tpu.rewrite`.
1350
1351  Returns:
1352    Tensors extracted from outputs and an empty list because Operations are not
1353    allowed in non-flat outputs..
1354  """
1355
1356  # Flatten output items.
1357  flat_outputs = nest.flatten(outputs)
1358
1359  # Convert all non-Operation outputs to Tensors.
1360  for i, o in enumerate(flat_outputs):
1361    if isinstance(o, ops.Operation):
1362      raise ValueError(
1363          "tpu.rewrite does not support Operation as return value in non-flat "
1364          "output structure. You can set returned Operations as control "
1365          "dependencies of returned Tensors so Operations are triggered when "
1366          'Tensors are evaluated. Operation found: "%s"' % o.name)
1367
1368    try:
1369      o = ops.convert_to_tensor(o)
1370    except Exception as e:
1371      raise ValueError(
1372          "TPU function return values must all either be Operations or "
1373          'convertible to Tensors. Got error: "%s"' % str(e))
1374
1375    # Wraps outputs in Identity ops. Otherwise a replicated input copied
1376    # straight to an output would bypass the replicate(). This would be bad
1377    # because the TPUReplicatedInput/TPUReplicatedOutput operator would not
1378    # be rewritten away, leading to a runtime error.
1379    # TODO(phawkins): extend the rewrite to elide these nodes instead.
1380    with ops.device(core(0)):
1381      o = array_ops.identity(o)
1382      # pylint: disable=protected-access
1383      o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True))
1384      # pylint: enable=protected-access
1385      flat_outputs[i] = array_ops.identity(o)
1386
1387  # All flat_outputs are Tensors, and no Operations.
1388  return flat_outputs, []
1389
1390
1391def split_compile_and_shard(computation,
1392                            inputs=None,
1393                            num_shards=1,
1394                            input_shard_axes=None,
1395                            outputs_from_all_shards=True,
1396                            output_shard_axes=None,
1397                            infeed_queue=None,
1398                            device_assignment=None,
1399                            name=None):
1400  """Shards `computation` for parallel execution.
1401
1402  `inputs` must be a list of Tensors or None (equivalent to an empty list), each
1403  of which has a corresponding split axis (from `input_shard_axes`). Each input
1404  is split into `num_shards` pieces along the corresponding axis, and
1405  computation is applied to each shard in parallel.
1406
1407  Tensors are broadcast to all shards if they are lexically captured by
1408  `computation`. e.g.,
1409
1410  x = tf.constant(7)
1411  def computation():
1412    return x + 3
1413  ... = shard(computation, ...)
1414
1415  If `outputs_from_all_shards` is true, the outputs from all shards of
1416  `computation` are concatenated back together along their `output_shard_axes`.
1417  Otherwise, each output is taken from an arbitrary shard.
1418
1419  Inputs and outputs of the computation must be at least rank-1 Tensors.
1420
1421  Args:
1422    computation: A Python function that builds a computation to apply to each
1423      shard of the input.
1424    inputs: A list of input tensors or None (equivalent to an empty list). Each
1425      input tensor has a corresponding shard axes, given by `input_shard_axes`,
1426      which must have size divisible by `num_shards`.
1427    num_shards: The number of shards.
1428    input_shard_axes: A list of dimensions along which to shard `inputs`, or
1429      `None`. `None` means "shard all inputs along dimension 0". If not `None`,
1430      there must be one dimension per input.
1431    outputs_from_all_shards: Boolean or list of boolean. For each output, if
1432      `True`, outputs from all shards are concatenated along the corresponding
1433      `output_shard_axes` entry. Otherwise, each output is taken
1434      from an arbitrary shard. If the argument is a boolean, the argument's
1435      value is used for each output.
1436    output_shard_axes: A list of dimensions along which to concatenate the
1437      outputs of `computation`, or `None`. `None` means "concatenate all outputs
1438      along dimension 0". If not `None`, there must be one dimension per output.
1439      Ignored if `outputs_from_all_shards` is False.
1440    infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs
1441      of `computation`.
1442    device_assignment: If not `None`, a `DeviceAssignment` describing the
1443      mapping between logical cores in the computation with physical cores in
1444      the TPU topology. Uses a default device assignment if `None`. The
1445      `DeviceAssignment` may be omitted if each shard of the computation uses
1446      only one core, and there is either only one shard, or the number of shards
1447      is equal to the number of cores in the TPU system.
1448    name: (Deprecated) Does nothing.
1449  Returns:
1450    A tuple of (compile op, [output tensors]).
1451  Raises:
1452    ValueError: If num_shards <= 0
1453    ValueError: If len(input_shard_axes) != len(inputs)
1454    ValueError: If len(output_shard_axes) != len(outputs from `computation`)
1455  """
1456  # TODO(phawkins): consider adding support for broadcasting Tensors passed as
1457  # inputs.
1458
1459  if num_shards <= 0:
1460    raise ValueError("num_shards must be a positive integer.")
1461
1462  inputs = [] if inputs is None else inputs
1463  if not isinstance(inputs, list):
1464    raise TypeError("tpu.shard()'s inputs must be a list of Tensors or None.")
1465
1466  # Converts inputs to Tensors.
1467  inputs = [ops.convert_to_tensor(x) for x in inputs]
1468
1469  if input_shard_axes is None:
1470    input_shard_axes = [0] * len(inputs)
1471  if len(inputs) != len(input_shard_axes):
1472    raise ValueError("Length of input_shard_axes must be equal to the number "
1473                     "of inputs.")
1474
1475  if inputs:
1476    # Splits the `inputs` along the corresponding `input_shard_axes`, giving
1477    # lists with layout [input][shard]
1478    split_inputs = [
1479        array_ops.split(x, num_shards, axis=axis)
1480        for (axis, x) in zip(input_shard_axes, inputs)]
1481
1482    # Transposes the input lists to have layout [shard][input]
1483    transposed_inputs = [list(i) for i in zip(*split_inputs)]
1484  else:
1485    transposed_inputs = [[]] * num_shards
1486
1487  compile_op, outputs = split_compile_and_replicate(
1488      computation,
1489      transposed_inputs,
1490      infeed_queue=infeed_queue,
1491      device_assignment=device_assignment,
1492      name=name)
1493
1494  # There must be at least one shard since num_shards > 0.
1495  # TODO(b/36647078) remove disable when pylint bug is fixed.
1496  # pylint: disable=indexing-exception
1497  if isinstance(outputs[0], ops.Operation):
1498    # pylint: enable=indexing-exception
1499    # There were no outputs from the computation and replicate returned a list
1500    # of NoOps with control dependencies on the computation. Return the first
1501    # one so it can be used as a control dependency or fetch node.
1502    # TODO(b/36647078) remove disable when pylint bug is fixed.
1503    # pylint: disable=indexing-exception
1504    return compile_op, [outputs[0]]
1505    # pylint: enable=indexing-exception
1506
1507  # TODO(b/36647078) remove disable when pylint bug is fixed.
1508  # pylint: disable=indexing-exception
1509  num_outputs = len(outputs[0])
1510  # pylint: enable=indexing-exception
1511
1512  if output_shard_axes is None:
1513    output_shard_axes = [0] * num_outputs
1514  if num_outputs != len(output_shard_axes):
1515    raise ValueError("Length of output_shard_axes must be equal to the number "
1516                     "of outputs.")
1517
1518  if isinstance(outputs_from_all_shards, bool):
1519    outputs_from_all_shards = [outputs_from_all_shards] * num_outputs
1520
1521  if num_outputs != len(outputs_from_all_shards):
1522    raise ValueError("Length of outputs_from_all_shards must be equal to the "
1523                     "number of outputs.")
1524
1525  results = []
1526  for (axis, all_shards, x) in zip(output_shard_axes, outputs_from_all_shards,
1527                                   zip(*outputs)):
1528    if all_shards:
1529      # Concatenate all of the outputs together (use stack for scalars).
1530      shape = x[0].shape
1531      is_scalar = shape is not None and (shape.ndims == 0)
1532      results.append((array_ops.stack(list(x)) if is_scalar
1533                      else array_ops.concat(list(x), axis=axis)))
1534    else:
1535      # TODO(phawkins): use a smarter policy, e.g., round-robin across shards.
1536      results.append(x[0])
1537
1538  return compile_op, results
1539
1540
1541@tf_export(v1=["tpu.shard"])
1542def shard(computation,
1543          inputs=None,
1544          num_shards=1,
1545          input_shard_axes=None,
1546          outputs_from_all_shards=True,
1547          output_shard_axes=None,
1548          infeed_queue=None,
1549          device_assignment=None,
1550          name=None):
1551  """Shards `computation` for parallel execution.
1552
1553  `inputs` must be a list of Tensors or None (equivalent to an empty list), each
1554  of which has a corresponding split axis (from `input_shard_axes`). Each input
1555  is split into `num_shards` pieces along the corresponding axis, and
1556  computation is applied to each shard in parallel.
1557
1558  Tensors are broadcast to all shards if they are lexically captured by
1559  `computation`. e.g.,
1560
1561  x = tf.constant(7)
1562  def computation():
1563    return x + 3
1564  ... = shard(computation, ...)
1565
1566  TODO(phawkins): consider adding support for broadcasting Tensors passed
1567  as inputs.
1568
1569  If `outputs_from_all_shards` is true, the outputs from all shards of
1570  `computation` are concatenated back together along their `output_shard_axes`.
1571  Otherwise, each output is taken from an arbitrary shard.
1572
1573  Inputs and outputs of the computation must be at least rank-1 Tensors.
1574
1575  Args:
1576    computation: A Python function that builds a computation to apply to each
1577      shard of the input.
1578    inputs: A list of input tensors or None (equivalent to an empty list). Each
1579      input tensor has a corresponding shard axes, given by `input_shard_axes`,
1580      which must have size divisible by `num_shards`.
1581    num_shards: The number of shards.
1582    input_shard_axes: A list of dimensions along which to shard `inputs`, or
1583      `None`. `None` means "shard all inputs along dimension 0". If not `None`,
1584      there must be one dimension per input.
1585    outputs_from_all_shards: Boolean or list of boolean. For each output, if
1586      `True`, outputs from all shards are concatenated along the corresponding
1587      `output_shard_axes` entry. Otherwise, each output is taken
1588      from an arbitrary shard. If the argument is a boolean, the argument's
1589      value is used for each output.
1590    output_shard_axes: A list of dimensions along which to concatenate the
1591      outputs of `computation`, or `None`. `None` means "concatenate all outputs
1592      along dimension 0". If not `None`, there must be one dimension per output.
1593      Ignored if `outputs_from_all_shards` is False.
1594    infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs
1595      of `computation`.
1596    device_assignment: If not `None`, a `DeviceAssignment` describing the
1597      mapping between logical cores in the computation with physical cores in
1598      the TPU topology. Uses a default device assignment if `None`. The
1599      `DeviceAssignment` may be omitted if each shard of the computation uses
1600      only one core, and there is either only one shard, or the number of shards
1601      is equal to the number of cores in the TPU system.
1602    name: (Deprecated) Does nothing.
1603  Returns:
1604    A list of output tensors.
1605  Raises:
1606    ValueError: If num_shards <= 0
1607    ValueError: If len(input_shard_axes) != len(inputs)
1608    ValueError: If len(output_shard_axes) != len(outputs from `computation`)
1609  """
1610  return split_compile_and_shard(
1611      computation,
1612      inputs=inputs,
1613      num_shards=num_shards,
1614      input_shard_axes=input_shard_axes,
1615      outputs_from_all_shards=outputs_from_all_shards,
1616      output_shard_axes=output_shard_axes,
1617      infeed_queue=infeed_queue,
1618      device_assignment=device_assignment,
1619      name=name)[1]
1620
1621
1622@tf_export(v1=["tpu.batch_parallel"])
1623def batch_parallel(computation,
1624                   inputs=None,
1625                   num_shards=1,
1626                   infeed_queue=None,
1627                   device_assignment=None,
1628                   name=None):
1629  """Shards `computation` along the batch dimension for parallel execution.
1630
1631  Convenience wrapper around shard().
1632
1633  `inputs` must be a list of Tensors or None (equivalent to an empty list).
1634  Each input is split into `num_shards` pieces along the 0-th dimension, and
1635  computation is applied to each shard in parallel.
1636
1637  Tensors are broadcast to all shards if they are lexically captured by
1638  `computation`. e.g.,
1639
1640  x = tf.constant(7)
1641  def computation():
1642    return x + 3
1643  ... = shard(computation, ...)
1644
1645  The outputs from all shards are concatenated back together along their 0-th
1646  dimension.
1647
1648  Inputs and outputs of the computation must be at least rank-1 Tensors.
1649
1650  Args:
1651    computation: A Python function that builds a computation to apply to each
1652      shard of the input.
1653    inputs: A list of input tensors or None (equivalent to an empty list). The
1654      0-th dimension of each Tensor must have size divisible by `num_shards`.
1655    num_shards: The number of shards.
1656    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
1657      of arguments as inputs to `computation`.
1658    device_assignment: If not `None`, a `DeviceAssignment` describing the
1659      mapping between logical cores in the computation with physical cores in
1660      the TPU topology. Uses a default device assignment if `None`. The
1661      `DeviceAssignment` may be omitted if each shard of the computation uses
1662      only one core, and there is either only one shard, or the number of shards
1663      is equal to the number of cores in the TPU system.
1664    name: (Deprecated) Does nothing.
1665  Returns:
1666    A list of output tensors.
1667  Raises:
1668    ValueError: If `num_shards <= 0`
1669  """
1670  return shard(
1671      computation,
1672      inputs,
1673      num_shards=num_shards,
1674      infeed_queue=infeed_queue,
1675      device_assignment=device_assignment,
1676      name=name)
1677
1678
1679@tf_export(v1=["tpu.rewrite"])
1680def rewrite(computation,
1681            inputs=None,
1682            infeed_queue=None,
1683            device_assignment=None,
1684            name=None):
1685  """Rewrites `computation` for execution on a TPU system.
1686
1687  Args:
1688    computation: A Python function that builds a computation to apply to the
1689      input. If the function takes n inputs, 'inputs' should be a list of n
1690      tensors.
1691
1692      `computation` may return a list of operations and tensors. Tensors must
1693      come before operations in the returned list.  The return value of
1694      `rewrite` is a list of tensors corresponding to the tensors from the
1695      output of `computation`.
1696
1697      All `Operation`s constructed during `computation` will be executed when
1698      evaluating any of the returned output tensors, not just the ones returned.
1699    inputs: A list of input tensors or `None` (equivalent to an empty list).
1700      Each input can be a nested structure containing values that are
1701      convertible to tensors. Note that passing an N-dimension list of
1702      compatible values will result in a N-dimention list of scalar tensors
1703      rather than a single Rank-N tensors. If you need different behavior,
1704      convert part of inputs to tensors with `tf.convert_to_tensor`.
1705    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
1706      of arguments as inputs to `computation`.
1707    device_assignment: if not `None`, a `DeviceAssignment` describing the
1708      mapping between logical cores in the computation with physical cores in
1709      the TPU topology. May be omitted for a single-core computation, in which
1710      case the core attached to task 0, TPU device 0 is used.
1711    name: (Deprecated) Does nothing.
1712  Returns:
1713    Same data structure as if computation(*inputs) is called directly with some
1714    exceptions for correctness. Exceptions include:
1715      1) None output: a NoOp would be returned which control-depends on
1716         computation.
1717      2) Single value output: A tuple containing the value would be returned.
1718      3) Operation-only outputs: a NoOp would be returned which
1719         control-depends on computation.
1720      TODO(b/121383831): Investigate into removing these special cases.
1721  """
1722  # TODO(b/36647078) remove disable when pylint bug is fixed.
1723  # pylint: disable=indexing-exception
1724  return replicate(
1725      computation,
1726      None if inputs is None else [inputs],
1727      infeed_queue=infeed_queue,
1728      device_assignment=device_assignment,
1729      name=name)[0]
1730  # pylint: enable=indexing-exception
1731
1732  # Operations that indicate some error in the user's inference graph.
1733_BLACKLISTED_INFERENCE_OPS = set([
1734    "ReadVariableOp",
1735    "AssignVariableOp",
1736    "AssignAddVariableOp",
1737    "AssignSubVariableOp",
1738    "VarHandleOp",
1739    "Variable",
1740    "VariableV2",
1741])
1742
1743
1744def under_tpu_inference_context():
1745  """Check if it is currently under `_TPUInferenceContext`."""
1746  graph = ops.get_default_graph()
1747  while graph:
1748    context = graph._get_control_flow_context()  # pylint: disable=protected-access
1749    while context:
1750      if isinstance(context, _TPUInferenceContext):
1751        return True
1752      context = context.outer_context
1753    if isinstance(graph, function._FuncGraph):  # pylint: disable=protected-access
1754      graph = graph._outer_graph  # pylint: disable=protected-access
1755    elif isinstance(graph, func_graph.FuncGraph):
1756      graph = graph.outer_graph
1757    else:
1758      return False
1759
1760
1761class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext):
1762  """A `ControlFlowContext` for nodes inside a TPU inference computation.
1763
1764  The primary role of `_TPUInferenceContext` is to indicate the mode of
1765  operation and possibly sanity check operators inside a
1766  tpu.rewrite_for_inference() computation.
1767  """
1768
1769  def __init__(self, name, check_ops=True):
1770    super(_TPUInferenceContext, self).__init__()
1771    self._name = name
1772    self._check_ops = check_ops
1773
1774  def AddOp(self, op):
1775    self._AddOpInternal(op)
1776
1777  def _AddOpInternal(self, op):
1778    # pylint: disable=protected-access
1779    if self._check_ops and op.type in _BLACKLISTED_INFERENCE_OPS:
1780      raise NotImplementedError(
1781          "Operation of type %s (%s) is not supported on the TPU for inference."
1782          " Execution will fail if this op is used in the graph. Make sure your"
1783          " variables are using variable_scope." % (op.type, op.name))
1784    if self._outer_context:
1785      self._outer_context.AddInnerOp(op)
1786
1787  def AddValue(self, val):
1788    result = val
1789    if self._outer_context:
1790      result = self._outer_context.AddValue(val)
1791    return result
1792
1793  def AddInnerOp(self, op):
1794    self._AddOpInternal(op)
1795
1796  @property
1797  def grad_state(self):
1798    return None
1799
1800
1801def validate_inference_rewrite_for_variables(graph):
1802  """Validates whether rewrite_for_inference() 'worked' for variables.
1803
1804     The rewrite_for_inference() method is supposed to append GuaranteeConstOps
1805     after ReadVariableOps, but this mechanism works only if you are using
1806     tf.compat.v1.get_variable() to create and access variables in your tpu
1807     computation. This validation method can be called immediately after calling
1808     tpu.rewrite_for_inference() to check whether GuaranteeConstOps where added
1809     to the graph.
1810
1811     Typical usages:
1812       tpu.validate_inference_rewrite_for_variables(
1813           tf.compat.v1.get_default_graph())
1814
1815       tpu.validate_inference_rewrite_for_variables(sess.graph)
1816
1817  Args:
1818    graph: The graph which needs to be validated.
1819  Raises:
1820    RuntimeError: if validation failed.
1821  """
1822  if not any(x.type == "GuaranteeConst" for x in graph.get_operations()):
1823    raise RuntimeError(
1824        "No GuaranteeConst ops found in the graph after running "
1825        "tpu.rewrite_for_inference(...). Please check that you are using "
1826        "tf.get_variable() to create and access variables in your tpu "
1827        "computation.")
1828
1829
1830def rewrite_for_inference(computation,
1831                          inputs=None,
1832                          infeed_queue=None,
1833                          device_assignment=None,
1834                          name=None):
1835  """Rewrites `computation` for inference on a TPU system.
1836
1837     Other than 'rewriting' the computation to run on a TPU, if using variables
1838     in your computation, it moves the ReadVariableOps outside the TPU
1839     computation, and adds GuaranteeConst ops just after the ReadVariableOps.
1840     This mechanism works only if you are using tf.compat.v1.get_variable() to
1841     create and access variables in your tpu computation. You can validate
1842     whether this worked, by calling validate_inference_rewrite_for_variables()
1843     method immediately after this method to check whether GuaranteeConstOps
1844     where added to the graph.
1845
1846  Args:
1847    computation: A Python function that builds a computation to apply to the
1848      input. If the function takes n inputs, 'inputs' should be a list of n
1849      tensors. If the function returns m outputs, rewrite will return a list of
1850      m tensors.
1851    inputs: A list of input tensors or `None` (equivalent to an empty list).
1852    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
1853      of arguments as inputs to `computation`.
1854    device_assignment: if not `None`, a `DeviceAssignment` describing the
1855      mapping between logical cores in the computation with physical cores in
1856      the TPU topology. May be omitted for a single-core computation, in which
1857      case the core attached to task 0, TPU device 0 is used.
1858    name: The name of the operator.
1859  Returns:
1860    A list of output tensors.
1861  """
1862
1863  def guarantee_const_getter(getter, name, *args, **kwargs):
1864    with ops.control_dependencies(None):
1865      return array_ops.guarantee_const(
1866          getter(name, *args, **kwargs), name=name + "/GuaranteeConst")
1867
1868  def wrapped_computation(*args, **kwargs):
1869    """Execute computation under `_TPUInferenceContext`."""
1870    context = _TPUInferenceContext(
1871        name=ops.get_default_graph().unique_name("rewrite_for_inference"))
1872    try:
1873      context.Enter()
1874
1875      vscope = variable_scope.get_variable_scope()
1876      prev_custom_getter = vscope.custom_getter
1877      prev_caching_device = vscope.caching_device
1878      vscope.set_custom_getter(guarantee_const_getter)
1879      vscope.set_caching_device(lambda op: op.device)
1880
1881      result = computation(*args, **kwargs)
1882
1883      vscope.set_custom_getter(prev_custom_getter)
1884      vscope.set_caching_device(prev_caching_device)
1885    finally:
1886      context.Exit()
1887    return result
1888
1889  # pylint: disable=undefined-variable
1890  return rewrite(
1891      wrapped_computation,
1892      inputs=inputs,
1893      infeed_queue=infeed_queue,
1894      device_assignment=device_assignment,
1895      name=name)
1896  # pylint: enable=undefined-variable
1897
1898
1899def prune_unconnected_ops_from_xla(prune_graph):
1900  """Prunes unconnected ops as listed in _UNCONNECTED_OPS_TO_PRUNE.
1901
1902  Args:
1903    prune_graph: A tensorflow graph from which we wish to prune unconnected ops
1904      as listed in _UNCONNECTED_OPS_TO_PRUNE.  In general, these ops should have
1905      no inputs and no consumers. These can often be left behind due to graph
1906      construction rewiring (for instance TF-Hub). While they never execute,
1907      they will cause XLA compile to fail so we strip them from XLA compile by
1908      removing the tpu_replicate attribute.
1909  """
1910  # Scan over the top level graph and all function graphs.
1911  for graph in [prune_graph] + [
1912      f for f in prune_graph._functions.values()  # pylint: disable=protected-access
1913  ]:
1914    if not isinstance(graph, ops.Graph):
1915      continue
1916    for op in graph.get_operations():
1917      if op.type not in _UNCONNECTED_OPS_TO_PRUNE:
1918        continue
1919      outputs_consumed = False
1920      for output in op.outputs:
1921        if output.consumers():
1922          outputs_consumed = True
1923          break
1924      if not outputs_consumed:
1925        logging.info(
1926            "Pruning OP %s of type %s from XLA Compile due to "
1927            "it being disconnected.", op.name, op.type)
1928        op._clear_attr(_TPU_REPLICATE_ATTR)  # pylint: disable=protected-access
1929