• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Compiled parallel-for loop."""
16# pylint: disable=missing-docstring,g-direct-tensorflow-import
17
18import collections
19import string
20import sys
21import traceback
22
23import numpy as np
24from functools import partial
25
26from tensorflow.compiler.tf2xla.python import xla
27from tensorflow.core.framework import full_type_pb2
28from tensorflow.python.eager import context
29from tensorflow.python.eager import def_function
30from tensorflow.python.eager import execute
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import func_graph
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import smart_cond
36from tensorflow.python.framework import sparse_tensor
37from tensorflow.python.framework import tensor_shape
38from tensorflow.python.framework import tensor_spec
39from tensorflow.python.framework import tensor_util
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import control_flow_ops
42from tensorflow.python.ops import data_flow_ops
43from tensorflow.python.ops import gen_array_ops
44from tensorflow.python.ops import gen_dataset_ops
45from tensorflow.python.ops import gen_image_ops
46from tensorflow.python.ops import gen_linalg_ops
47from tensorflow.python.ops import gen_list_ops
48from tensorflow.python.ops import gen_math_ops
49from tensorflow.python.ops import gen_nn_ops
50from tensorflow.python.ops import gen_parsing_ops
51from tensorflow.python.ops import gen_random_ops
52from tensorflow.python.ops import gen_sparse_ops
53from tensorflow.python.ops import gen_spectral_ops
54from tensorflow.python.ops import handle_data_util
55from tensorflow.python.ops import linalg_ops
56from tensorflow.python.ops import list_ops
57from tensorflow.python.ops import manip_ops
58from tensorflow.python.ops import map_fn
59from tensorflow.python.ops import math_ops
60from tensorflow.python.ops import nn_ops
61from tensorflow.python.ops import parsing_ops
62from tensorflow.python.ops import resource_variable_ops
63from tensorflow.python.ops import sparse_ops
64from tensorflow.python.ops import special_math_ops
65from tensorflow.python.ops import tensor_array_ops
66from tensorflow.python.platform import flags
67from tensorflow.python.platform import tf_logging as logging
68from tensorflow.python.util import compat
69from tensorflow.python.util import nest
70from tensorflow.python.util import object_identity
71
72
73# TODO(agarwal): remove flag.
74flags.DEFINE_bool(
75    "op_conversion_fallback_to_while_loop", True,
76    "DEPRECATED: Flag is ignored.")
77
78
79def _variant_handle_data(t):
80  """Fetches handle data for a variant tensor `t`, or None if unavailable."""
81  handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
82  if not handle_data.is_set:
83    return None
84  return handle_data.shape_and_type
85
86
87def _variant_type_id(t):
88  """Returns the full_type_pb2 type of `t`, or None if it is not available."""
89  if t.dtype != dtypes.variant:
90    return None
91  shapes_and_types = _variant_handle_data(t)
92  if shapes_and_types is None or not shapes_and_types:
93    # TODO(b/169968286): Identify all variant tensors (e.g. maps) and we can
94    # make this an error instead of assuming TensorLists have handle data.
95    return None  # Presumed not a TensorList/Optional
96  return shapes_and_types[0].type.type_id
97
98
99_INTERNAL_STACKING_TYPE_IDS = (
100    full_type_pb2.TFT_ARRAY,
101    full_type_pb2.TFT_OPTIONAL)
102
103
104def _is_variant_with_internal_stacking(t):
105  """Identifies variant tensors which pfor always maintains as scalars.
106
107  For these, the pfor tensor is recorded as "stacked" if the content of the
108  variant tensor (e.g. the elements of a TensorList) are all stacked.
109
110  Args:
111    t: A tensor to identify.
112  Returns:
113    True if `t` is a TensorList/Optional, False not, None if unknown.
114  """
115  type_id = _variant_type_id(t)
116  return type_id in _INTERNAL_STACKING_TYPE_IDS
117
118
119def _parse_variant_shapes_and_types(t):
120  """Extracts shape and dtype information from a variant tensor `t`."""
121  shapes_and_types = _variant_handle_data(t)
122  if shapes_and_types is None or not shapes_and_types:
123    raise ValueError("Required handle data not set for {!r}".format(t))
124  if shapes_and_types[0].type.type_id == full_type_pb2.TFT_ARRAY:
125    return shapes_and_types
126  else:
127    if shapes_and_types[0].type.type_id == full_type_pb2.TFT_UNSET:
128      return shapes_and_types
129    else:
130      raise ValueError(
131          "Attempted to stack a variant-dtype tensor with no type set ({!r})"
132          .format(t))
133
134
135def _stack(t, length):
136  """stacks `t` `length` times."""
137  # Note that this stacking may currently be triggered, for example, when a
138  # loop invariant tensor with dtype variant is input to a while_loop which then
139  # produces a loop dependent output. Simply stacking the variants may not be
140  # suitable since operations on stacked handles may expect a vectorized version
141  # of the variant.
142  if t.dtype == dtypes.variant:
143    shapes_and_types = _parse_variant_shapes_and_types(t)
144    if shapes_and_types[0].type.type_id == full_type_pb2.TFT_ARRAY:
145      if len(shapes_and_types) != 1:
146        raise ValueError(
147            f"Expected handle data of length 1, got {shapes_and_types!r} of "
148            f"length {len(shapes_and_types)}.")
149      return wrap(
150          _stack_tensor_list(t, shapes_and_types[0].dtype, length),
151          True)
152    else:
153      raise ValueError(
154          "Attempted to stack an unhandled variant-dtype tensor of "
155          f"type {shapes_and_types[0].type!r} ({t!r}).")
156  ones = array_ops.ones_like(array_ops.shape(t))
157  ones = array_ops.reshape(ones, [-1])
158  length = array_ops.reshape(length, [-1])
159  multiples = array_ops.concat([length, ones], 0)
160  t = array_ops.tile(array_ops.expand_dims(t, 0), multiples)
161  return wrap(t, True)
162
163
164# The following stateful ops can be safely called once, and with the same
165# signature as the unconverted version, if their inputs are loop invariant.
166# TODO(agarwal): implement a strategy for converting Variable reads/writes. The
167# plan is to map each read/write in the loop_fn to a corresponding merged
168# read/write in the converted graph. Writes need to be mergeable (e.g.
169# AssignAdd) to be used in `pfor`. Given a certain read/write order in the
170# loop_fn, doing a one-to-one conversion will simulate executing such
171# instructions in lock-step across all iterations.
172passthrough_stateful_ops = set([
173    "VariableV2",
174    "VarHandleOp",
175    "VariableShape",
176    "ReadVariableOp",
177    "StackV2",
178    "TensorArrayWriteV3",
179    "TensorArrayReadV3",
180    "TensorArraySizeV3",
181])
182
183
184# Ops which we will treat like stateful for the purpose of vectorization.
185# Typically this is used to force pfor converters to run for these ops.
186force_stateful_ops = set([
187    # We vectorize this since we need to change the element shape set on the
188    # list.
189    "TensorListReserve",
190])
191
192
193def _is_stateful_pfor_op(op):
194  if isinstance(op, WhileOp):
195    return op.is_stateful
196  if op.type == "Const":
197    # Const didn't have an op_def.
198    return False
199  if op.type in passthrough_stateful_ops:
200    return False
201  if op.type in force_stateful_ops:
202    return True
203  assert hasattr(op, "op_def") and op.op_def is not None, op
204  return op.op_def.is_stateful
205
206
207# pylint: disable=protected-access
208class WhileOp:
209  """Object for storing state for converting the outputs of a while_loop."""
210
211  def __init__(self, exit_node, pfor_ops, fallback_to_while_loop, pfor_config):
212    """Initializer.
213
214    Args:
215      exit_node: A tensor output from the while_loop.
216      pfor_ops: list of ops inside the current pfor loop.
217      fallback_to_while_loop: If True, fallback to while loop when conversion of
218        an op is not supported
219      pfor_config: PForConfig object used while constructing loop body.
220    """
221    self._fallback_to_while_loop = fallback_to_while_loop
222    self._pfor_config = pfor_config
223    self._pfor_ops = set(pfor_ops)
224    self._pfor_op_ids = set(x._id for x in pfor_ops)
225    assert isinstance(exit_node, ops.Tensor)
226    self._while_context = exit_node.op._get_control_flow_context()
227    assert isinstance(self._while_context, control_flow_ops.WhileContext)
228    self._context_name = self._while_context.name
229    self._condition = self._while_context.pivot.op.inputs[0]
230    # Parts of an external while_loop could be created inside a pfor loop.
231    # However for the purpose here, we declare such loops to be external. Also
232    # note that we check if the condition was created inside or outside to
233    # determine if the while_loop was first created inside or outside.
234    # TODO(agarwal): check that the Enter and Exit of this loop are unstacked.
235    self._is_inside_loop = self.op_is_inside_loop(self._condition.op)
236    if self._is_inside_loop:
237      for e in self._while_context.loop_exits:
238        assert self.op_is_inside_loop(e.op)
239
240    # Note the code below tries to reverse engineer an existing while_loop graph
241    # by assuming the following pattern of nodes.
242    #
243    #          NextIteration <---- Body <--- Enter
244    #              |                ^
245    #              V             ___| Y
246    #    Enter -> Merge -> Switch___
247    #                       ^       | N
248    #                       |       V
249    #                  LoopCond    Exit
250
251    # Node that elements in the list below correspond one-to-one with each
252    # other. i.e. these lists are the same size, and the i_th entry corresponds
253    # to different Operations/Tensors of a single cycle as illustrated above.
254    # List of Switch ops (ops.Operation) that feed into an Exit Node.
255    self._exit_switches = []
256    # List of inputs (ops.Tensor) to NextIteration.
257    self._body_outputs = []
258    # List of list of control inputs of the NextIteration nodes.
259    self._next_iter_control_inputs = []
260    # List of Merge ops (ops.Operation).
261    self._enter_merges = []
262    # List of output (ops.Tensor) of Exit nodes.
263    self._outputs = []
264
265    # List of Enter Tensors.
266    # There are two types of Enter nodes:
267    # - The Enter nodes that are used in the `loop_vars` argument to
268    # `while_loop` (see
269    # https://www.tensorflow.org/api_docs/python/tf/while_loop). We collect
270    # these Enter nodes immediately below by tracing backwards from the Exit
271    # nodes via Exit <- Switch <- Merge <- Enter. You can see this chain in the
272    # diagram above. This allows us to have a 1:1 correspondence between the
273    # self._outputs and the first elements in self._enters.
274    # - The Enter nodes that are used only by the body. They don't appear in the
275    # `loop_vars` and are not returned from the `while_loop`. In Python code,
276    # they are usually captured by the body lambda. We collect them below by
277    # iterating over all the ops in the graph. They are appended to the end of
278    # self._enters or self._direct_enters, and don't correspond to any outputs
279    # in self._outputs. Note that we keep the resource/variant Enter nodes in
280    # self._direct_enters and the constructed while_loop's body uses them
281    # directly as opposed to passing them as loop variables. This is done
282    # because the while_body cannot partition the resource/variant Tensors, so
283    # it has to leave them unchanged.
284    self._enters = []
285    self._direct_enters = []
286
287    for e in self._while_context.loop_exits:
288      self._outputs.append(e.op.outputs[0])
289      switch = e.op.inputs[0].op
290      assert switch.type == "Switch", switch
291      self._exit_switches.append(switch)
292      merge = switch.inputs[0].op
293      assert merge.type == "Merge", merge
294      self._enter_merges.append(merge)
295      enter = merge.inputs[0].op
296      assert enter.type == "Enter", enter
297      self._enters.append(enter.outputs[0])
298      next_iter = merge.inputs[1].op
299      assert next_iter.type == "NextIteration", next_iter
300      self._body_outputs.append(next_iter.inputs[0])
301      self._next_iter_control_inputs.append(next_iter.control_inputs)
302
303    # Collect all the Enter nodes that are not part of `loop_vars`, the second
304    # category described above.
305    # Also track whether the loop body has any stateful ops.
306    self._is_stateful = False
307    for op in ops.get_default_graph().get_operations():
308      # TODO(agarwal): make sure this works with nested case.
309      control_flow_context = op._get_control_flow_context()
310      if control_flow_context is None:
311        continue
312      if control_flow_context.name == self._context_name:
313        self._is_stateful |= _is_stateful_pfor_op(op)
314        if op.type == "Enter":
315          output = op.outputs[0]
316          if output not in self._enters:
317            if output.dtype in (dtypes.resource, dtypes.variant):
318              if output not in self._direct_enters:
319                self._direct_enters.append(output)
320            else:
321              self._enters.append(output)
322
323  def __str__(self):
324    """String representation."""
325    return "while_loop(%s)" % self.name
326
327  @property
328  def inputs(self):
329    """Input to all the Enter nodes."""
330    return [x.op.inputs[0] for x in self._enters + self._direct_enters]
331
332  @property
333  def control_inputs(self):
334    """Control input to all the Enter nodes."""
335    control_inputs = []
336    for x in self._enters + self._direct_enters:
337      control_inputs.extend(x.op.control_inputs)
338    return control_inputs
339
340  @property
341  def outputs(self):
342    """Outputs of all the Exit nodes."""
343    return self._outputs
344
345  @property
346  def name(self):
347    """Context name for the while loop."""
348    return self._context_name
349
350  @property
351  def is_inside_loop(self):
352    """Returns true if the while_loop was created inside the pfor."""
353    return self._is_inside_loop
354
355  def op_is_inside_loop(self, op):
356    """True if op was created inside the pfor loop body."""
357    assert isinstance(op, ops.Operation)
358    # Note that we use self._pfor_op_ids for the check and not self._pfor_ops
359    # since it appears there tensorflow API could return different python
360    # objects representing the same Operation node.
361    return op._id in self._pfor_op_ids
362
363  @property
364  def is_stateful(self):
365    return self._is_stateful
366
367  @property
368  def pfor_converter(self):
369    """Return a converter for the while loop."""
370    return self
371
372  def _init_pfor(self, parent_pfor, indices, cond_stacked, inputs,
373                 inputs_stacked):
374    """Create a PFor object for converting parts of the while_loop.
375
376    Args:
377      parent_pfor: PFor object being used for converting the while_loop.
378      indices: int32 Tensor of ids for the iterations that are still active
379        (i.e. did not exit the while_loop).
380      cond_stacked: True if the while_loop condition is stacked.
381      inputs: list of input Tensors corresponding 1-to-1 with self._enters. Note
382        that these Tensors are a subset of the loop variables for the generated
383        while_loop.
384      inputs_stacked: List of booleans corresponding 1-to-1 with `inputs`,
385        indicating if the value is stacked or not.
386
387    Returns:
388      A PFor instance. The instance is initialized by adding conversion mappings
389        of nodes that will be external to the conversion that the returned
390        instance will be used for. e.g. Enter nodes as well as Merge and Switch
391        outputs are mapped to converted values.
392    """
393    num_outputs = len(self._outputs)
394    assert len(inputs) == len(self._enters)
395    assert len(inputs_stacked) == len(self._enters)
396    loop_var = parent_pfor.loop_var
397    loop_len = array_ops.size(indices)
398    pfor = PFor(
399        loop_var,
400        loop_len,
401        pfor_ops=self._pfor_ops,
402        all_indices=indices,
403        all_indices_partitioned=cond_stacked,
404        fallback_to_while_loop=self._fallback_to_while_loop,
405        pfor_config=self._pfor_config)
406    # Map all inputs of Enter nodes in self._direct_enters to their converted
407    # values.
408    for enter in self._direct_enters:
409      enter_input = enter.op.inputs[0]
410      converted_enter, stacked, is_sparse_stacked = parent_pfor._convert_helper(
411          enter_input)
412      # Since these are resources / variants, they should be unstacked.
413      assert not stacked and not is_sparse_stacked, (enter, converted_enter)
414      pfor._add_conversion(enter, wrap(converted_enter, False))
415
416    # Map all Enter nodes to the inputs.
417    for enter, inp, stacked in zip(self._enters, inputs, inputs_stacked):
418      pfor._add_conversion(enter, wrap(inp, stacked))
419    # Map outputs of Switch and Merge.
420    for i in range(num_outputs):
421      wrapped_inp = wrap(inputs[i], inputs_stacked[i])
422      merge = self._enter_merges[i]
423      pfor._add_conversion(merge.outputs[0], wrapped_inp)
424      # Note that second output of Merge is typically not used, except possibly
425      # as a control dependency. To avoid trying to output the correct value, we
426      # employ a hack here. We output a dummy invalid value with an incorrect
427      # dtype. This will allow control dependency to work but if using it as an
428      # input, it should typically lead to errors during graph construction due
429      # to dtype mismatch.
430      # TODO(agarwal): Check in the original graph to see if there are any
431      # consumers of this Tensor that use it as an input.
432      pfor._add_conversion(merge.outputs[1],
433                           wrap(constant_op.constant(-1.0), False))
434      switch = self._exit_switches[i]
435      # Don't need to worry about switch.output[0] which will feed to Exit node.
436      pfor._add_conversion(switch.outputs[1], wrapped_inp)
437    return pfor
438
439  def _convert_enter(self, parent_pfor, enter):
440    """Converts an Enter node."""
441    inp, stacked, _ = parent_pfor._convert_helper(enter.op.inputs[0])
442    control_inputs = []
443    for x in enter.op.control_inputs:
444      converted = parent_pfor._convert_helper(x)
445      if not isinstance(converted, ops.Operation):
446        converted = converted.t
447      control_inputs.append(converted)
448    if control_inputs:
449      with ops.control_dependencies(control_inputs):
450        inp = array_ops.identity(inp)
451    return inp, stacked
452
453  def _maybe_stacked(self, cache, inp):
454    """Heuristic to figure out if the converting inp leads to a stacked value.
455
456
457    Args:
458      cache: map from Tensor to boolean indicating stacked/unstacked.
459      inp: input Tensor.
460
461    Returns:
462      True if `inp` could get stacked. If the function returns False, the
463      converted value should be guaranteed to be unstacked. If returning True,
464      it may or may not be stacked.
465    """
466    if inp in cache:
467      return cache[inp]
468    if not self.op_is_inside_loop(inp.op):
469      return False
470    op = inp.op
471    output = False
472    if op.type in [
473        "Shape",
474        "Rank",
475        "ShapeN",
476        "ZerosLike",
477        "TensorArrayV3",
478        "TensorArraySizeV3",
479    ]:
480      output = False
481    elif _is_stateful_pfor_op(op):
482      # This may be fairly aggressive.
483      output = True
484    elif op.type == "Exit":
485      # This may be fairly aggressive.
486      output = True
487    else:
488      for t in op.inputs:
489        if self._maybe_stacked(cache, t):
490          output = True
491          break
492    cache[inp] = output
493    return output
494
495  def _create_init_values(self, pfor_input):
496    """Create arguments passed to converted while_loop."""
497    with ops.name_scope("while_init"):
498      loop_len_vector = pfor_input.pfor.loop_len_vector
499      loop_len = loop_len_vector[0]
500      num_outputs = len(self._outputs)
501
502      inputs = []
503      maybe_stacked_cache = {}
504      # Convert all the Enters. Need to do this before checking for stacking
505      # below.
506      for i, enter in enumerate(self._enters):
507        inp, stacked = self._convert_enter(pfor_input.pfor, enter)
508        inputs.append(inp)
509        maybe_stacked_cache[enter] = stacked
510        # Since this enter node is part of the `loop_vars`, it corresponds to an
511        # output and its preceding switch. We mark this switch's output the same
512        # stackness, to act at the base case for the logic below. Below, we will
513        # be going through the body figuring out which inputs might need to be
514        # stacked and which inputs can safely remain unstacked.
515        if i < num_outputs:
516          maybe_stacked_cache[self._exit_switches[i].outputs[1]] = stacked
517
518      # Shape invariants for init_values corresponding to self._enters.
519      input_shape_invariants = []
520      # TensorArrays for outputs of converted while loop
521      output_tas = []
522      # Shape invariants for output TensorArrays.
523      ta_shape_invariants = []
524      # List of booleans indicating stackness of inputs, i.e. tensors
525      # corresponding to self._enters.
526      inputs_stacked = []
527      for i, inp in enumerate(inputs):
528        enter = self._enters[i]
529        inp_stacked = self._maybe_stacked(maybe_stacked_cache, enter)
530        # Note that even when an input is unstacked, the body could make it
531        # stacked. we use a heuristic below to figure out if body may be making
532        # it stacked.
533        if i < num_outputs:
534          body_output = self._body_outputs[i]
535          if enter.op in self._pfor_ops:
536            body_output_stacked = self._maybe_stacked(maybe_stacked_cache,
537                                                      body_output)
538          else:
539            # If constructed outside of pfor loop, then the output would not be
540            # stacked.
541            body_output_stacked = False
542          if body_output_stacked and not inp_stacked:
543            inp = _stack(inp, loop_len_vector).t
544            inputs[i] = inp
545            inp_stacked = True
546          # TODO(agarwal): other attributes for the TensorArray ?
547          output_tas.append(tensor_array_ops.TensorArray(inp.dtype, loop_len))
548          ta_shape_invariants.append(tensor_shape.TensorShape(None))
549
550        inputs_stacked.append(inp_stacked)
551        input_shape_invariants.append(tensor_shape.TensorShape(None))
552
553      # See documentation for __call__ for the structure of init_values.
554      init_values = [True, pfor_input.pfor.all_indices] + inputs + output_tas
555      # TODO(agarwal): try stricter shape invariants
556      shape_invariants = (
557          [tensor_shape.TensorShape(None),
558           tensor_shape.TensorShape(None)] + input_shape_invariants +
559          ta_shape_invariants)
560
561      return init_values, inputs_stacked, shape_invariants
562
563  def _process_cond_unstacked(self, conditions, indices, inputs, output_tas):
564    """Handles case when condition is unstacked.
565
566    Note that all iterations end together. So we don't need to partition the
567    inputs. When all iterations are done, we write the inputs to the
568    TensorArrays. Note that we only write to index 0 of output_tas. Since all
569    iterations end together, they can all be output together.
570    """
571    not_all_done = array_ops.reshape(conditions, [])
572    new_output_tas = []
573    # pylint: disable=cell-var-from-loop
574    for i, out_ta in enumerate(output_tas):
575      inp = inputs[i]
576      new_output_tas.append(
577          control_flow_ops.cond(not_all_done, lambda: out_ta,
578                                lambda: out_ta.write(0, inp)))
579    # pylint: enable=cell-var-from-loop
580    return not_all_done, indices, inputs, new_output_tas
581
582  def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked,
583                            output_tas):
584    num_outputs = len(self._outputs)
585    # Compute if all iterations are done.
586    not_all_done = math_ops.reduce_any(conditions)
587    conditions_int = math_ops.cast(conditions, dtypes.int32)
588    # Partition the indices.
589    done_indices, new_indices = data_flow_ops.dynamic_partition(
590        indices, conditions_int, 2)
591
592    new_inputs = []
593    new_output_tas = []
594    for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)):
595      # Partition the inputs.
596      if stacked:
597        done_inp, new_inp = data_flow_ops.dynamic_partition(
598            inp, conditions_int, 2)
599      else:
600        # TODO(agarwal): avoid this stacking. See TODO earlier in
601        # _process_cond_unstacked.
602        done_inp = _stack(inp, [array_ops.size(done_indices)]).t
603        new_inp = inp
604      new_inputs.append(new_inp)
605      # For iterations that are done, write them to TensorArrays.
606      if i < num_outputs:
607        out_ta = output_tas[i]
608        # Note that done_indices can be empty. done_inp should also be empty in
609        # that case.
610        new_output_tas.append(out_ta.scatter(done_indices, done_inp))
611    return not_all_done, new_indices, new_inputs, new_output_tas
612
613  def _process_body(self, pfor_input, inputs_stacked, new_indices, cond_stacked,
614                    new_inputs, not_all_done):
615    """Convert the body function."""
616
617    def true_fn(control_inputs, body_pfor, body_output, stacked):
618      """Converts the body function for all but last iteration.
619
620      This essentially converts body_output. Additionally, it needs to handle
621      any control dependencies on the NextIteration node. So it creates another
622      Identity node with the converted dependencies.
623      """
624      converted_control_inp = []
625      for x in control_inputs:
626        for t in x.outputs:
627          converted_control_inp.append(body_pfor._convert_helper(t).t)
628      if stacked:
629        # Note convert always does the stacking.
630        output = body_pfor.convert(body_output)
631      else:
632        output, convert_stacked, _ = body_pfor._convert_helper(body_output)
633        assert convert_stacked == stacked, body_output
634      with ops.control_dependencies(converted_control_inp):
635        return array_ops.identity(output)
636
637    body_pfor = self._init_pfor(pfor_input.pfor, new_indices, cond_stacked,
638                                new_inputs, inputs_stacked)
639    new_outputs = []
640
641    for i, (body_output,
642            stacked) in enumerate(zip(self._body_outputs, inputs_stacked)):
643      control_inp = self._next_iter_control_inputs[i]
644      out_dtype = body_output.dtype
645      # Note that we want to run the body only if not all pfor iterations are
646      # done. If all are done, we return empty tensors since these values will
647      # not be used. Notice that the value returned by the loop is based on
648      # TensorArrays and not directly on these returned values.
649      # pylint: disable=cell-var-from-loop
650      new_output = control_flow_ops.cond(
651          not_all_done,
652          lambda: true_fn(control_inp, body_pfor, body_output, stacked),
653          lambda: constant_op.constant([], dtype=out_dtype))
654      # pylint: enable=cell-var-from-loop
655      new_outputs.append(new_output)
656    return new_outputs
657
658  def __call__(self, pfor_input):
659    """Converter for the while_loop.
660
661    The conversion of a while_loop is another while_loop.
662
663    The arguments to this converted while_loop are as follows:
664    not_all_done: Boolean scalar Tensor indicating if all the pfor iterations
665      are done.
666    indices: int32 1-D Tensor storing the id of the iterations that are not
667      done.
668    args: Remaining arguments. These can be divided into 3 categories:
669      - First set of arguments are the tensors that correspond to the initial
670        elements of self._enters. The elements that appear in original while
671        loop's `loop_vars`.
672      - The second set of arguments are the tensors that correspond to the
673        remaining elements of self._enters. These are the tensors that directly
674        enter the original while loop body.
675       - Finally, the last set of arguments are TensorArrays. These TensorArrays
676         correspond to the outputs of the original while_loop, i.e. to the
677         elements in self._outputs. Each TensorArray has `PFor.loop_len`
678         elements, i.e. the number of pfor iterations. At the end, the i'th
679         element of each TensorArray will contain the output computed by the
680         i'th iteration of pfor. Note that elements can be written into these
681         tensors arrays in any order, depending on when the corresponding pfor
682         iteration is done.
683      If the original while_loop had `k` tensors in its `loop_vars` and its body
684      directly captured `m` tensors, the `args` will contain `2 * k + m` values.
685
686    In each iteration, the while_loop body recomputes the condition for all
687    active pfor iterations to see which of them are now done. It then partitions
688    all the inputs and passes them along to the converted body. Values for all
689    the iterations that are done are written to TensorArrays indexed by the pfor
690    iteration number. When all iterations are done, the TensorArrays are stacked
691    to get the final value.
692
693    Args:
694      pfor_input: A PForInput object corresponding to the output of any Exit
695        node from this while loop.
696
697    Returns:
698      List of converted outputs.
699    """
700    # Create init_values that will be passed to the while_loop.
701    init_values, inputs_stacked, shape_invariants = self._create_init_values(
702        pfor_input)
703    # Note that we use a list as a hack since we need the nested function body
704    # to set the value of cond_is_stacked. python2.x doesn't support nonlocal
705    # variables.
706    cond_is_stacked = [None]
707
708    def cond(not_all_done, *_):
709      return not_all_done
710
711    def body(not_all_done, indices, *args):
712      # See documentation for __call__ for the structure of *args.
713      num_enters = len(self._enters)
714      inputs = args[:num_enters]
715      output_tas = args[num_enters:]
716      # TODO(agarwal): see which outputs have consumers and only populate the
717      # TensorArrays corresponding to those. Or do those paths get trimmed out
718      # from inside the while_loop body?
719      assert len(inputs) >= len(output_tas)
720      assert len(inputs) == len(inputs_stacked)
721
722      # Convert condition
723      with ops.name_scope("while_cond"):
724        # Note that we set cond_stacked to True here. At this point we don't
725        # know if it could be loop invariant, hence the conservative value is
726        # to assume stacked.
727        cond_pfor = self._init_pfor(
728            pfor_input.pfor,
729            indices,
730            cond_stacked=True,
731            inputs=inputs,
732            inputs_stacked=inputs_stacked)
733        conditions, cond_stacked, _ = cond_pfor._convert_helper(self._condition)
734        cond_is_stacked[0] = cond_stacked
735
736      # Recompute the new condition, write outputs of done iterations, and
737      # partition the inputs if needed.
738      if not cond_stacked:
739        (not_all_done, new_indices, new_inputs,
740         new_output_tas) = self._process_cond_unstacked(conditions, indices,
741                                                        inputs, output_tas)
742      else:
743        (not_all_done, new_indices, new_inputs,
744         new_output_tas) = self._process_cond_stacked(conditions, indices,
745                                                      inputs, inputs_stacked,
746                                                      output_tas)
747
748      # Convert body
749      with ops.name_scope("while_body"):
750        #  Compute the outputs from the body.
751        new_outputs = self._process_body(pfor_input, inputs_stacked,
752                                         new_indices, cond_stacked, new_inputs,
753                                         not_all_done)
754
755      # Note that the first num_outputs new values of inputs are computed using
756      # the body. Rest of them were direct Enters into the condition/body and
757      # the partitioning done earlier is sufficient to give the new value.
758      num_outputs = len(self._outputs)
759      new_args = ([not_all_done, new_indices] + new_outputs +
760                  list(new_inputs[num_outputs:]) + new_output_tas)
761      return tuple(new_args)
762
763    while_outputs = control_flow_ops.while_loop(
764        cond, body, init_values, shape_invariants=shape_invariants)
765    output_tas = while_outputs[-len(self._outputs):]
766    outputs = []
767    assert cond_is_stacked[0] is not None
768    for inp_stacked, ta in zip(inputs_stacked, output_tas):
769      if cond_is_stacked[0]:
770        outputs.append(wrap(ta.stack(), True))
771      else:
772        # Note that if while_loop condition is unstacked, all iterations exit at
773        # the same time and we wrote those outputs in index 0 of the tensor
774        # array.
775        outputs.append(wrap(ta.read(0), inp_stacked))
776    return outputs
777
778
779class ConversionNotImplementedError(Exception):
780  pass
781
782
783class _PforInput:
784  """Input object passed to registered pfor converters."""
785
786  __slots__ = ["pfor", "_op", "_inputs"]
787
788  def __init__(self, pfor, op, inputs):
789    """Creates a _PforInput object.
790
791    Args:
792      pfor: PFor converter object.
793      op: the Operation object that is being converted.
794      inputs: list of WrappedTensor objects representing converted values of the
795        inputs of `op`.
796    """
797    self.pfor = pfor
798    self._op = op
799    self._inputs = inputs
800
801  def stack_inputs(self, stack_indices=None, tile_variants=False):
802    """Stacks unstacked inputs at `stack_indices`.
803
804    Args:
805      stack_indices: indices of inputs at which stacking is done. If None,
806        stacking is done at all indices.
807      tile_variants: If True, affected indices which have a variant dtype will
808        be tiled after this operation to match the expected shape of a
809        vectorized tensor. Variants generally need to be un-tiled when they are
810        inputs to operations and tiled when returned.
811    """
812    if stack_indices is None:
813      stack_indices = range(len(self._inputs))
814    length = self.pfor.loop_len_vector
815    for i in stack_indices:
816      inp = self._inputs[i]
817      is_variant = inp.t.dtype == dtypes.variant
818      if not inp.is_stacked:
819        self._inputs[i] = _stack(inp.t, length)
820        if tile_variants and is_variant:
821          self._inputs[i] = wrap(
822              _tile_variant_with_length(self._inputs[i].t, length), True)
823      elif not tile_variants and is_variant:
824        self._inputs[i] = wrap(_untile_variant(self._inputs[i].t), True)
825
826  def expanddim_inputs_for_broadcast(self):
827    """Reshapes stacked inputs to prepare them for broadcast.
828
829    Since stacked inputs have an extra leading dimension, automatic broadcasting
830    rules could incorrectly try to expand dimensions before that leading
831    dimension. To avoid that, we reshape these stacked inputs to the maximum
832    rank they will need to be broadcasted to.
833    """
834    if not self._inputs:
835      return
836
837    # Find max rank
838    def _get_rank(x):
839      rank = array_ops.rank(x.t)
840      if not x.is_stacked:
841        rank += 1
842      return rank
843
844    ranks = [_get_rank(x) for x in self._inputs]
845    max_rank = ranks[0]
846    for rank in ranks[1:]:
847      max_rank = math_ops.maximum(rank, max_rank)
848
849    for i, inp in enumerate(self._inputs):
850      if inp.is_stacked:
851        shape = array_ops.shape(inp.t)
852        rank_diff = array_ops.reshape(max_rank - ranks[i], [1])
853        ones = array_ops.tile([1], rank_diff)
854        new_shape = array_ops.concat([shape[:1], ones, shape[1:]], axis=0)
855        self._inputs[i] = wrap(array_ops.reshape(inp.t, new_shape), True)
856
857  @property
858  def inputs(self):
859    return self._inputs
860
861  @property
862  def num_inputs(self):
863    return len(self._inputs)
864
865  def input(self, index):
866    assert len(self._inputs) > index, (index, self._inputs)
867    return self._inputs[index]
868
869  def stacked_input(self, index):
870    t, is_stacked, _ = self.input(index)
871    if not is_stacked:
872      op_type = self.op_type
873      op_def = getattr(self._op, "op_def", None)
874      if op_def is None:
875        input_name = "at index %d" % index
876      else:
877        input_name = "\"%s\"" % op_def.input_arg[index].name
878      raise ConversionNotImplementedError(
879          f"Input {input_name} of op '{op_type}' expected to be not loop "
880          "invariant.")
881    return t
882
883  def unstacked_input(self, index):
884    t, is_stacked, _ = self.input(index)
885    if is_stacked:
886      op_type = self.op_type
887      op_def = getattr(self._op, "op_def", None)
888      if op_def is None:
889        input_name = "at index %d" % index
890      else:
891        input_name = "\"%s\"" % op_def.input_arg[index].name
892      raise ConversionNotImplementedError(
893          f"Input {input_name} of op '{op_type}' expected to be loop "
894          "invariant.")
895    return t
896
897  @property
898  def op(self):
899    return self._op
900
901  @property
902  def op_type(self):
903    return self._op.type
904
905  def get_attr(self, attr):
906    return self._op.get_attr(attr)
907
908  @property
909  def outputs(self):
910    return self._op.outputs
911
912  def output(self, index):
913    assert index < len(self._op.outputs)
914    return self._op.outputs[index]
915
916
917_pfor_converter_registry = {}
918
919
920class RegisterPFor:
921  """Utility to register converters for pfor.
922
923  Usage:
924  @RegisterPFor(foo_op_type)
925  def _foo_converter(pfor_input):
926    ...
927
928  The above will register conversion function `_foo_converter` for handling
929  conversion of `foo_op_type`. These converters are called during vectorization
930  of a `pfor` loop body. For each operation node in this loop body,
931  the vectorization process will call the converter corresponding to the
932  operation type of the node.
933
934  During conversion, the registered function will be called with a single
935  argument `pfor_input`, of type `PForInput`, which will contain state needed
936  for the conversion.  When the converter is called for a node, all its inputs
937  should already have been converted and these converted values are stored in
938  `pfor_input.inputs`.  This registered function should output a list of
939  WrappedTensor objects with the same length as the number of outputs of the
940  node being converted. If the node had zero outputs, then it should return an
941  ops.Operation object.  These new sets of nodes should implement the
942  functionality of running that operation for the number of iterations specified
943  by `pfor_input.pfor.loop_len_vector[0]` where the inputs of the node for each
944  iteration are picked from `pfor_inputs.inputs()`.
945
946  One tricky aspect of the conversion process is keeping track of, and
947  leveraging loop invariance of computation. Each converted input is a
948  WrappedTensor which indicates whether the input was loop invariant or not. If
949  the converted value is loop invariant, its rank should match the rank of the
950  corresponding tensor in the loop body, else its rank is larger by 1. The
951  converter should look at the loop invariance of the inputs and generate new
952  nodes based on that. Note that the converter will not be called if all inputs
953  are loop invariant and the operation is not stateful. The converter should
954  determine if its own output is loop invariant and `wrap` its output
955  accordingly.
956
957  Example:
958
959  Here, the converter is trying to convert a Reshape node in the loop body. This
960  node will have two inputs: the tensor to reshape, and the new shape.  The
961  example here only handles the case where the shape is loop invariant.
962
963  @RegisterPFor("Reshape")
964  def _convert_reshape(pfor_input):
965    # We assume that input is not loop invariant. Call to `stacked_input`
966    # asserts that and returns the converted value. This value will have a rank
967    # larger by 1 compared to the rank of the input in the loop body.
968    t = pfor_input.stacked_input(0)
969
970    # We assume that shape input is loop invariant. Call to `unstacked_input`
971    # asserts that and returns the converted value.
972    shape = pfor_input.unstacked_input(1)
973
974    # We compute `new_shape` by prepending the number of iterations to the
975    # original shape.
976    new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape],
977                                 axis=0)
978
979    # The vectorized output involves reshaping the converted input `t` using
980    # `new_shape`.
981    new_output = array_ops.reshape(t, new_shape)
982
983    # The converted output is marked as not loop invariant using the call to
984    # wrap.
985    return wrap(new_output, True)
986  """
987
988  def __init__(self, op_type):
989    """Creates an object to register a converter for op with type `op_type`."""
990    self.op_type = op_type
991
992  def __call__(self, converter):
993    name = self.op_type
994    assert name not in _pfor_converter_registry, "Re-registering %s " % name
995    _pfor_converter_registry[name] = converter
996    return converter
997
998
999class RegisterPForWithArgs(RegisterPFor):
1000  """Utility to register converters for pfor.
1001
1002  Usage:
1003  @RegisteRPFor(foo_op_type, foo=value, ....)
1004  def _foo_converter(pfor_input, foo=None, ....):
1005    ...
1006
1007  See RegisterPFor for details on the conversion function.
1008  `RegisterPForWithArgs` allows binding extra arguments to the
1009  conversion function at registration time.
1010  """
1011
1012  def __init__(self, op_type, *args, **kw_args):
1013    super(RegisterPForWithArgs, self).__init__(op_type)
1014    self._args = args
1015    self._kw_args = kw_args
1016
1017  def __call__(self, converter):
1018
1019    def _f(pfor_input):
1020      return converter(pfor_input, self.op_type, *self._args, **self._kw_args)
1021
1022    super(RegisterPForWithArgs, self).__call__(_f)
1023    return converter
1024
1025
1026# TODO(agarwal): call raw_ops instead of calling these low level routines.
1027def _create_op(op_type, inputs, op_dtypes, attrs=None):
1028  """Utility to create an op."""
1029  op = ops.get_default_graph().create_op(
1030      op_type, inputs, op_dtypes, attrs=attrs, compute_device=True)
1031  flat_attrs = []
1032  # The tape expects an alternating flat list of names and attribute values.
1033  for a in attrs:
1034    flat_attrs.append(str(a))
1035    flat_attrs.append(op.get_attr(str(a)))
1036  execute.record_gradient(op_type, op.inputs, tuple(flat_attrs), op.outputs[:])
1037  return op
1038
1039
1040WrappedTensor = collections.namedtuple("WrappedTensor",
1041                                       ["t", "is_stacked", "is_sparse_stacked"])
1042"""Wrapper around the result of a Tensor conversion.
1043
1044The additional fields are useful for keeping track of the conversion state as
1045data flows through the ops in the loop body. For every op whose output is a
1046Tensor, its converter should return either a WrappedTensor or a list of
1047WrappedTensors.
1048
1049Args:
1050  t: The converted tensor
1051  is_stacked: True if the tensor is stacked, i.e. represents the results of all
1052    the iterations of the loop, where each row i of the tensor corresponds to
1053    that op's output on iteration i of the loop. False if the tensor is not
1054    stacked, i.e. represents the result of the op on of a single iteration of
1055    the loop, where the result does not vary between iterations.
1056  is_sparse_stacked: True if the tensor corresponds to a component tensor
1057    (indices, values, or dense_shape) of a sparse tensor, and has been logically
1058    stacked via a sparse conversion.
1059"""
1060
1061
1062def wrap(tensor, is_stacked=True, is_sparse_stacked=False):
1063  """Helper to create a WrappedTensor object."""
1064  assert isinstance(is_stacked, bool)
1065  assert isinstance(is_sparse_stacked, bool)
1066  assert isinstance(tensor, ops.Tensor)
1067  assert not is_sparse_stacked or is_stacked, ("If the wrapped tensor is "
1068                                               "stacked via a sparse "
1069                                               "conversion, it must also be "
1070                                               "stacked.")
1071  return WrappedTensor(tensor, is_stacked, is_sparse_stacked)
1072
1073
1074def _wrap_and_tile_variants(tensor, length):
1075  if tensor.dtype == dtypes.variant:
1076    tensor = _tile_variant_with_length(tensor, length)
1077  return wrap(tensor)
1078
1079
1080def _fallback_converter(pfor_input, root_cause="", warn=True):
1081  if warn:
1082    logging.warning("Using a while_loop for converting %s cause %s",
1083                    pfor_input.op_type, root_cause)
1084  output_dtypes = [x.dtype for x in pfor_input.outputs]
1085  iters = pfor_input.pfor.loop_len_vector[0]
1086
1087  def while_body(i, *ta_list):
1088    """Body of while loop."""
1089    inputs = [
1090        x[i, ...] if stacked else x for x, stacked, _ in pfor_input.inputs
1091    ]
1092    op_outputs = _create_op(
1093        pfor_input.op_type,
1094        inputs,
1095        output_dtypes,
1096        attrs=pfor_input.op.node_def.attr).outputs
1097
1098    outputs = []
1099    # TODO(agarwal): Add tf.debugging asserts to check that the shapes across
1100    # the different iterations are the same.
1101    for out, ta in zip(op_outputs, ta_list):
1102      assert isinstance(out, ops.Tensor)
1103      outputs.append(ta.write(i, array_ops.expand_dims(out, 0)))
1104    return tuple([i + 1] + outputs)
1105
1106  ta_list = control_flow_ops.while_loop(
1107      lambda i, *ta: i < iters, while_body, [0] +
1108      [tensor_array_ops.TensorArray(dtype, iters) for dtype in output_dtypes
1109      ])[1:]
1110  return tuple([wrap(ta.concat(), True) for ta in ta_list])
1111
1112
1113class PForConfig:
1114  """A configuration object used to communicate with loop body function."""
1115
1116  def __init__(self):
1117    # This may be set to the number of iterations.
1118    self._maybe_iters = None
1119    # Map from reduction node, created by `reduce`, to the bundle of reduction
1120    # function and arguments.
1121    self._reduce_map = {}
1122
1123  def _has_reductions(self):
1124    """True if some reductions where performed by loop body."""
1125    return len(self._reduce_map)
1126
1127  def _set_iters(self, iters):
1128    """Set number of pfor iterations."""
1129    if isinstance(iters, ops.Tensor):
1130      iters = tensor_util.constant_value(iters)
1131    self._maybe_iters = iters
1132
1133  def reduce(self, fn, *args):
1134    """Performs reduction `fn` on `args` vectorized across pfor iterations.
1135
1136    Note that `fn` is traced once inside the loop function context. Hence any
1137    captures or side-effects will happen in that context. Call to the traced
1138    version of `fn` happens during the construction of the vectorized code.
1139
1140    Note that this currently may not work inside a control flow construct.
1141    Args:
1142      fn: a reduction function. It will be called with arguments that have the
1143        same structure as *args but with individual values whose rank may be
1144        higher by 1 since they represent loop invariant vectorized versions of
1145        the corresponding Tensors in *args.
1146      *args: unvectorized Tensors.
1147
1148    Returns:
1149      The result of running `fn` on the vectorized versions of `*args`. These
1150      outputs will be available as loop invariant values to all the iterations.
1151    """
1152    assert not context.executing_eagerly()
1153    # Creates a concrete function that will be used for reduction.
1154    tensor_specs = []
1155    for arg in args:
1156      if not isinstance(arg, ops.Tensor):
1157        raise ValueError(f"Got a non-Tensor argument {arg} in reduce.")
1158      batched_shape = tensor_shape.TensorShape([self._maybe_iters
1159                                               ]).concatenate(arg.shape)
1160      tensor_specs.append(
1161          tensor_spec.TensorSpec(shape=batched_shape, dtype=arg.dtype))
1162    concrete_function = def_function.function(fn).get_concrete_function(
1163        *tensor_specs)
1164
1165    # Creates PlaceholderWithDefault and IdentityN nodes corresponding the
1166    # reduction.
1167    pl_outputs = []
1168    with ops.control_dependencies(args):
1169      for output in concrete_function.outputs:
1170        if not isinstance(output, ops.Tensor):
1171          raise ValueError(f"Got a non-Tensor output {output} while running "
1172                           "reduce.")
1173        # Note that we use placeholder_with_default just to make XLA happy since
1174        # it does not like placeholder ops.
1175        if output.shape.is_fully_defined():
1176          dummy = array_ops.zeros(output.shape.as_list(), dtype=output.dtype)
1177          pl_outputs.append(
1178              array_ops.placeholder_with_default(dummy, shape=output.shape))
1179        else:
1180          # TODO(agarwal): support case when under XLA and output.shape is not
1181          # fully defined.
1182          pl_outputs.append(
1183              array_ops.placeholder(output.dtype, shape=output.shape))
1184
1185      reduction_op = array_ops.identity_n(pl_outputs)[0].op
1186    self._reduce_map[reduction_op] = (concrete_function, args)
1187    if len(reduction_op.outputs) == 1:
1188      return reduction_op.outputs[0]
1189    else:
1190      return tuple(reduction_op.outputs)
1191
1192  # TODO(agarwal): handle reductions inside control flow constructs.
1193  def reduce_concat(self, x):
1194    """Performs a concat reduction on `x` across pfor iterations.
1195
1196    Note that this currently may not work inside a control flow construct.
1197    Args:
1198      x: an unvectorized Tensor.
1199
1200    Returns:
1201      A Tensor that has rank one higher than `x`. The value is the vectorized
1202      version of `x`, i.e. stacking the value of `x` across different pfor
1203      iterations.
1204    """
1205    return self.reduce(lambda y: y, x)
1206
1207  def reduce_mean(self, x):
1208    """Performs a mean reduction on `x` across pfor iterations.
1209
1210    Note that this currently may not work inside a control flow construct.
1211    Args:
1212      x: an unvectorized Tensor.
1213
1214    Returns:
1215      A Tensor that has same rank as `x`. The value is the mean of the values
1216      of `x` across the pfor iterations.
1217    """
1218    return self.reduce(lambda y: math_ops.reduce_mean(y, axis=0), x)
1219
1220  def reduce_sum(self, x):
1221    """Performs a sum reduction on `x` across pfor iterations.
1222
1223    Note that this currently may not work inside a control flow construct.
1224    Args:
1225      x: an unvectorized Tensor.
1226
1227    Returns:
1228      A Tensor that has same rank as `x`. The value is the sum of the values
1229      of `x` across the pfor iterations.
1230    """
1231    return self.reduce(lambda y: math_ops.reduce_sum(y, axis=0), x)
1232
1233  def _lookup_reduction(self, t):
1234    """Lookups Tensor `t` in the reduction maps."""
1235    assert isinstance(t, ops.Tensor), t
1236    return self._reduce_map.get(t.op)
1237
1238
1239class PFor:
1240  """Implementation of rewrite of parallel-for loops.
1241
1242  This class takes a DAG or a set of DAGs representing the body of a
1243  parallel-for loop, and adds new operations to the graph that implements
1244  functionality equivalent to running that loop body for a specified number of
1245  iterations. This new set of nodes may or may not use a tensorflow loop
1246  construct.
1247
1248  The process of conversion does not delete or change any existing operations.
1249  It only adds operations that efficiently implement the equivalent
1250  functionality. We refer to the added ops as "converted ops".
1251
1252  The conversion process uses a simple greedy heuristic. It walks the loop body
1253  and tries to express the functionality of running each node in a loop with a
1254  new set of nodes. When converting an op several cases are possible:
1255  - The op is not inside the loop body. Hence it can be used as is.
1256  - The op does not depend on the iteration number and is stateless. In this
1257    case, it can be used as is.
1258  - The op is not stateful, and depends on iteration number only through control
1259    dependencies. In this case, we can create a single op with same inputs and
1260    attributes, but with "converted" control dependencies.
1261  - The op is not stateful, and all its inputs are loop invariant. In this
1262    case, similar to above, we can create a single op with same inputs and
1263    attributes, but with "converted" control dependencies.
1264  - The op is stateful or at least one of the inputs is not loop invariant. In
1265    this case, we run the registered converter for that op to create a set of
1266    converted ops. All nodes in the set will have converted control dependencies
1267    corresponding to control dependencies of the original op. If the op returned
1268    multiple outputs, "converted outputs" could be produced by different ops in
1269    this set.
1270  """
1271
1272  def __init__(self,
1273               loop_var,
1274               loop_len,
1275               pfor_ops,
1276               fallback_to_while_loop,
1277               all_indices=None,
1278               all_indices_partitioned=False,
1279               pfor_config=None,
1280               warn=False):
1281    """Creates an object to rewrite a parallel-for loop.
1282
1283    Args:
1284      loop_var: ops.Tensor output of a Placeholder operation. The value should
1285        be an int32 scalar representing the loop iteration number.
1286      loop_len: A scalar or scalar Tensor representing the number of iterations
1287        the loop is run for.
1288      pfor_ops: List of all ops inside the loop body.
1289      fallback_to_while_loop: If True, on failure to vectorize an op, a while
1290        loop is used to sequentially execute that op.
1291      all_indices: If not None, an int32 vector with size `loop_len`
1292        representing the iteration ids that are still active. These values
1293        should be unique and sorted. However they may not be contiguous. This is
1294        typically the case when inside a control flow construct which has
1295        partitioned the indices of the iterations that are being converted.
1296      all_indices_partitioned: If True, this object is being constructed from a
1297        control flow construct where not all the pfor iterations are guaranteed
1298        to be active.
1299      pfor_config: PForConfig object used while constructing the loop body.
1300      warn: Whether or not to warn on while loop conversions.
1301    """
1302    assert isinstance(loop_var, ops.Tensor)
1303    assert loop_var.op.type == "PlaceholderWithDefault"
1304    self._loop_var = loop_var
1305    loop_len_value = tensor_util.constant_value(loop_len)
1306    if loop_len_value is not None:
1307      loop_len = loop_len_value
1308    self._loop_len_vector = array_ops.reshape(loop_len, [1])
1309    self._all_indices_partitioned = all_indices_partitioned
1310    if all_indices_partitioned:
1311      assert all_indices is not None
1312    self.all_indices = (
1313        math_ops.range(loop_len) if all_indices is None else all_indices)
1314
1315    self._conversion_map = object_identity.ObjectIdentityDictionary()
1316    self._conversion_map[loop_var] = wrap(self.all_indices, True)
1317    self._pfor_ops = set(pfor_ops)
1318    self._pfor_op_ids = set(x._id for x in pfor_ops)
1319    self._fallback_to_while_loop = fallback_to_while_loop
1320    self._warn = warn
1321    self._pfor_config = pfor_config
1322
1323  def op_is_inside_loop(self, op):
1324    """True if op was created inside the pfor loop body."""
1325    assert isinstance(op, ops.Operation)
1326    # Note that we use self._pfor_op_ids for the check and not self._pfor_ops
1327    # since it appears there tensorflow API could return different python
1328    # objects representing the same Operation node.
1329    return op._id in self._pfor_op_ids
1330
1331  def _convert_sparse(self, y):
1332    """Returns the converted value corresponding to SparseTensor y.
1333
1334    For SparseTensors, instead of stacking the component tensors separately,
1335    resulting in component tensors with shapes (N, m, rank), (N, m), and (N,
1336    rank) respectively for indices, values, and dense_shape (where N is the loop
1337    length and m is the number of sparse tensor values per loop iter), we want
1338    to logically stack the SparseTensors, to create a SparseTensor whose
1339    components are size (N * m, rank + 1), (N * m, ), and (rank + 1,)
1340    respectively.
1341
1342    Here, we try to get the conversion of each component tensor.
1343    If the tensors are stacked via a sparse conversion, return the resulting
1344    SparseTensor composed of the converted components. Otherwise, the component
1345    tensors are either unstacked or stacked naively. In the latter case, we
1346    unstack the component tensors to reform loop_len SparseTensor elements,
1347    then correctly batch them.
1348
1349    The unstacked tensors must have the same rank. Each dimension of each
1350    SparseTensor will expand to be the largest among all SparseTensor elements
1351    for that dimension. For example, if there are N SparseTensors of rank 3
1352    being stacked, with N dense shapes, where the i_th shape is (x_i, y_i, z_i),
1353    the new dense shape will be (N, max_i(x_i), max_i(y_i), max_i(z_i)).
1354
1355    Args:
1356      y: A tf.sparse.SparseTensor.
1357
1358    Returns:
1359      A tf.sparse.SparseTensor that is the converted value corresponding to y.
1360    """
1361    outputs = [
1362        self._convert_helper(t) for t in (y.indices, y.values, y.dense_shape)
1363    ]
1364    assert all(isinstance(o, WrappedTensor) for o in outputs)
1365
1366    if all(w.is_sparse_stacked for w in outputs):
1367      return sparse_tensor.SparseTensor(*[w.t for w in outputs])
1368
1369    assert not any(w.is_sparse_stacked for w in outputs), (
1370        "Error converting SparseTensor. All components should be logically "
1371        "stacked, or none.")
1372
1373    # If component tensors were not sparsely stacked, they are either unstacked
1374    # or stacked without knowledge that they are components of sparse tensors.
1375    # In this case, we have to restack them.
1376    return self._restack_sparse_tensor_logically(
1377        *[self._unwrap_or_tile(w) for w in outputs])
1378
1379  def _restack_sparse_tensor_logically(self, indices, values, shape):
1380    sparse_tensor_rank = indices.get_shape().dims[-1].value
1381    if sparse_tensor_rank is not None:
1382      sparse_tensor_rank += 1
1383
1384    def fn(args):
1385      res = gen_sparse_ops.serialize_sparse(
1386          args[0], args[1], args[2], out_type=dtypes.variant)
1387      return res
1388
1389    # Applies a map function to the component tensors to serialize each
1390    # sparse tensor element and batch them all, then deserializes the batch.
1391    # TODO(rachelim): Try to do this without map_fn -- add the right offsets
1392    # to shape and indices tensors instead.
1393    result = map_fn.map_fn(fn, [indices, values, shape], dtype=dtypes.variant)
1394    return sparse_ops.deserialize_sparse(
1395        result, dtype=values.dtype, rank=sparse_tensor_rank)
1396
1397  def _unwrap_or_tile(self, wrapped_tensor):
1398    """Given a wrapped tensor, unwrap if stacked. Otherwise, tiles it."""
1399    output, is_stacked = wrapped_tensor.t, wrapped_tensor.is_stacked
1400    if is_stacked:
1401      return output
1402    else:
1403      return _stack(output, self._loop_len_vector).t
1404
1405  def convert(self, y):
1406    """Returns the converted value corresponding to y.
1407
1408    Args:
1409      y: A ops.Tensor or a ops.Operation object. If latter, y should not have
1410        any outputs.
1411
1412    Returns:
1413      If y does not need to be converted, it returns y as is. Else it returns
1414      the "converted value" corresponding to y.
1415    """
1416    if y is None:
1417      return None
1418    if isinstance(y, sparse_tensor.SparseTensor):
1419      return self._convert_sparse(y)
1420    assert isinstance(y, (ops.Tensor, ops.Operation)), y
1421    output = self._convert_helper(y)
1422    if isinstance(output, WrappedTensor):
1423      assert isinstance(y, ops.Tensor)
1424      return self._unwrap_or_tile(output)
1425    else:
1426      assert isinstance(y, ops.Operation)
1427      assert not y.outputs
1428      assert isinstance(output, ops.Operation)
1429    return output
1430
1431  def _was_converted(self, t):
1432    """True if t is not a conversion of itself."""
1433    converted_t = self._conversion_map[t]
1434    return converted_t.t is not t
1435
1436  def _add_conversion(self, old_output, new_output):
1437    assert isinstance(old_output, (ops.Tensor, ops.Operation)), old_output
1438    assert isinstance(new_output, (WrappedTensor, ops.Operation)), new_output
1439    self._conversion_map[old_output] = new_output
1440
1441  def _convert_reduction(self, y):
1442    # Handle reductions.
1443    if self._pfor_config is None or isinstance(y, ops.Operation):
1444      return None
1445    reduction = self._pfor_config._lookup_reduction(y)
1446    if reduction is None:
1447      return None
1448    (reduction_fn, reduction_args) = reduction
1449    batched_args = []
1450    for reduction_arg in reduction_args:
1451      assert isinstance(reduction_arg, ops.Tensor), reduction_arg
1452      # Tensor being reduced should already be converted due to a control
1453      # dependency on the created placeholder.
1454      # Note that in cases where reduction_arg is in an outer context, one
1455      # needs to locate the corresponding Enter node and use that to lookup
1456      # the conversion.
1457      # TODO(agarwal): handle reductions inside control flow constructs.
1458      assert reduction_arg in self._conversion_map, (
1459          "Unable to handle reduction of %s, possibly as it was used "
1460          "inside a control flow construct. Note that reductions across "
1461          "pfor iterations are currently not supported inside control flow "
1462          "constructs." % reduction_arg)
1463      batched_arg = self._conversion_map[reduction_arg]
1464      batched_args.append(self._unwrap_or_tile(batched_arg))
1465    outputs = reduction_fn(*batched_args)
1466    return [wrap(output, False) for output in nest.flatten(outputs)]
1467
1468  def _convert_helper(self, op_or_tensor):
1469    stack = collections.deque([op_or_tensor])
1470    while stack:
1471      y = stack[0]
1472      if y in self._conversion_map:
1473        assert isinstance(self._conversion_map[y],
1474                          (WrappedTensor, ops.Operation))
1475        stack.popleft()
1476        continue
1477      if isinstance(y, ops.Operation):
1478        assert not y.outputs, (
1479            "We only support converting Operation objects with no outputs. "
1480            "Got %s", y)
1481        y_op = y
1482      else:
1483        assert isinstance(y, ops.Tensor), y
1484        y_op = y.op
1485
1486      is_while_loop = y_op.type == "Exit"
1487      if is_while_loop:
1488        while_op = WhileOp(
1489            y, pfor_ops=self._pfor_ops,
1490            fallback_to_while_loop=self.fallback_to_while_loop,
1491            pfor_config=self._pfor_config)
1492        is_inside_loop = while_op.is_inside_loop
1493        # If all nodes in the while_loop graph were created inside the pfor, we
1494        # treat the whole loop subgraph as a single op (y_op) and try to convert
1495        # it. For while_loops that are created completely or partially outside,
1496        # we treat them as external and should be able to simply return the Exit
1497        # node output as is without needing any conversion. Note that for
1498        # while_loops that are partially constructed inside, we assume they will
1499        # be loop invariant. If that is not the case, it will create runtime
1500        # errors since the converted graph would depend on the self._loop_var
1501        # placeholder.
1502        if is_inside_loop:
1503          y_op = while_op
1504      else:
1505        is_inside_loop = self.op_is_inside_loop(y_op)
1506
1507      # If this op was not created inside the loop body, we will return as is.
1508      # 1. Convert inputs and control inputs.
1509
1510      def _add_to_stack(x):
1511        if x not in self._conversion_map:
1512          stack.appendleft(x)
1513          return True
1514        else:
1515          return False
1516
1517      if is_inside_loop:
1518        added_to_stack = False
1519        for inp in y_op.inputs:
1520          added_to_stack |= _add_to_stack(inp)
1521        for cinp in y_op.control_inputs:
1522          if cinp.outputs:
1523            for t in cinp.outputs:
1524              added_to_stack |= _add_to_stack(t)
1525          else:
1526            added_to_stack |= _add_to_stack(cinp)
1527        if added_to_stack:
1528          continue
1529
1530        converted_inputs = [self._conversion_map[inp] for inp in y_op.inputs]
1531        some_input_converted = any(self._was_converted(x) for x in y_op.inputs)
1532        some_input_stacked = any(x.is_stacked for x in converted_inputs)
1533
1534        converted_control_ops = set()
1535        some_control_input_converted = False
1536        for cinp in y_op.control_inputs:
1537          if cinp.outputs:
1538            for t in cinp.outputs:
1539              converted_t = self._conversion_map[t]
1540              if self._was_converted(t):
1541                some_control_input_converted = True
1542              converted_control_ops.add(converted_t.t.op)
1543          else:
1544            converted_cinp = self._conversion_map[cinp]
1545            assert isinstance(converted_cinp, ops.Operation)
1546            if converted_cinp != cinp:
1547              some_control_input_converted = True
1548            converted_control_ops.add(converted_cinp)
1549        converted_control_ops = list(converted_control_ops)
1550        is_stateful = _is_stateful_pfor_op(y_op)
1551      else:
1552        converted_inputs = []
1553        converted_control_ops = []
1554      logging.vlog(3, "converting op:%s\ninputs:%s\ncontrol_inputs:%s", y_op,
1555                   converted_inputs, converted_control_ops)
1556
1557      # 2. Convert y_op
1558      # If converting a while_loop, we let the while_loop convertor deal with
1559      # putting the control dependencies appropriately.
1560      control_dependencies = [] if is_while_loop else converted_control_ops
1561      with ops.control_dependencies(control_dependencies), ops.name_scope(
1562          y_op.name + "/pfor/"), ops.get_default_graph()._original_op(y_op):
1563        # Op is a placeholder for a reduction.
1564        reduce_output = self._convert_reduction(y)
1565        if reduce_output is not None:
1566          new_outputs = reduce_output
1567        # None of the inputs and control inputs were converted.
1568        elif ((not is_inside_loop or
1569               (not is_stateful and not some_input_converted and
1570                not some_control_input_converted)) and
1571              y.graph == ops.get_default_graph()):
1572          if y is y_op:
1573            assert not isinstance(y_op, WhileOp)
1574            new_outputs = y_op
1575          else:
1576            new_outputs = [wrap(x, False) for x in y_op.outputs]
1577        elif not (is_stateful or is_while_loop or some_input_stacked):
1578          # All inputs are unstacked or unconverted but some control inputs are
1579          # converted.
1580          # TODO(rachelim): Handle the case where some inputs are sparsely
1581          # stacked (i.e. any(x.is_sparse_stacked for x in converted_inputs))
1582          new_op = _create_op(y_op.type, [x.t for x in converted_inputs],
1583                              [x.dtype for x in y_op.outputs],
1584                              y_op.node_def.attr)
1585          if y is y_op:
1586            new_outputs = new_op
1587          else:
1588            new_outputs = []
1589            for old_output, new_output in zip(y_op.outputs, new_op.outputs):
1590              handle_data_util.copy_handle_data(old_output, new_output)
1591              new_outputs.append(wrap(new_output, False))
1592        else:
1593          # Either some inputs are not loop invariant or op is stateful.
1594          if hasattr(y_op, "pfor_converter"):
1595            converter = y_op.pfor_converter
1596          else:
1597            converter = _pfor_converter_registry.get(y_op.type, None)
1598          if converter is None:
1599            root_cause = (f"there is no registered converter for this op.")
1600            has_variant_outputs = any(x.dtype == dtypes.variant for x in
1601                                      y_op.outputs)
1602            has_vectorized_variant_inputs = any(
1603                _is_variant_with_internal_stacking(x) for x in
1604                y_op.inputs)
1605            if (self._fallback_to_while_loop and not has_variant_outputs
1606                and not has_vectorized_variant_inputs):
1607              converter = partial(
1608                  _fallback_converter, root_cause=root_cause, warn=self._warn)
1609            else:
1610              message = (f"No pfor vectorization defined for {y_op.type}\n"
1611                         f"{y_op}\n inputs: {converted_inputs}.")
1612              if not self._fallback_to_while_loop:
1613                message += ("Consider enabling the fallback_to_while_loop "
1614                            "option to pfor, which may run slower.")
1615              raise ValueError(message)
1616          # TODO(rachelim): Handle the case where some inputs are sparsely
1617          # stacked. We should only call the converter if it supports handling
1618          # those inputs.
1619          pfor_inputs = _PforInput(self, y_op, converted_inputs)
1620          try:
1621            try:
1622              new_outputs = converter(pfor_inputs)
1623            except ConversionNotImplementedError as e:
1624              has_vectorized_variant_inputs = any(
1625                  _is_variant_with_internal_stacking(x) for x in
1626                  y_op.inputs)
1627              if (self._fallback_to_while_loop
1628                  and not has_vectorized_variant_inputs):
1629                new_outputs = _fallback_converter(
1630                    pfor_inputs, root_cause=str(e))
1631              else:
1632                raise ValueError(str(e)).with_traceback(sys.exc_info()[2])
1633          except Exception as e:  # pylint: disable=broad-except
1634            logging.error(
1635                f"Got error while pfor was converting op {y_op} with inputs "
1636                f"{y_op.inputs[:]}\n, converted inputs {pfor_inputs.inputs}\n"
1637                f"Here are the pfor conversion stack traces: {e}")
1638            original_op = y_op
1639            while isinstance(original_op, ops.Operation):
1640              logging.error(
1641                  "%s\ncreated at:\n  %s", original_op,
1642                  "  ".join(traceback.format_list(original_op.traceback)))
1643              original_op = original_op._original_op
1644            raise
1645
1646          if isinstance(new_outputs, WrappedTensor):
1647            new_outputs = [new_outputs]
1648          assert isinstance(new_outputs,
1649                            (list, tuple, ops.Operation)), new_outputs
1650        logging.vlog(2, f"converted {y_op} {new_outputs}")
1651
1652        # Insert into self._conversion_map
1653        if y is y_op:
1654          assert isinstance(new_outputs, ops.Operation)
1655          self._add_conversion(y_op, new_outputs)
1656        else:
1657          assert len(y_op.outputs) == len(new_outputs), (y_op, y_op.outputs,
1658                                                         new_outputs)
1659          for old_output, new_output in zip(y_op.outputs, new_outputs):
1660            assert isinstance(new_output, WrappedTensor), (new_output, y, y_op)
1661            assert old_output.dtype == new_output.t.dtype, (new_output, y, y_op)
1662            # Set shape for converted output.
1663            output_shape = old_output.shape
1664            if not new_output.is_sparse_stacked:
1665              if new_output.is_stacked:
1666                loop_len = tensor_util.constant_value(self.loop_len_vector)
1667                if loop_len is None:
1668                  batch_dim = tensor_shape.TensorShape([None])
1669                else:
1670                  batch_dim = tensor_shape.TensorShape(loop_len)
1671                output_shape = batch_dim.concatenate(output_shape)
1672              if _is_variant_with_internal_stacking(new_output.t):
1673                new_output.t.set_shape([])
1674              else:
1675                new_output.t.set_shape(output_shape)
1676            self._add_conversion(old_output, new_output)
1677        stack.popleft()
1678
1679    return self._conversion_map[op_or_tensor]
1680
1681  @property
1682  def loop_len_vector(self):
1683    """Returns a single element vector whose value is number of iterations."""
1684    return self._loop_len_vector
1685
1686  @property
1687  def loop_var(self):
1688    """Returns placeholder loop variable."""
1689    return self._loop_var
1690
1691  @property
1692  def pfor_ops(self):
1693    return self._pfor_ops
1694
1695  @property
1696  def pfor_config(self):
1697    return self._pfor_config
1698
1699  @property
1700  def all_indices_partitioned(self):
1701    """all_indices_partitioned property.
1702
1703    Returns:
1704      True if we are inside a control flow construct and not all pfor iterations
1705      may be active.
1706    """
1707    return self._all_indices_partitioned
1708
1709  @property
1710  def fallback_to_while_loop(self):
1711    return self._fallback_to_while_loop
1712
1713
1714# The code below defines converters for different operations. Please see comment
1715# for RegisterPFor to see how converters should be defined.
1716
1717
1718# image_ops
1719
1720
1721@RegisterPFor("AdjustContrastv2")
1722def _convert_adjust_contrastv2(pfor_input):
1723  images = pfor_input.stacked_input(0)
1724  contrast_factor = pfor_input.unstacked_input(1)
1725  return wrap(gen_image_ops.adjust_contrastv2(images, contrast_factor), True)
1726
1727
1728@RegisterPFor("AdjustHue")
1729def _convert_adjust_hue(pfor_input):
1730  images = pfor_input.stacked_input(0)
1731  delta = pfor_input.unstacked_input(1)
1732  return wrap(gen_image_ops.adjust_hue(images, delta), True)
1733
1734
1735@RegisterPFor("AdjustSaturation")
1736def _convert_adjust_saturation(pfor_input):
1737  images = pfor_input.stacked_input(0)
1738  scale = pfor_input.unstacked_input(1)
1739  return wrap(gen_image_ops.adjust_saturation(images, scale), True)
1740
1741
1742# nn_ops
1743
1744
1745def _flatten_first_two_dims(x):
1746  """Merges first two dimensions."""
1747  old_shape = array_ops.shape(x)
1748  new_shape = array_ops.concat([[-1], old_shape[2:]], axis=0)
1749  return array_ops.reshape(x, new_shape)
1750
1751
1752def _unflatten_first_dim(x, first_dim):
1753  """Splits first dimension into [first_dim, -1]."""
1754  old_shape = array_ops.shape(x)
1755  new_shape = array_ops.concat([first_dim, [-1], old_shape[1:]], axis=0)
1756  return array_ops.reshape(x, new_shape)
1757
1758
1759def _inputs_with_flattening(pfor_input, input_indices):
1760  """Stacks and flattens first dim of inputs at indices `input_indices`."""
1761  if input_indices is None:
1762    input_indices = []
1763  pfor_input.stack_inputs(stack_indices=input_indices)
1764  inputs = []
1765  for i in range(pfor_input.num_inputs):
1766    if i in input_indices:
1767      inp = pfor_input.stacked_input(i)
1768      inp = _flatten_first_two_dims(inp)
1769    else:
1770      inp = pfor_input.unstacked_input(i)
1771    inputs.append(inp)
1772  return inputs
1773
1774
1775@RegisterPForWithArgs("Conv2D", dims=[0])
1776@RegisterPForWithArgs("DepthToSpace", dims=[0])
1777@RegisterPForWithArgs("AvgPool", dims=[0])
1778@RegisterPForWithArgs("AvgPool3D", dims=[0])
1779@RegisterPForWithArgs("MaxPool", dims=[0])
1780@RegisterPForWithArgs("MaxPoolV2", dims=[0])
1781@RegisterPForWithArgs("MaxPool3D", dims=[0])
1782@RegisterPForWithArgs("MaxPool3DGrad", dims=[0, 1, 2])
1783@RegisterPForWithArgs("MaxPoolGrad", dims=[0, 1, 2])
1784@RegisterPForWithArgs("MaxPoolGradV2", dims=[0, 1, 2])
1785@RegisterPForWithArgs("MaxPool3DGradGrad", dims=[0, 1, 2])
1786@RegisterPForWithArgs("MaxPoolGradGrad", dims=[0, 1, 2])
1787@RegisterPForWithArgs("MaxPoolGradGradV2", dims=[0, 1, 2])
1788@RegisterPForWithArgs("SoftmaxCrossEntropyWithLogits", dims=[0, 1])
1789@RegisterPForWithArgs("SparseSoftmaxCrossEntropyWithLogits", dims=[0, 1])
1790@RegisterPForWithArgs("SpaceToDepth", dims=[0])
1791def _convert_flatten_batch(pfor_input, op_type, dims):
1792  del op_type
1793  inputs = _inputs_with_flattening(pfor_input, dims)
1794  outputs = _create_op(
1795      pfor_input.op_type,
1796      inputs, [x.dtype for x in pfor_input.outputs],
1797      attrs=pfor_input.op.node_def.attr).outputs
1798  n = pfor_input.pfor.loop_len_vector
1799  outputs = [_unflatten_first_dim(x, n) for x in outputs]
1800  return [wrap(x, True) for x in outputs]
1801
1802
1803_channel_flatten_input_cache = {}
1804
1805
1806@RegisterPFor("BatchToSpaceND")
1807def _convert_batch_to_space_nd(pfor_input):
1808  inp = pfor_input.stacked_input(0)
1809  block_shape = pfor_input.unstacked_input(1)
1810  crops = pfor_input.unstacked_input(2)
1811
1812  inp_shape = array_ops.shape(inp)
1813  n = pfor_input.pfor.loop_len_vector
1814
1815  # Reshape and transpose to move the vectorization axis inside the axes that
1816  # will move to space.
1817  # Reshape to 4D and transpose
1818  block_size = math_ops.reduce_prod(block_shape)
1819  new_shape = [n[0], block_size, inp_shape[1] // block_size, -1]
1820  inp = array_ops.reshape(inp, new_shape)
1821  inp = array_ops.transpose(inp, [1, 0, 2, 3])
1822  # Reshape back to merge the block, vectorization and batch dimension, and
1823  # restore the other dimensions.
1824  new_shape = array_ops.concat([n * inp_shape[1], inp_shape[2:]], axis=0)
1825  inp = array_ops.reshape(inp, new_shape)
1826  # Call batch_to_space and then split the new batch axis.
1827  output = gen_array_ops.batch_to_space_nd(inp, block_shape, crops)
1828  output = _unflatten_first_dim(output, n)
1829  return wrap(output, True)
1830
1831
1832@RegisterPFor("SpaceToBatchND")
1833def _convert_space_to_batch_nd(pfor_input):
1834  inp = pfor_input.stacked_input(0)
1835  block_shape = pfor_input.unstacked_input(1)
1836  paddings = pfor_input.unstacked_input(2)
1837
1838  n = pfor_input.pfor.loop_len_vector
1839  inp_shape = array_ops.shape(inp)
1840  inp = _flatten_first_two_dims(inp)
1841  output = gen_array_ops.space_to_batch_nd(inp, block_shape, paddings)
1842  output_shape = array_ops.shape(output)
1843  block_size = math_ops.reduce_prod(block_shape)
1844  new_shape = [block_size, n[0], -1]
1845  output = array_ops.reshape(output, new_shape)
1846  output = array_ops.transpose(output, [1, 0, 2])
1847  new_shape = array_ops.concat(
1848      [n, block_size * inp_shape[1:2], output_shape[1:]], axis=0)
1849  output = array_ops.reshape(output, new_shape)
1850  return wrap(output, True)
1851
1852
1853def _channel_flatten_input(x, data_format):
1854  """Merge the stack dimension with the channel dimension.
1855
1856  If S is pfor's stacking dimension, then,
1857    - for SNCHW, we transpose to NSCHW. If N dimension has size 1, the transpose
1858      should be cheap.
1859    - for SNHWC, we transpose to NHWSC.
1860  We then merge the S and C dimension.
1861
1862  Args:
1863    x: ops.Tensor to transform.
1864    data_format: "NCHW" or "NHWC".
1865
1866  Returns:
1867    A 3-element tuple with the transformed value, along with the shape for
1868    reshape and order for transpose required to transform back.
1869  """
1870
1871  graph = ops.get_default_graph()
1872  cache_key = (graph, x.ref(), data_format)
1873  if cache_key not in _channel_flatten_input_cache:
1874    x_shape = array_ops.shape(x)
1875    if data_format == b"NCHW":
1876      order = [1, 0, 2, 3, 4]
1877      shape = array_ops.concat([x_shape[1:2], [-1], x_shape[3:]], axis=0)
1878      reverse_order = order
1879    else:
1880      order = [1, 2, 3, 0, 4]
1881      shape = array_ops.concat([x_shape[1:4], [-1]], axis=0)
1882      reverse_order = [3, 0, 1, 2, 4]
1883    # Move S dimension next to C dimension.
1884    x = array_ops.transpose(x, order)
1885    reverse_shape = array_ops.shape(x)
1886    # Reshape to merge the S and C dimension.
1887    x = array_ops.reshape(x, shape)
1888    outputs = x, reverse_order, reverse_shape
1889    _channel_flatten_input_cache[cache_key] = outputs
1890  else:
1891    outputs = _channel_flatten_input_cache[cache_key]
1892  return outputs
1893
1894
1895# Note that with training=True, running FusedBatchNormV3 on individual examples
1896# is very different from running FusedBatchNormV3 on a batch of those examples.
1897# This is because, for the latter case, the operation can be considered as first
1898# computing the mean and variance over all the examples and then using these
1899# to scale all those examples. This creates a data dependency between these
1900# different "iterations" since the inputs to the scaling step depends on the
1901# statistics coming from all these inputs.
1902# As with other kernels, the conversion here effectively runs the kernel
1903# independently for each iteration, and returns outputs by stacking outputs from
1904# each of those iterations.
1905@RegisterPFor("FusedBatchNormV3")
1906def _convert_fused_batch_norm(pfor_input):
1907  is_training = pfor_input.get_attr("is_training")
1908  # When BatchNorm is used with training=False, mean and variance are provided
1909  # externally and used as is by the op. Thus, we can merge the S and N
1910  # dimensions as we do for regular operations.
1911  # When BatchNorm is used with training=True, mean and variance are computed
1912  # for each channel across the batch dimension (first one). If we merge S and N
1913  # dimensions, mean and variances will be computed over a larger set. So, we
1914  # merge the S and C dimensions instead.
1915  if not is_training:
1916    # We return zeros for batch_mean and batch_variance output. Note that CPU
1917    # and GPU seem to have different behavior for those two outputs. CPU outputs
1918    # zero because these values are not used during inference. GPU outputs
1919    # something, probably real means and variances.
1920    inputs = _inputs_with_flattening(pfor_input, [0])
1921    outputs = _create_op(
1922        pfor_input.op_type,
1923        inputs, [x.dtype for x in pfor_input.outputs],
1924        attrs=pfor_input.op.node_def.attr).outputs
1925    y = outputs[0]
1926    n = pfor_input.pfor.loop_len_vector
1927    y = _unflatten_first_dim(y, n)
1928    mean = pfor_input.unstacked_input(3)
1929    zeros = array_ops.zeros_like(mean)
1930    return [wrap(y, True)] + [wrap(zeros, False)] * 5
1931
1932  pfor_input.stack_inputs()
1933  data_format = pfor_input.get_attr("data_format")
1934  # We merge the first dimension with the "C" dimension, run FusedBatchNormV3,
1935  # and then transpose back.
1936  x = pfor_input.stacked_input(0)
1937  x, reverse_order, reverse_shape = _channel_flatten_input(x, data_format)
1938  # Note that we stack all the other inputs as well so that they are the same
1939  # size as the new size of the channel dimension.
1940  inputs = [x] + [
1941      array_ops.reshape(pfor_input.stacked_input(i), [-1])
1942      for i in range(1, pfor_input.num_inputs)
1943  ]
1944  outputs = _create_op(
1945      pfor_input.op_type,
1946      inputs, [x.dtype for x in pfor_input.outputs],
1947      attrs=pfor_input.op.node_def.attr).outputs
1948  y = outputs[0]
1949  y = array_ops.reshape(y, reverse_shape)
1950  y = array_ops.transpose(y, reverse_order)
1951  n = pfor_input.pfor.loop_len_vector
1952  outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]]
1953  outputs = [y] + outputs
1954  return [wrap(x, True) for x in outputs]
1955
1956
1957@RegisterPFor("FusedBatchNormGradV3")
1958def _convert_fused_batch_norm_grad(pfor_input):
1959  pfor_input.stack_inputs()
1960  data_format = pfor_input.get_attr("data_format")
1961  y_backprop = pfor_input.stacked_input(0)
1962  y_backprop, _, _ = _channel_flatten_input(y_backprop, data_format)
1963  x = pfor_input.stacked_input(1)
1964  x, x_reverse_order, x_reverse_shape = _channel_flatten_input(x, data_format)
1965  inputs = [y_backprop, x] + [
1966      array_ops.reshape(pfor_input.stacked_input(i), [-1])
1967      for i in range(2, pfor_input.num_inputs)
1968  ]
1969  outputs = _create_op(
1970      pfor_input.op_type,
1971      inputs, [x.dtype for x in pfor_input.outputs],
1972      attrs=pfor_input.op.node_def.attr).outputs
1973  x_backprop = outputs[0]
1974  x_backprop = array_ops.reshape(x_backprop, x_reverse_shape)
1975  x_backprop = array_ops.transpose(x_backprop, x_reverse_order)
1976  n = pfor_input.pfor.loop_len_vector
1977  outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]]
1978  outputs = [x_backprop] + outputs
1979  return [wrap(output, True) for output in outputs]
1980
1981
1982@RegisterPForWithArgs("Conv2DBackpropInput", flatten_dims=[2], shape_dim=0)
1983@RegisterPForWithArgs("AvgPoolGrad", flatten_dims=[1], shape_dim=0)
1984@RegisterPForWithArgs("AvgPool3DGrad", flatten_dims=[1], shape_dim=0)
1985def _convert_flatten_batch_shape_input(pfor_input, op_type, flatten_dims,
1986                                       shape_dim):
1987  del op_type
1988  inputs = _inputs_with_flattening(pfor_input, flatten_dims)
1989  n = pfor_input.pfor.loop_len_vector
1990  # Adjust the `input_sizes` input.
1991  ones = array_ops.ones([array_ops.shape(inputs[shape_dim])[0] - 1],
1992                        dtype=n.dtype)
1993  inputs[shape_dim] *= array_ops.concat([n, ones], axis=0)
1994  outputs = _create_op(
1995      pfor_input.op_type,
1996      inputs, [x.dtype for x in pfor_input.outputs],
1997      attrs=pfor_input.op.node_def.attr).outputs
1998  outputs = [_unflatten_first_dim(x, n) for x in outputs]
1999  return [wrap(x, True) for x in outputs]
2000
2001
2002@RegisterPFor("Conv2DBackpropFilter")
2003def _convert_conv2d_backprop_filter(pfor_input):
2004  pfor_input.stack_inputs(stack_indices=[2])
2005  inputs, inputs_stacked, _ = pfor_input.input(0)
2006  filter_sizes = pfor_input.unstacked_input(1)
2007  grads = pfor_input.stacked_input(2)
2008  strides = pfor_input.get_attr("strides")
2009  padding = pfor_input.get_attr("padding")
2010  use_cudnn_on_gpu = pfor_input.get_attr("use_cudnn_on_gpu")
2011  data_format = pfor_input.get_attr("data_format")
2012  dilations = pfor_input.get_attr("dilations")
2013  if inputs_stacked:
2014    # TODO(agarwal): Implement this efficiently.
2015    logging.warning("Conv2DBackpropFilter uses a while_loop. Fix that!")
2016
2017    def while_body(i, ta):
2018      inp_i = inputs[i, ...]
2019      grad_i = grads[i, ...]
2020      output = nn_ops.conv2d_backprop_filter(
2021          inp_i,
2022          filter_sizes,
2023          grad_i,
2024          strides=strides,
2025          padding=padding,
2026          use_cudnn_on_gpu=use_cudnn_on_gpu,
2027          data_format=data_format,
2028          dilations=dilations)
2029      return i + 1, ta.write(i, array_ops.expand_dims(output, 0))
2030
2031    n = array_ops.reshape(pfor_input.pfor.loop_len_vector, [])
2032    _, ta = control_flow_ops.while_loop(
2033        lambda i, ta: i < n, while_body,
2034        (0, tensor_array_ops.TensorArray(inputs.dtype, n)))
2035    output = ta.concat()
2036    return wrap(output, True)
2037  else:
2038    # We merge the stack dimension with the channel dimension of the gradients
2039    # and pretend we had a larger filter (see change to filter_sizes below).
2040    # Once the filter backprop is computed, we reshape and transpose back
2041    # appropriately.
2042    grads, _, _ = _channel_flatten_input(grads, data_format)
2043    n = pfor_input.pfor.loop_len_vector
2044    old_filter_sizes = filter_sizes
2045    filter_sizes *= array_ops.concat([[1, 1, 1], n], axis=0)
2046    output = nn_ops.conv2d_backprop_filter(
2047        inputs,
2048        filter_sizes,
2049        grads,
2050        strides=strides,
2051        padding=padding,
2052        use_cudnn_on_gpu=use_cudnn_on_gpu,
2053        data_format=data_format,
2054        dilations=dilations)
2055    new_filter_shape = array_ops.concat([old_filter_sizes[:3], n, [-1]], axis=0)
2056    output = array_ops.reshape(output, new_filter_shape)
2057    output = array_ops.transpose(output, [3, 0, 1, 2, 4])
2058    return wrap(output, True)
2059
2060
2061def _flatten_with_inner_dim(x, dim, x_rank):
2062  """Merges the first dim with the specified dim."""
2063  shape = array_ops.shape(x)
2064  x = array_ops.transpose(x,
2065                          list(range(1, dim)) + [0] + list(range(dim, x_rank)))
2066
2067  if dim < x_rank - 1:
2068    new_shape_pieces = [shape[1:dim], [-1], shape[dim + 1:]]
2069  else:
2070    new_shape_pieces = [shape[1:dim], [-1]]
2071  new_shape = array_ops.concat(new_shape_pieces, axis=0)
2072  return array_ops.reshape(x, new_shape)
2073
2074
2075def _unflatten_with_inner_dim(x, dim, x_rank, stack_size):
2076  """Undoes _flatten_with_inner_dim."""
2077  shape = array_ops.shape(x)
2078  if dim < x_rank - 1:
2079    new_shape_pieces = [shape[:dim], [stack_size], [-1], shape[dim + 1:]]
2080  else:
2081    new_shape_pieces = [shape[:dim], [stack_size], [-1]]
2082  new_shape = array_ops.concat(new_shape_pieces, axis=0)
2083  x = array_ops.reshape(x, new_shape)
2084  dims_permutation = [dim] + list(range(dim)) + list(range(dim + 1, x_rank + 1))
2085  return array_ops.transpose(x, dims_permutation)
2086
2087
2088@RegisterPFor("DepthwiseConv2dNative")
2089def _convert_depthwise_conv2d_native(pfor_input):
2090  # Kernel can be vectorized, so folding to batch dimension does not work. We
2091  # instead fold into the channel dimension because it is parallel.
2092  stack_size = pfor_input.pfor.loop_len_vector[0]
2093  data_format = pfor_input.get_attr("data_format")
2094  c_dim = 1 if data_format == b"NCHW" else 3
2095  t = _flatten_with_inner_dim(pfor_input.stacked_input(0), c_dim + 1, 5)
2096  kernel = _flatten_with_inner_dim(pfor_input.stacked_input(1), 3, 5)
2097  conv = _create_op(
2098      "DepthwiseConv2dNative", [t, kernel],
2099      [x.dtype for x in pfor_input.outputs],
2100      attrs=pfor_input.op.node_def.attr).outputs[0]
2101  return wrap(_unflatten_with_inner_dim(conv, c_dim, 4, stack_size), True)
2102
2103
2104@RegisterPFor("DepthwiseConv2dNativeBackpropInput")
2105def _convert_depthwise_conv2d_native_backprop_input(pfor_input):
2106  stack_size = pfor_input.pfor.loop_len_vector[0]
2107  input_sizes = pfor_input.unstacked_input(0)
2108  data_format = pfor_input.get_attr("data_format")
2109  c_dim = 1 if data_format == b"NCHW" else 3
2110  input_sizes_mutipliers = [
2111      constant_op.constant([1] * c_dim, dtype=dtypes.int32), [stack_size]
2112  ]
2113  if c_dim < 3:
2114    input_sizes_mutipliers += [
2115        constant_op.constant([1] * (3 - c_dim), dtype=dtypes.int32)
2116    ]
2117  input_sizes *= array_ops.concat(input_sizes_mutipliers, axis=0)
2118  kernel = _flatten_with_inner_dim(pfor_input.stacked_input(1), 3, 5)
2119  out_backprop = _flatten_with_inner_dim(
2120      pfor_input.stacked_input(2), c_dim + 1, 5)
2121  result = _create_op(
2122      "DepthwiseConv2dNativeBackpropInput", [input_sizes, kernel, out_backprop],
2123      [x.dtype for x in pfor_input.outputs],
2124      attrs=pfor_input.op.node_def.attr).outputs[0]
2125  return wrap(_unflatten_with_inner_dim(result, c_dim, 4, stack_size), True)
2126
2127
2128@RegisterPFor("DepthwiseConv2dNativeBackpropFilter")
2129def _convert_depthwise_conv2d_native_backprop_filter(pfor_input):
2130  stack_size = pfor_input.pfor.loop_len_vector[0]
2131  data_format = pfor_input.get_attr("data_format")
2132  c_dim = 1 if data_format == b"NCHW" else 3
2133  inputs = _flatten_with_inner_dim(pfor_input.stacked_input(0), c_dim + 1, 5)
2134  filter_sizes = pfor_input.unstacked_input(1)
2135  filter_sizes_multipliers = [
2136      constant_op.constant([1, 1], dtype=dtypes.int32), [stack_size],
2137      constant_op.constant([1], dtype=dtypes.int32)
2138  ]
2139  filter_sizes *= array_ops.concat(filter_sizes_multipliers, axis=0)
2140  out_backprop = _flatten_with_inner_dim(
2141      pfor_input.stacked_input(2), c_dim + 1, 5)
2142  result = _create_op(
2143      "DepthwiseConv2dNativeBackpropFilter",
2144      [inputs, filter_sizes, out_backprop],
2145      [x.dtype for x in pfor_input.outputs],
2146      attrs=pfor_input.op.node_def.attr).outputs[0]
2147  return wrap(_unflatten_with_inner_dim(result, 2, 4, stack_size), True)
2148
2149
2150@RegisterPForWithArgs("LogSoftmax", gen_nn_ops.log_softmax)
2151@RegisterPForWithArgs("Softmax", gen_nn_ops.softmax)
2152def _convert_softmax(pfor_input, op_type, op_func):
2153  del op_type
2154  return wrap(op_func(pfor_input.stacked_input(0)), True)
2155
2156
2157# array_ops
2158
2159
2160@RegisterPForWithArgs("Identity", array_ops.identity)
2161@RegisterPForWithArgs("StopGradient", array_ops.stop_gradient)
2162@RegisterPForWithArgs("MatrixDiag", array_ops.matrix_diag)
2163@RegisterPForWithArgs("MatrixDiagPart", array_ops.matrix_diag_part)
2164@RegisterPForWithArgs("_EagerConst", array_ops.identity)
2165def _convert_identity(pfor_input, op_type, op_func):
2166  del op_type
2167  return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
2168
2169
2170@RegisterPFor("IdentityN")
2171def _convert_identity_n(pfor_input):
2172  outputs = array_ops.identity_n([x.t for x in pfor_input.inputs])
2173  return [
2174      wrap(out, inp.is_stacked) for out, inp in zip(outputs, pfor_input.inputs)
2175  ]
2176
2177
2178@RegisterPFor("Reshape")
2179def _convert_reshape(pfor_input):
2180  t = pfor_input.stacked_input(0)
2181  shape = pfor_input.unstacked_input(1)
2182  new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0)
2183  return wrap(array_ops.reshape(t, new_shape), True)
2184
2185
2186@RegisterPFor("Fill")
2187def _convert_fill(pfor_input):
2188  dims = pfor_input.unstacked_input(0)
2189  value = pfor_input.stacked_input(1)
2190  # Expand the rank of `value`
2191  new_shape = array_ops.concat(
2192      [[-1], array_ops.ones([array_ops.size(dims)], dtype=dtypes.int32)],
2193      axis=0)
2194  value = array_ops.reshape(value, new_shape)
2195  # Compute the new output shape
2196  new_dims = array_ops.concat([pfor_input.pfor.loop_len_vector, dims], axis=0)
2197  # Broadcast
2198  return wrap(array_ops.broadcast_to(value, new_dims), True)
2199
2200
2201@RegisterPFor("BroadcastTo")
2202def _convert_broadcast_to(pfor_input):
2203  t = pfor_input.stacked_input(0)
2204  shape = pfor_input.unstacked_input(1)
2205  new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0)
2206
2207  # Expand dims of stacked t to broadcast against the new shape.
2208  # TODO(davmre): consider factoring out common code with
2209  # `expanddim_inputs_for_broadcast`, which has similar logic but with
2210  # implicit shapes (of input Tensors) rather than explicit shapes.
2211  rank_diff = array_ops.shape(new_shape)[0] - array_ops.rank(t)
2212  ones = array_ops.tile([1], array_ops.reshape(rank_diff, [1]))
2213  t_shape = array_ops.shape(t)
2214  t_expanded_shape = array_ops.concat([t_shape[:1], ones, t_shape[1:]], axis=0)
2215
2216  return wrap(
2217      array_ops.broadcast_to(array_ops.reshape(t, t_expanded_shape), new_shape),
2218      True)
2219
2220
2221@RegisterPFor("ExpandDims")
2222def _convert_expanddims(pfor_input):
2223  t = pfor_input.stacked_input(0)
2224  dim = pfor_input.unstacked_input(1)
2225  dim += math_ops.cast(dim >= 0, dim.dtype)
2226  return wrap(array_ops.expand_dims(t, axis=dim), True)
2227
2228
2229@RegisterPForWithArgs("LowerBound", gen_array_ops.lower_bound)
2230@RegisterPForWithArgs("UpperBound", gen_array_ops.upper_bound)
2231def _convert_searchsorted(pfor_input, _, op_func):
2232  pfor_input.stack_inputs()
2233  sorted_inputs = _flatten_first_two_dims(pfor_input.stacked_input(0))
2234  values = _flatten_first_two_dims(pfor_input.stacked_input(1))
2235  out_type = pfor_input.get_attr("out_type")
2236  output = op_func(sorted_inputs, values, out_type)
2237  return wrap(
2238      _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector), True)
2239
2240
2241@RegisterPFor("MatrixBandPart")
2242def _convert_matrix_band_part(pfor_input):
2243  t = pfor_input.stacked_input(0)
2244  num_lower = pfor_input.unstacked_input(1)
2245  num_upper = pfor_input.unstacked_input(2)
2246  return wrap(
2247      array_ops.matrix_band_part(t, num_lower=num_lower, num_upper=num_upper),
2248      True)
2249
2250
2251@RegisterPFor("MatrixSetDiag")
2252def _convert_matrix_set_diag(pfor_input):
2253  pfor_input.stack_inputs()
2254  t = pfor_input.stacked_input(0)
2255  diag = pfor_input.stacked_input(1)
2256  return wrap(array_ops.matrix_set_diag(t, diag), True)
2257
2258
2259# Registrations for Matrix{Diag,DiagPart,SetDiag}V2-3.
2260# The input orders defined in the OpKernel and the actual python API are
2261# different (for compatibility with V1), so we cannot use _convert_identity.
2262# v2 is not compatible with v3 and is never exposed on the public API.
2263@RegisterPFor("MatrixDiagV2")
2264@RegisterPFor("MatrixDiagV3")
2265def _convert_matrix_diag_v2(pfor_input):
2266  params = {
2267      "diagonal": pfor_input.stacked_input(0),
2268      "k": pfor_input.unstacked_input(1),
2269      "num_rows": pfor_input.unstacked_input(2),
2270      "num_cols": pfor_input.unstacked_input(3),
2271      "padding_value": pfor_input.unstacked_input(4)
2272  }
2273  if pfor_input.op_type == "MatrixDiagV2":
2274    return wrap(array_ops.matrix_diag_v2(**params), True)
2275  params["align"] = pfor_input.get_attr("align")
2276  return wrap(array_ops.matrix_diag(**params), True)
2277
2278
2279@RegisterPFor("Diag")
2280def _convert_diag(pfor_input):
2281  diag = pfor_input.stacked_input(0)
2282  if diag.shape.ndims == 2:
2283    # We can use matrix_diag.
2284    return wrap(array_ops.matrix_diag(diag), True)
2285  else:
2286    # It is not clear if we can do better than a while loop here with existing
2287    # kernels.
2288    return _fallback_converter(pfor_input, warn=False)
2289
2290
2291# See notes for MatrixDiagV2
2292@RegisterPFor("MatrixDiagPartV2")
2293@RegisterPFor("MatrixDiagPartV3")
2294def _convert_matrix_diag_part_v2(pfor_input):
2295  params = {
2296      "input": pfor_input.stacked_input(0),
2297      "k": pfor_input.unstacked_input(1),
2298      "padding_value": pfor_input.unstacked_input(2)
2299  }
2300  if pfor_input.op_type == "MatrixDiagPartV2":
2301    return wrap(array_ops.matrix_diag_part_v2(**params), True)
2302  params["align"] = pfor_input.get_attr("align")
2303  return wrap(array_ops.matrix_diag_part(**params), True)
2304
2305
2306# See notes for MatrixDiagV2
2307@RegisterPFor("MatrixSetDiagV2")
2308@RegisterPFor("MatrixSetDiagV3")
2309def _convert_matrix_set_diag_v2(pfor_input):
2310  pfor_input.stack_inputs([0, 1])
2311  params = {
2312      "input": pfor_input.stacked_input(0),
2313      "diagonal": pfor_input.stacked_input(1),
2314      "k": pfor_input.unstacked_input(2)
2315  }
2316  if pfor_input.op_type == "MatrixSetDiagV2":
2317    return wrap(array_ops.matrix_set_diag_v2(**params), True)
2318  params["align"] = pfor_input.get_attr("align")
2319  return wrap(array_ops.matrix_set_diag(**params), True)
2320
2321
2322@RegisterPFor("DiagPart")
2323def _convert_diag_part(pfor_input):
2324  inp = pfor_input.stacked_input(0)
2325  if inp.shape.ndims == 3:
2326    # We can use matrix_diag_part.
2327    return wrap(array_ops.matrix_diag_part(inp), True)
2328  else:
2329    # It is not clear if we can do better than a while loop here with existing
2330    # kernels.
2331    return _fallback_converter(pfor_input, warn=False)
2332
2333
2334@RegisterPFor("OneHot")
2335def _convert_one_hot(pfor_input):
2336  indices = pfor_input.stacked_input(0)
2337  depth = pfor_input.unstacked_input(1)
2338  on_value = pfor_input.unstacked_input(2)
2339  off_value = pfor_input.unstacked_input(3)
2340  axis = pfor_input.get_attr("axis")
2341  if axis >= 0:
2342    axis += 1
2343  return wrap(
2344      array_ops.one_hot(indices, depth, on_value, off_value, axis), True)
2345
2346
2347@RegisterPFor("Slice")
2348def _convert_slice(pfor_input):
2349  t = pfor_input.stacked_input(0)
2350  begin, begin_stacked, _ = pfor_input.input(1)
2351  size = pfor_input.unstacked_input(2)
2352  if not begin_stacked:
2353    begin = array_ops.concat([[0], begin], axis=0)
2354    size = array_ops.concat([[-1], size], axis=0)
2355    return wrap(array_ops.slice(t, begin, size), True)
2356  else:
2357    # Handle negative sizes.
2358    #
2359    # If the `begin` entry corresponding to a negative `size` is loop-variant,
2360    # the output would be ragged. This case is not supported. But `size` having
2361    # some negative values and some loop-variant `begin`s is OK (and it's hard
2362    # to tell the difference statically).
2363    original_unstacked_shape = _stack(
2364        array_ops.shape(t)[1:], pfor_input.pfor.loop_len_vector).t
2365    broadcast_size = _stack(size, pfor_input.pfor.loop_len_vector).t
2366    result_shape = array_ops.where(
2367        math_ops.less(broadcast_size, 0),
2368        original_unstacked_shape - begin + broadcast_size + 1, broadcast_size)
2369    result_shape = math_ops.cast(math_ops.reduce_max(result_shape, axis=0),
2370                                 dtypes.int64)
2371
2372    # Now we enumerate points in the sliced region for each pfor iteration and
2373    # gather them.
2374    cumsize = math_ops.cumprod(result_shape, exclusive=True, reverse=True)
2375    result_num_elements = math_ops.reduce_prod(result_shape)
2376    # Offsets are loop-variant. We first compute loop-invariant gather
2377    # coordinates, then broadcast-add the loop-variant `begin` offsets.
2378    result_base_coordinates = (
2379        math_ops.range(result_num_elements, dtype=dtypes.int64)[:, None]
2380        // cumsize[None, :]) % result_shape[None, :]
2381    result_coordinates = (
2382        begin[:, None, :]
2383        + math_ops.cast(result_base_coordinates, begin.dtype)[None, :, :])
2384    result_flat = array_ops.gather_nd(params=t, indices=result_coordinates,
2385                                      batch_dims=1)
2386    result_stacked_shape = array_ops.concat(
2387        [math_ops.cast(pfor_input.pfor.loop_len_vector, result_shape.dtype),
2388         result_shape],
2389        axis=0)
2390    return wrap(array_ops.reshape(result_flat, result_stacked_shape), True)
2391
2392
2393@RegisterPFor("Tile")
2394def _convert_tile(pfor_input):
2395  t = pfor_input.stacked_input(0)
2396  multiples = pfor_input.unstacked_input(1)
2397  multiples = array_ops.concat([[1], multiples], 0)
2398  return wrap(array_ops.tile(t, multiples), True)
2399
2400
2401@RegisterPFor("Pack")
2402def _convert_pack(pfor_input):
2403  pfor_input.stack_inputs()
2404  axis = pfor_input.get_attr("axis")
2405  if axis >= 0:
2406    axis += 1
2407  return wrap(
2408      array_ops.stack([x.t for x in pfor_input.inputs], axis=axis), True)
2409
2410
2411@RegisterPFor("Unpack")
2412def _convert_unpack(pfor_input):
2413  value = pfor_input.stacked_input(0)
2414  axis = pfor_input.get_attr("axis")
2415  if axis >= 0:
2416    axis += 1
2417  num = pfor_input.get_attr("num")
2418  return [wrap(x, True) for x in array_ops.unstack(value, axis=axis, num=num)]
2419
2420
2421@RegisterPFor("Pad")
2422def _convert_pad(pfor_input):
2423  t = pfor_input.stacked_input(0)
2424  paddings = pfor_input.unstacked_input(1)
2425  paddings = array_ops.concat([[[0, 0]], paddings], 0)
2426  return wrap(array_ops.pad(t, paddings, mode="CONSTANT"), True)
2427
2428
2429@RegisterPFor("PadV2")
2430def _convert_pad_v2(pfor_input):
2431  t = pfor_input.stacked_input(0)
2432  paddings = pfor_input.unstacked_input(1)
2433  paddings = array_ops.concat([[[0, 0]], paddings], 0)
2434  return wrap(array_ops.pad_v2(t, paddings, mode="CONSTANT"), True)
2435
2436
2437@RegisterPFor("Split")
2438def _convert_split(pfor_input):
2439  split_dim = pfor_input.unstacked_input(0)
2440  t = pfor_input.stacked_input(1)
2441  num_split = pfor_input.get_attr("num_split")
2442  split_dim += math_ops.cast(split_dim >= 0, dtypes.int32)
2443  return [wrap(x, True) for x in array_ops.split(t, num_split, axis=split_dim)]
2444
2445
2446@RegisterPFor("SplitV")
2447def _convert_split_v(pfor_input):
2448  t = pfor_input.stacked_input(0)
2449  splits = pfor_input.unstacked_input(1)
2450  split_dim = pfor_input.unstacked_input(2)
2451  split_dim += math_ops.cast(split_dim >= 0, dtypes.int32)
2452  return [wrap(x, True) for x in array_ops.split(t, splits, axis=split_dim)]
2453
2454
2455@RegisterPFor("Squeeze")
2456def _convert_squeeze(pfor_input):
2457  t = pfor_input.stacked_input(0)
2458  squeeze_dims = pfor_input.get_attr("squeeze_dims")
2459  squeeze_dims = [i + 1 if i >= 0 else i for i in squeeze_dims]
2460  return wrap(array_ops.squeeze(t, axis=squeeze_dims), True)
2461
2462
2463@RegisterPFor("ReverseV2")
2464def _convert_reverse(pfor_input):
2465  value = pfor_input.stacked_input(0)
2466  axis = pfor_input.unstacked_input(1)
2467  new_axis = array_ops.where_v2(axis >= 0, axis + 1, axis)
2468  return wrap(gen_array_ops.reverse_v2(value, axis=new_axis), True)
2469
2470
2471@RegisterPForWithArgs("Transpose", gen_array_ops.transpose)
2472@RegisterPForWithArgs("ConjugateTranspose", gen_array_ops.conjugate_transpose)
2473def _convert_transpose(pfor_input, _, op_func):
2474  t = pfor_input.stacked_input(0)
2475  perm = pfor_input.unstacked_input(1)
2476  new_perm = array_ops.concat([[0], perm + 1], axis=0)
2477  return wrap(op_func(t, new_perm), True)
2478
2479
2480@RegisterPFor("ZerosLike")
2481def _convert_zeroslike(pfor_input):
2482  t = pfor_input.stacked_input(0)
2483  shape = array_ops.shape(t)[1:]
2484  return wrap(array_ops.zeros(shape, dtype=t.dtype), False)
2485
2486
2487@RegisterPFor("Gather")
2488@RegisterPFor("GatherV2")
2489def _convert_gather(pfor_input):
2490  param, param_stacked, _ = pfor_input.input(0)
2491  indices, indices_stacked, _ = pfor_input.input(1)
2492  batch_dims = pfor_input.get_attr("batch_dims")
2493
2494  op_type = pfor_input.op_type
2495  if op_type == "Gather":
2496    validate_indices = pfor_input.get_attr("validate_indices")
2497    axis = 0
2498  else:
2499    validate_indices = None
2500    # Assume we will never have a Tensor with rank > 2**32.
2501    axis = math_ops.cast(pfor_input.unstacked_input(2), dtypes.int32)
2502    axis_value = tensor_util.constant_value(axis)
2503    if axis_value is not None:
2504      axis = axis_value
2505  if indices_stacked and not param_stacked:
2506    if indices is pfor_input.pfor.all_indices and axis == 0:
2507      param_shape0 = tensor_shape.dimension_value(param.shape[0])
2508      indices_shape0 = tensor_shape.dimension_value(indices.shape[0])
2509      if param_shape0 is not None and indices_shape0 == param_shape0:
2510        # Note that with loops and conditionals, indices may not be contiguous.
2511        # However they will be sorted and unique. So if the shape matches, then
2512        # it must be picking up all the rows of param.
2513        return wrap(param, True)
2514
2515    if batch_dims != 0:
2516      # Convert `batch_dims` to its positive equivalent if necessary.
2517      batch_dims_pos = batch_dims
2518      if batch_dims < 0:
2519        batch_dims_pos += array_ops.rank(indices)
2520      # In order to maintain
2521      #   indices.shape[:batch_dims] == params.shape[:batch_dims]
2522      # with stacked indices, we move the first dimension of `indices` to the
2523      # `batch_dims + 1`th position. The (non-batch) index dimensions will be
2524      # inserted into the shape of `output` at the `axis` dimension, which is
2525      # then transposed to the front (below).
2526      order = array_ops.concat([
2527          math_ops.range(1, batch_dims_pos + 1),
2528          [0],
2529          math_ops.range(batch_dims_pos + 1, array_ops.rank(indices))], axis=0)
2530      indices = array_ops.transpose(indices, order)
2531
2532    output = array_ops.gather(
2533        param, indices, validate_indices=validate_indices, axis=axis,
2534        batch_dims=batch_dims)
2535    if axis != 0:
2536      axis = smart_cond.smart_cond(axis < 0,
2537                                   lambda: axis + array_ops.rank(param),
2538                                   lambda: ops.convert_to_tensor(axis))
2539      order = array_ops.concat(
2540          [[axis],
2541           math_ops.range(axis),
2542           math_ops.range(axis + 1, array_ops.rank(output))],
2543          axis=0)
2544      output = smart_cond.smart_cond(
2545          math_ops.equal(axis, 0), lambda: output,
2546          lambda: array_ops.transpose(output, order))
2547    return wrap(output, True)
2548  if param_stacked:
2549    pfor_input.stack_inputs(stack_indices=[1])
2550    indices = pfor_input.stacked_input(1)
2551    if isinstance(axis, ops.Tensor):
2552      axis = array_ops.where(axis >= 0, axis + 1, axis)
2553    else:
2554      axis = axis + 1 if axis >= 0 else axis
2555    batch_dims = batch_dims + 1 if batch_dims >= 0 else batch_dims
2556    output = array_ops.gather(param, indices, axis=axis, batch_dims=batch_dims)
2557    return wrap(output, True)
2558
2559
2560@RegisterPFor("GatherNd")
2561def _convert_gather_nd(pfor_input):
2562  # TODO(jmenick): Add support for unstacked params.
2563  pfor_input.stack_inputs(stack_indices=[1])
2564  params = pfor_input.stacked_input(0)
2565  indices = pfor_input.stacked_input(1)
2566  stacked_result = array_ops.gather_nd(params, indices, batch_dims=1)
2567  return wrap(stacked_result, True)
2568
2569
2570@RegisterPFor("ConcatV2")
2571def _convert_concatv2(pfor_input):
2572  n = pfor_input.num_inputs
2573  pfor_input.stack_inputs(stack_indices=range(n - 1))
2574  axis = pfor_input.unstacked_input(n - 1)
2575  axis += math_ops.cast(axis >= 0, axis.dtype)
2576  return wrap(
2577      array_ops.concat([x.t for x in pfor_input.inputs[:n - 1]], axis=axis),
2578      True)
2579
2580
2581@RegisterPFor("StridedSlice")
2582def _convert_strided_slice(pfor_input):
2583  inp = pfor_input.stacked_input(0)
2584  begin = pfor_input.unstacked_input(1)
2585  end = pfor_input.unstacked_input(2)
2586  strides = pfor_input.unstacked_input(3)
2587  begin_mask = pfor_input.get_attr("begin_mask")
2588  end_mask = pfor_input.get_attr("end_mask")
2589  ellipsis_mask = pfor_input.get_attr("ellipsis_mask")
2590  new_axis_mask = pfor_input.get_attr("new_axis_mask")
2591  shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask")
2592
2593  begin = array_ops.concat([[0], begin], axis=0)
2594  end = array_ops.concat([[0], end], axis=0)
2595  strides = array_ops.concat([[1], strides], axis=0)
2596  begin_mask = begin_mask << 1 | 1
2597  end_mask = end_mask << 1 | 1
2598  ellipsis_mask <<= 1
2599  new_axis_mask <<= 1
2600  shrink_axis_mask <<= 1
2601  return wrap(
2602      array_ops.strided_slice(
2603          inp,
2604          begin,
2605          end,
2606          strides,
2607          begin_mask=begin_mask,
2608          end_mask=end_mask,
2609          ellipsis_mask=ellipsis_mask,
2610          new_axis_mask=new_axis_mask,
2611          shrink_axis_mask=shrink_axis_mask), True)
2612
2613
2614@RegisterPFor("StridedSliceGrad")
2615def _convert_strided_slice_grad(pfor_input):
2616  shape = pfor_input.unstacked_input(0)
2617  begin = pfor_input.unstacked_input(1)
2618  end = pfor_input.unstacked_input(2)
2619  strides = pfor_input.unstacked_input(3)
2620  dy = pfor_input.stacked_input(4)
2621  begin_mask = pfor_input.get_attr("begin_mask")
2622  end_mask = pfor_input.get_attr("end_mask")
2623  ellipsis_mask = pfor_input.get_attr("ellipsis_mask")
2624  new_axis_mask = pfor_input.get_attr("new_axis_mask")
2625  shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask")
2626
2627  shape = array_ops.concat(
2628      [math_ops.cast(pfor_input.pfor.loop_len_vector, shape.dtype), shape],
2629      axis=0)
2630  begin = array_ops.concat([[0], begin], axis=0)
2631  end = array_ops.concat([[0], end], axis=0)
2632  strides = array_ops.concat([[1], strides], axis=0)
2633  begin_mask = begin_mask << 1 | 1
2634  end_mask = end_mask << 1 | 1
2635  ellipsis_mask <<= 1
2636  new_axis_mask <<= 1
2637  shrink_axis_mask <<= 1
2638  return wrap(
2639      array_ops.strided_slice_grad(
2640          shape,
2641          begin,
2642          end,
2643          strides,
2644          dy,
2645          begin_mask=begin_mask,
2646          end_mask=end_mask,
2647          ellipsis_mask=ellipsis_mask,
2648          new_axis_mask=new_axis_mask,
2649          shrink_axis_mask=shrink_axis_mask), True)
2650
2651
2652@RegisterPFor("CheckNumerics")
2653def _convert_check_numerics(pfor_input):
2654  t = pfor_input.stacked_input(0)
2655  message = pfor_input.get_attr("message")
2656  return wrap(gen_array_ops.check_numerics(t, message), True)
2657
2658
2659@RegisterPFor("EnsureShape")
2660def _convert_ensure_shape(pfor_input):
2661  t = pfor_input.stacked_input(0)
2662  shape = tensor_shape.TensorShape(pfor_input.get_attr("shape"))
2663  return wrap(gen_array_ops.ensure_shape(t, [None] + shape), True)
2664
2665
2666# manip_ops
2667
2668
2669@RegisterPFor("Roll")
2670def _convert_roll(pfor_input):
2671  t = pfor_input.stacked_input(0)
2672  shift, shift_stacked, _ = pfor_input.input(1)
2673  axis = pfor_input.unstacked_input(2)
2674  if not shift_stacked:
2675    return wrap(manip_ops.roll(t, shift, axis + 1), True)
2676  else:
2677    # `axis` and `shift` may both be vectors, with repeated axes summing the
2678    # corresponding `shift`s. We scatter shifts into a dense array of shape
2679    # [loop_len, num_unstacked_axes] indicating the offset for each axis.
2680    num_unstacked_axes = math_ops.cast(array_ops.rank(t), dtypes.int64) - 1
2681    axis = math_ops.cast(array_ops.reshape(axis, [-1]), dtypes.int64)
2682    loop_len = math_ops.cast(pfor_input.pfor.loop_len_vector[0], dtypes.int64)
2683    shift = math_ops.cast(array_ops.reshape(shift, [loop_len, -1]),
2684                          dtypes.int64)
2685    axis_segment_ids = (
2686        math_ops.range(loop_len, dtype=dtypes.int64)[:, None]
2687        * num_unstacked_axes + axis[None, :])
2688    axis_offsets = array_ops.reshape(
2689        math_ops.unsorted_segment_sum(
2690            data=shift, segment_ids=axis_segment_ids,
2691            num_segments=loop_len * num_unstacked_axes),
2692        [loop_len, num_unstacked_axes])
2693
2694    # Determine the coordinates in the input array of each result and gather
2695    # them.
2696    unstacked_shape = array_ops.shape(t, out_type=dtypes.int64)[1:]
2697    cumsize = math_ops.cumprod(unstacked_shape, exclusive=True, reverse=True)
2698    num_unstacked_elements = math_ops.reduce_prod(unstacked_shape)
2699    result_coordinates = (
2700        (math_ops.range(num_unstacked_elements,
2701                        dtype=dtypes.int64)[None, :, None]
2702         // cumsize[None, None, :] - axis_offsets[:, None, :])
2703        % unstacked_shape[None, None, :])
2704    result_flat = array_ops.gather_nd(params=t, indices=result_coordinates,
2705                                      batch_dims=1)
2706    return wrap(array_ops.reshape(result_flat, array_ops.shape(t)),
2707                True)
2708
2709# math_ops
2710
2711
2712@RegisterPFor("MatMul")
2713def _convert_matmul(pfor_input):
2714  # TODO(agarwal): Check if tiling is faster than two transposes.
2715  a, a_stacked, _ = pfor_input.input(0)
2716  b, b_stacked, _ = pfor_input.input(1)
2717  tr_a = pfor_input.get_attr("transpose_a")
2718  tr_b = pfor_input.get_attr("transpose_b")
2719  if a_stacked and b_stacked:
2720    output = wrap(math_ops.matmul(a, b, adjoint_a=tr_a, adjoint_b=tr_b), True)
2721    return output
2722  elif a_stacked:
2723    if tr_a:
2724      a = array_ops.transpose(a, [0, 2, 1])
2725    if a.shape.is_fully_defined():
2726      x, y, z = a.shape
2727    else:
2728      x, y, z = [
2729          array_ops.reshape(i, [])
2730          for i in array_ops.split(array_ops.shape(a), 3)
2731      ]
2732    a = array_ops.reshape(a, [x * y, z])
2733    prod = math_ops.matmul(a, b, transpose_b=tr_b)
2734    return wrap(array_ops.reshape(prod, [x, y, -1]), True)
2735  else:
2736    assert b_stacked
2737    if tr_b:
2738      perm = [2, 0, 1]
2739      b = array_ops.transpose(b, perm)
2740    else:
2741      # As an optimization, if one of the first two dimensions is 1, then we can
2742      # reshape instead of transpose.
2743      # TODO(agarwal): This check can be done inside Transpose kernel.
2744      b_shape = array_ops.shape(b)
2745      min_dim = math_ops.minimum(b_shape[0], b_shape[1])
2746      perm = array_ops.where(
2747          math_ops.equal(min_dim, 1), [0, 1, 2], [1, 0, 2])
2748      new_shape = array_ops.stack([b_shape[1], b_shape[0], b_shape[2]])
2749      b = array_ops.transpose(b, perm)
2750      b = array_ops.reshape(b, new_shape)
2751
2752    if b.shape.is_fully_defined():
2753      x, y, z = b.shape
2754    else:
2755      x, y, z = [
2756          array_ops.reshape(i, [])
2757          for i in array_ops.split(array_ops.shape(b), 3)
2758      ]
2759    b = array_ops.reshape(b, [x, y * z])
2760    prod = math_ops.matmul(a, b, transpose_a=tr_a)
2761    prod = array_ops.reshape(prod, [-1, y, z])
2762    prod = array_ops.transpose(prod, [1, 0, 2])
2763    return wrap(prod, True)
2764
2765
2766# TODO(rmlarsen): Use the converter of BatchMatMulV2 once compatibility window
2767# is met.
2768@RegisterPFor("BatchMatMul")
2769def _convert_batch_mat_mul(pfor_input):
2770  # TODO(agarwal): There may be a more efficient way to do this instead of
2771  # stacking the inputs.
2772  pfor_input.stack_inputs()
2773  x = pfor_input.stacked_input(0)
2774  y = pfor_input.stacked_input(1)
2775  adj_x = pfor_input.get_attr("adj_x")
2776  adj_y = pfor_input.get_attr("adj_y")
2777
2778  x = _flatten_first_two_dims(x)
2779  y = _flatten_first_two_dims(y)
2780  output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
2781  output = _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector)
2782  return wrap(output, True)
2783
2784
2785@RegisterPFor("BatchMatMulV2")
2786def _convert_batch_mat_mul_v2(pfor_input):
2787  pfor_input.expanddim_inputs_for_broadcast()
2788  x = pfor_input.input(0)[0]
2789  y = pfor_input.input(1)[0]
2790  adj_x = pfor_input.get_attr("adj_x")
2791  adj_y = pfor_input.get_attr("adj_y")
2792
2793  output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
2794  return wrap(output, True)
2795
2796
2797@RegisterPForWithArgs("Sum", math_ops.reduce_sum)
2798@RegisterPForWithArgs("Prod", math_ops.reduce_prod)
2799@RegisterPForWithArgs("Max", math_ops.reduce_max)
2800@RegisterPForWithArgs("Min", math_ops.reduce_min)
2801@RegisterPForWithArgs("Mean", math_ops.reduce_mean)
2802@RegisterPForWithArgs("All", math_ops.reduce_all)
2803@RegisterPForWithArgs("Any", math_ops.reduce_any)
2804def _convert_reduction(pfor_input, _, op_func):
2805  t = pfor_input.stacked_input(0)
2806  indices = pfor_input.unstacked_input(1)
2807  # Shift positive indices by one to account for the extra dimension.
2808  indices += math_ops.cast(indices >= 0, indices.dtype)
2809  keep_dims = pfor_input.get_attr("keep_dims")
2810  return wrap(op_func(t, indices, keepdims=keep_dims), True)
2811
2812
2813@RegisterPForWithArgs("ArgMax", math_ops.argmax)
2814@RegisterPForWithArgs("ArgMin", math_ops.argmin)
2815def _convert_argmax_argmin(pfor_input, _, op_func):
2816  t = pfor_input.stacked_input(0)
2817  dimension = pfor_input.unstacked_input(1)
2818  dimension += math_ops.cast(dimension >= 0, dimension.dtype)
2819  output_type = pfor_input.get_attr("output_type")
2820  return wrap(op_func(t, axis=dimension, output_type=output_type), True)
2821
2822
2823@RegisterPFor("Bucketize")
2824def _convert_bucketize(pfor_input):
2825  t = pfor_input.stacked_input(0)
2826  boundaries = pfor_input.get_attr("boundaries")
2827  return wrap(math_ops.bucketize(t, boundaries), True)
2828
2829
2830@RegisterPFor("ClipByValue")
2831def _convert_clip_by_value(pfor_input):
2832  t = pfor_input.stacked_input(0)
2833  clip_value_min = pfor_input.unstacked_input(1)
2834  clip_value_max = pfor_input.unstacked_input(2)
2835  return wrap(gen_math_ops.clip_by_value(t, clip_value_min, clip_value_max),
2836              True)
2837
2838
2839@RegisterPForWithArgs("Cumsum", math_ops.cumsum)
2840@RegisterPForWithArgs("Cumprod", math_ops.cumprod)
2841def _convert_cumfoo(pfor_input, _, op_func):
2842  t = pfor_input.stacked_input(0)
2843  axis = pfor_input.unstacked_input(1)
2844  # Shift positive indices by one to account for the extra dimension.
2845  axis += math_ops.cast(axis >= 0, axis.dtype)
2846  exclusive = pfor_input.get_attr("exclusive")
2847  reverse = pfor_input.get_attr("reverse")
2848  return wrap(op_func(t, axis, exclusive=exclusive, reverse=reverse), True)
2849
2850
2851@RegisterPFor("BiasAdd")
2852def _convert_biasadd(pfor_input):
2853  t, t_stacked, _ = pfor_input.input(0)
2854  bias, bias_stacked, _ = pfor_input.input(1)
2855  data_format = pfor_input.get_attr("data_format").decode()
2856  if bias_stacked:
2857    # BiasAdd only supports 1-D biases, so cast bias to match value and use Add.
2858    pfor_input.expanddim_inputs_for_broadcast()
2859    t, _, _ = pfor_input.input(0)
2860    bias = math_ops.cast(pfor_input.stacked_input(1), t.dtype)
2861    if compat.as_bytes(data_format) == b"NCHW":
2862      b_shape = array_ops.shape(bias)
2863      new_b_shape = array_ops.concat(
2864          [b_shape[:-3], b_shape[-1:], b_shape[-3:-1]], axis=0)
2865      bias = array_ops.reshape(bias, new_b_shape)
2866    return wrap(math_ops.add(t, bias), True)
2867  else:
2868    assert t_stacked, "At least one input to BiasAdd should be loop variant."
2869    if compat.as_bytes(data_format) == b"NCHW":
2870      shape = array_ops.shape(t)
2871      flattened_shape = array_ops.concat([[-1], shape[2:]], axis=0)
2872      t = array_ops.reshape(t, flattened_shape)
2873      t = nn_ops.bias_add(t, bias, data_format="NCHW")
2874      t = array_ops.reshape(t, shape)
2875      return wrap(t, True)
2876    return wrap(nn_ops.bias_add(t, bias, data_format=data_format), True)
2877
2878
2879@RegisterPForWithArgs("UnsortedSegmentSum", math_ops.unsorted_segment_sum)
2880@RegisterPForWithArgs("UnsortedSegmentMax", math_ops.unsorted_segment_max)
2881@RegisterPForWithArgs("UnsortedSegmentMin", math_ops.unsorted_segment_min)
2882@RegisterPForWithArgs("UnsortedSegmentProd", math_ops.unsorted_segment_prod)
2883def _convert_unsortedsegmentsum(pfor_input, _, op_func):
2884  pfor_input.stack_inputs([0, 1])
2885  data = pfor_input.stacked_input(0)
2886  segment_ids = pfor_input.stacked_input(1)
2887  # TODO(agarwal): handle stacked?
2888  num_segments = pfor_input.unstacked_input(2)
2889  if segment_ids.dtype != num_segments.dtype:
2890    segment_ids = math_ops.cast(segment_ids, dtypes.int64)
2891    num_segments = math_ops.cast(num_segments, dtypes.int64)
2892  dtype = segment_ids.dtype
2893  segment_shape = array_ops.shape(segment_ids, out_type=dtype)
2894  n = segment_shape[0]
2895  ones = array_ops.ones_like(segment_shape, dtype=dtype)[1:]
2896  segment_offset = num_segments * math_ops.range(n, dtype=dtype)
2897  segment_offset = array_ops.reshape(segment_offset,
2898                                     array_ops.concat([[n], ones], axis=0))
2899  segment_ids += segment_offset
2900  num_segments = math_ops.cast(num_segments, dtypes.int64) * math_ops.cast(
2901      n, dtypes.int64)
2902  output = op_func(data, segment_ids, num_segments)
2903  new_output_shape = array_ops.concat(
2904      [[n, -1], array_ops.shape(output)[1:]], axis=0)
2905  output = array_ops.reshape(output, new_output_shape)
2906  return wrap(output, True)
2907
2908
2909def _flatten_array_with_offset(ids, offset_delta, num_rows):
2910  """Flattens a rank 2 tensor, adding an offset to each row."""
2911  # Note that if `ids` is rank 1, it is broadcast to rank 2.
2912  offset_delta = math_ops.cast(offset_delta, ids.dtype)
2913  n = math_ops.cast(num_rows, dtype=ids.dtype)
2914  offsets = math_ops.range(
2915      start=0, limit=n * offset_delta, delta=offset_delta, dtype=ids.dtype)
2916  offsets = array_ops.expand_dims(offsets, -1)
2917  ids += offsets
2918  return array_ops.reshape(ids, [-1])
2919
2920
2921@RegisterPForWithArgs("SparseSegmentSum", math_ops.sparse_segment_sum_v2)
2922@RegisterPForWithArgs("SparseSegmentMean", math_ops.sparse_segment_mean_v2)
2923@RegisterPForWithArgs("SparseSegmentSqrtN", math_ops.sparse_segment_sqrt_n_v2)
2924@RegisterPForWithArgs("SparseSegmentSumWithNumSegments",
2925                      math_ops.sparse_segment_sum_v2)
2926@RegisterPForWithArgs("SparseSegmentMeanWithNumSegments",
2927                      math_ops.sparse_segment_mean_v2)
2928@RegisterPForWithArgs("SparseSegmentSqrtNWithNumSegments",
2929                      math_ops.sparse_segment_sqrt_n_v2)
2930def _convert_sparse_segment(pfor_input, _, op_func):
2931  _, segment_ids_stacked, _ = pfor_input.input(2)
2932  if segment_ids_stacked:
2933    pfor_input.stack_inputs([1])
2934  data, data_stacked, _ = pfor_input.input(0)
2935  indices, _, _ = pfor_input.input(1)
2936  num_inputs = len(pfor_input.inputs)
2937  assert num_inputs in (3, 4)
2938  if num_inputs == 3:
2939    # `segment_ids` needs to be unstacked since otherwise output sizes could
2940    # differ across pfor iterations.
2941    segment_ids = pfor_input.unstacked_input(2)
2942    num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1)
2943  else:
2944    segment_ids, _, _ = pfor_input.input(2)
2945    num_segments = pfor_input.unstacked_input(3)
2946
2947  n = pfor_input.pfor.loop_len_vector[0]
2948  if data_stacked:
2949    indices = _flatten_array_with_offset(indices, array_ops.shape(data)[1], n)
2950    data = _flatten_first_two_dims(data)
2951  else:
2952    indices = array_ops.reshape(indices, [-1])
2953  segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n)
2954
2955  if num_inputs == 3:
2956    num_segments = None
2957  else:
2958    num_segments *= n
2959  output = op_func(data, indices, segment_ids, num_segments=num_segments)
2960  output = _unflatten_first_dim(output, [n])
2961  return wrap(output, True)
2962
2963
2964@RegisterPForWithArgs("SparseSegmentSumGrad", math_ops.sparse_segment_sum_grad)
2965@RegisterPForWithArgs("SparseSegmentMeanGrad",
2966                      math_ops.sparse_segment_mean_grad)
2967@RegisterPForWithArgs("SparseSegmentSqrtNGrad",
2968                      math_ops.sparse_segment_sqrt_n_grad)
2969def _convert_sparse_segment_grad(pfor_input, _, op_func):
2970  grad = pfor_input.stacked_input(0)
2971  indices = pfor_input.unstacked_input(1)
2972  segment_ids = pfor_input.unstacked_input(2)
2973  dim0 = pfor_input.unstacked_input(3)
2974
2975  n = pfor_input.pfor.loop_len_vector[0]
2976  indices = _flatten_array_with_offset(indices, dim0, n)
2977  num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1)
2978  segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n)
2979  grad = _flatten_first_two_dims(grad)
2980  dim0 *= n
2981  output = op_func(grad, indices, segment_ids, dim0)
2982  output = _unflatten_first_dim(output, [n])
2983  return wrap(output, True)
2984
2985
2986@RegisterPFor("Cast")
2987def _convert_cast(pfor_input):
2988  inp = pfor_input.stacked_input(0)
2989  dtype = pfor_input.get_attr("DstT")
2990  return wrap(math_ops.cast(inp, dtype), True)
2991
2992
2993@RegisterPFor("Abs")
2994@RegisterPFor("Acos")
2995@RegisterPFor("Acosh")
2996@RegisterPFor("Add")
2997@RegisterPFor("AddV2")
2998@RegisterPFor("Angle")
2999@RegisterPFor("Asin")
3000@RegisterPFor("Asinh")
3001@RegisterPFor("Atan")
3002@RegisterPFor("Atan2")
3003@RegisterPFor("Atanh")
3004@RegisterPFor("BesselI0")
3005@RegisterPFor("BesselI1")
3006@RegisterPFor("BesselI0e")
3007@RegisterPFor("BesselI1e")
3008@RegisterPFor("BesselK0")
3009@RegisterPFor("BesselK1")
3010@RegisterPFor("BesselK0e")
3011@RegisterPFor("BesselK1e")
3012@RegisterPFor("BesselJ0")
3013@RegisterPFor("BesselJ1")
3014@RegisterPFor("BesselY0")
3015@RegisterPFor("BesselY1")
3016@RegisterPFor("BitwiseAnd")
3017@RegisterPFor("BitwiseOr")
3018@RegisterPFor("BitwiseXor")
3019@RegisterPFor("Ceil")
3020@RegisterPFor("Complex")
3021@RegisterPFor("ComplexAbs")
3022@RegisterPFor("Conj")
3023@RegisterPFor("Cos")
3024@RegisterPFor("Cosh")
3025@RegisterPFor("Dawsn")
3026@RegisterPFor("Digamma")
3027@RegisterPFor("Div")
3028@RegisterPFor("DivNoNan")
3029@RegisterPFor("Elu")
3030@RegisterPFor("Erf")
3031@RegisterPFor("Erfc")
3032@RegisterPFor("Erfinv")
3033@RegisterPFor("Exp")
3034@RegisterPFor("Expint")
3035@RegisterPFor("Expm1")
3036@RegisterPFor("Floor")
3037@RegisterPFor("FloorDiv")
3038@RegisterPFor("FloorMod")
3039@RegisterPFor("FresnelCos")
3040@RegisterPFor("FresnelSin")
3041@RegisterPFor("Greater")
3042@RegisterPFor("GreaterEqual")
3043@RegisterPFor("Igamma")
3044@RegisterPFor("IgammaGradA")
3045@RegisterPFor("Igammac")
3046@RegisterPFor("Imag")
3047@RegisterPFor("Inv")
3048@RegisterPFor("Invert")
3049@RegisterPFor("IsFinite")
3050@RegisterPFor("IsInf")
3051@RegisterPFor("IsNan")
3052@RegisterPFor("LeftShift")
3053@RegisterPFor("Less")
3054@RegisterPFor("LessEqual")
3055@RegisterPFor("Lgamma")
3056@RegisterPFor("Log")
3057@RegisterPFor("Log1p")
3058@RegisterPFor("LogicalAnd")
3059@RegisterPFor("LogicalNot")
3060@RegisterPFor("LogicalOr")
3061@RegisterPFor("LogicalXor")
3062@RegisterPFor("Maximum")
3063@RegisterPFor("Minimum")
3064@RegisterPFor("Mod")
3065@RegisterPFor("Mul")
3066@RegisterPFor("MulNoNan")
3067@RegisterPFor("Ndtri")
3068@RegisterPFor("Neg")
3069@RegisterPFor("Polygamma")
3070@RegisterPFor("Pow")
3071@RegisterPFor("Real")
3072@RegisterPFor("RealDiv")
3073@RegisterPFor("Reciprocal")
3074@RegisterPFor("Relu")
3075@RegisterPFor("Relu6")
3076@RegisterPFor("RightShift")
3077@RegisterPFor("Rint")
3078@RegisterPFor("Round")
3079@RegisterPFor("Rsqrt")
3080@RegisterPFor("Selu")
3081@RegisterPFor("Sigmoid")
3082@RegisterPFor("Sign")
3083@RegisterPFor("Sin")
3084@RegisterPFor("Sinh")
3085@RegisterPFor("Softplus")
3086@RegisterPFor("Softsign")
3087@RegisterPFor("Spence")
3088@RegisterPFor("Sqrt")
3089@RegisterPFor("Square")
3090@RegisterPFor("SquaredDifference")
3091@RegisterPFor("Sub")
3092@RegisterPFor("Tan")
3093@RegisterPFor("Tanh")
3094@RegisterPFor("TruncateDiv")
3095@RegisterPFor("TruncateMod")
3096@RegisterPFor("Xdivy")
3097@RegisterPFor("Xlogy")
3098@RegisterPFor("Xlog1py")
3099@RegisterPFor("Zeta")
3100def _convert_cwise(pfor_input):
3101  if pfor_input.num_inputs > 1:
3102    pfor_input.expanddim_inputs_for_broadcast()
3103
3104  out = _create_op(
3105      pfor_input.op_type, [x.t for x in pfor_input.inputs],
3106      [x.dtype for x in pfor_input.outputs],
3107      attrs=pfor_input.op.node_def.attr).outputs
3108  assert len(out) == 1
3109  out = out[0]
3110
3111  op_output = wrap(out, True)
3112  return op_output
3113
3114
3115@RegisterPFor("XlaSharding")
3116def _convert_xla_sharding(pfor_input):
3117  t = pfor_input.stacked_input(0)
3118  sharding = pfor_input.get_attr("sharding")
3119  return wrap(xla.sharding(t, sharding=sharding), True)
3120
3121
3122@RegisterPFor("LeakyRelu")
3123def _convert_leaky_relu(pfor_input):
3124  t = pfor_input.stacked_input(0)
3125  alpha = pfor_input.get_attr("alpha")
3126  return wrap(gen_nn_ops.leaky_relu(t, alpha=alpha), True)
3127
3128
3129@RegisterPFor("Equal")
3130def _convert_equal(pfor_input):
3131  pfor_input.expanddim_inputs_for_broadcast()
3132  x = pfor_input.input(0)[0]
3133  y = pfor_input.input(1)[0]
3134  incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error")
3135  return wrap(gen_math_ops.equal(
3136      x, y, incompatible_shape_error=incompatible_shape_error), True)
3137
3138
3139@RegisterPFor("NotEqual")
3140def _convert_not_equal(pfor_input):
3141  pfor_input.expanddim_inputs_for_broadcast()
3142  x = pfor_input.input(0)[0]
3143  y = pfor_input.input(1)[0]
3144  incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error")
3145  return wrap(gen_math_ops.not_equal(
3146      x, y, incompatible_shape_error=incompatible_shape_error), True)
3147
3148
3149@RegisterPFor("ApproximateEqual")
3150def _convert_approximate_equal(pfor_input):
3151  pfor_input.expanddim_inputs_for_broadcast()
3152  x = pfor_input.input(0)[0]
3153  y = pfor_input.input(1)[0]
3154  tolerance = pfor_input.get_attr("tolerance")
3155  return wrap(math_ops.approximate_equal(x, y, tolerance=tolerance), True)
3156
3157
3158@RegisterPFor("Shape")
3159def _convert_shape(pfor_input):
3160  out_type = pfor_input.get_attr("out_type")
3161  return wrap(
3162      array_ops.shape(pfor_input.stacked_input(0), out_type=out_type)[1:],
3163      False)
3164
3165
3166@RegisterPFor("ShapeN")
3167def _convert_shape_n(pfor_input):
3168  out_type = pfor_input.get_attr("out_type")
3169  shapes = [
3170      array_ops.shape(x, out_type=out_type)[1:] if stacked else array_ops.shape(
3171          x, out_type=out_type) for x, stacked, _ in pfor_input.inputs
3172  ]
3173  return [wrap(x, False) for x in shapes]
3174
3175
3176@RegisterPFor("Size")
3177def _convert_size(pfor_input):
3178  out_type = pfor_input.get_attr("out_type")
3179  n = math_ops.cast(pfor_input.pfor.loop_len_vector[0], out_type)
3180  return wrap(
3181      array_ops.size(pfor_input.stacked_input(0), out_type=out_type) // n,
3182      False)
3183
3184
3185@RegisterPFor("Rank")
3186def _convert_rank(pfor_input):
3187  return wrap(array_ops.rank(pfor_input.stacked_input(0)) - 1, False)
3188
3189
3190@RegisterPFor("AddN")
3191def _convert_addn(pfor_input):
3192  # AddN does not support broadcasting.
3193  pfor_input.stack_inputs(tile_variants=False)
3194  return _wrap_and_tile_variants(
3195      math_ops.add_n([x.t for x in pfor_input.inputs]),
3196      pfor_input.pfor.loop_len_vector)
3197
3198
3199@RegisterPFor("Cross")
3200def _convert_cross(pfor_input):
3201  pfor_input.stack_inputs()
3202  a = pfor_input.stacked_input(0)
3203  b = pfor_input.stacked_input(1)
3204  return wrap(math_ops.cross(a, b), True)
3205
3206
3207@RegisterPFor("BiasAddGrad")
3208def _convert_biasaddgrad(pfor_input):
3209  grad = pfor_input.stacked_input(0)
3210  fmt = pfor_input.get_attr("data_format")
3211  if fmt == b"NCHW":
3212    output = math_ops.reduce_sum(grad, axis=[1, 3, 4], keepdims=False)
3213  else:
3214    grad_shape = array_ops.shape(grad)
3215    last_dim_shape = grad_shape[-1]
3216    first_dim_shape = grad_shape[0]
3217    output = array_ops.reshape(grad, [first_dim_shape, -1, last_dim_shape])
3218    output = math_ops.reduce_sum(output, axis=[1], keepdims=False)
3219  return wrap(output, True)
3220
3221
3222# Some required ops are not exposed under the tf namespace. Hence relying on
3223# _create_op to create them.
3224@RegisterPForWithArgs("EluGrad")
3225@RegisterPForWithArgs("LeakyReluGrad")
3226@RegisterPForWithArgs("ReciprocalGrad")
3227@RegisterPForWithArgs("Relu6Grad")
3228@RegisterPForWithArgs("ReluGrad")
3229@RegisterPForWithArgs("RsqrtGrad")
3230@RegisterPForWithArgs("SeluGrad")
3231@RegisterPForWithArgs("SigmoidGrad")
3232@RegisterPForWithArgs("SoftplusGrad")
3233@RegisterPForWithArgs("SoftsignGrad")
3234@RegisterPForWithArgs("SqrtGrad")
3235@RegisterPForWithArgs("TanhGrad")
3236def _convert_grads(pfor_input, op_type, *args, **kw_args):
3237  del args
3238  del kw_args
3239  # TODO(agarwal): Looks like these ops don't support broadcasting. Hence we
3240  # have to use tiling here.
3241  pfor_input.stack_inputs()
3242  outputs = _create_op(
3243      op_type, [x.t for x in pfor_input.inputs],
3244      [x.dtype for x in pfor_input.outputs],
3245      attrs=pfor_input.op.node_def.attr).outputs
3246  return [wrap(x, True) for x in outputs]
3247
3248
3249@RegisterPFor("Select")
3250def _convert_select(pfor_input):
3251  pfor_input.stack_inputs()
3252  cond = pfor_input.stacked_input(0)
3253  t = pfor_input.stacked_input(1)
3254  e = pfor_input.stacked_input(2)
3255  cond_rank = array_ops.rank(cond)
3256  cond, t, e = smart_cond.smart_cond(
3257      cond_rank > 1, lambda: _inputs_with_flattening(pfor_input, [0, 1, 2]),
3258      lambda: [cond, t, e])
3259  outputs = _create_op(
3260      pfor_input.op_type, [cond, t, e], [x.dtype for x in pfor_input.outputs],
3261      attrs=pfor_input.op.node_def.attr).outputs
3262  n = pfor_input.pfor.loop_len_vector
3263  out = smart_cond.smart_cond(cond_rank > 1,
3264                              lambda: _unflatten_first_dim(outputs[0], n),
3265                              lambda: outputs[0])
3266  return [wrap(out, True) for x in outputs]
3267
3268
3269@RegisterPFor("SelectV2")
3270def _convert_selectv2(pfor_input):
3271  pfor_input.expanddim_inputs_for_broadcast()
3272  cond = pfor_input.input(0)[0]
3273  t = pfor_input.input(1)[0]
3274  e = pfor_input.input(2)[0]
3275  out = array_ops.where_v2(cond, t, e)
3276  return wrap(out, True)
3277
3278
3279# random_ops
3280
3281
3282def _transpose_dim_to_front(x, dim):
3283  rank = array_ops.rank(x)
3284  return array_ops.transpose(
3285      x,
3286      perm=array_ops.concat(
3287          [[dim], math_ops.range(0, dim),
3288           math_ops.range(dim + 1, rank)],
3289          axis=0))
3290
3291
3292@RegisterPForWithArgs("RandomUniform")
3293@RegisterPForWithArgs("RandomUniformInt")
3294@RegisterPForWithArgs("RandomStandardNormal")
3295@RegisterPForWithArgs("TruncatedNormal")
3296def _convert_random(pfor_input, op_type, *args, **kw_args):
3297  del args
3298  del kw_args
3299  inputs = [pfor_input.unstacked_input(i) for i in range(pfor_input.num_inputs)]
3300  # inputs[0] is "shape"
3301  inputs[0] = array_ops.concat([pfor_input.pfor.loop_len_vector, inputs[0]],
3302                               axis=0)
3303  # TODO(b/222761732): Turn this warning back on when legacy RNGs are
3304  #   deprecated.
3305  # logging.warning(
3306  #     "Note that %s inside pfor op may not give same output as "
3307  #     "inside a sequential loop.", op_type)
3308  outputs = _create_op(
3309      op_type,
3310      inputs, [x.dtype for x in pfor_input.outputs],
3311      attrs=pfor_input.op.node_def.attr).outputs
3312  return [wrap(x, True) for x in outputs]
3313
3314
3315@RegisterPFor("RandomGamma")
3316@RegisterPFor("RandomPoissonV2")
3317def _convert_random_with_param(pfor_input):
3318  shape = pfor_input.unstacked_input(0)
3319  # param is lam (Poisson rate) or alpha (Gamma shape).
3320  param, param_stacked, _ = pfor_input.input(1)
3321  # TODO(b/222761732): Turn this warning back on when legacy RNGs are
3322  #   deprecated.
3323  # logging.warning(
3324  #     "Note that %s inside pfor op may not give same output as "
3325  #     "inside a sequential loop.", pfor_input.op_type)
3326
3327  if param_stacked:
3328    samples = _create_op(
3329        pfor_input.op_type,
3330        inputs=[shape, param],
3331        op_dtypes=[x.dtype for x in pfor_input.outputs],
3332        attrs=pfor_input.op.node_def.attr).outputs[0]
3333    loop_dim = array_ops.shape(shape)[0]
3334    stacked_samples = _transpose_dim_to_front(samples, loop_dim)
3335  else:
3336    shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0)
3337    stacked_samples = _create_op(
3338        pfor_input.op_type,
3339        inputs=[shape, param],
3340        op_dtypes=[x.dtype for x in pfor_input.outputs],
3341        attrs=pfor_input.op.node_def.attr).outputs[0]
3342
3343  return wrap(stacked_samples, True)
3344
3345
3346@RegisterPFor("Multinomial")
3347def _convert_multinomial(pfor_input):
3348  logits, logits_stacked, _ = pfor_input.input(0)
3349  num_samples = pfor_input.unstacked_input(1)
3350  seed = pfor_input.get_attr("seed")
3351  seed2 = pfor_input.get_attr("seed2")
3352  output_dtype = pfor_input.get_attr("output_dtype")
3353  # TODO(b/222761732): Turn this warning back on when legacy RNGs are
3354  #   deprecated.
3355  # logging.warning(
3356  #     "Note that Multinomial inside pfor op may not give same output as "
3357  #     "inside a sequential loop.")
3358
3359  n = pfor_input.pfor.loop_len_vector[0]
3360  if logits_stacked:
3361    flattened_logits = _flatten_first_two_dims(logits)
3362    samples = gen_random_ops.multinomial(
3363        flattened_logits,
3364        num_samples,
3365        seed=seed,
3366        seed2=seed2,
3367        output_dtype=output_dtype)
3368    stacked_samples = _unflatten_first_dim(samples, [n])
3369  else:
3370    samples = gen_random_ops.multinomial(
3371        logits,
3372        num_samples * n,
3373        seed=seed,
3374        seed2=seed2,
3375        output_dtype=output_dtype)
3376    stacked_samples = array_ops.transpose(
3377        array_ops.reshape(samples, [-1, n, num_samples]), [1, 0, 2])
3378
3379  return wrap(stacked_samples, True)
3380
3381
3382@RegisterPFor("StatelessMultinomial")
3383@RegisterPFor("StatelessParameterizedTruncatedNormal")
3384@RegisterPFor("StatelessRandomBinomial")
3385@RegisterPFor("StatelessRandomGammaV2")
3386@RegisterPFor("StatelessRandomNormal")
3387@RegisterPFor("StatelessRandomPoisson")
3388@RegisterPFor("StatelessRandomUniform")
3389@RegisterPFor("StatelessRandomUniformInt")
3390@RegisterPFor("StatelessRandomUniformFullInt")
3391@RegisterPFor("StatelessTruncatedNormal")
3392def _convert_stateless_multinomial(pfor_input):
3393  # Unlike stateful random ops, for stateless ones we want better
3394  # reproducibility based on seed. Hence we don't want to use a similar strategy
3395  # as used for stateful ones where we generate a possibly different set of
3396  # random numbers under vectorization.
3397  # Unfortunately, the kernels currently are not necessarily setup to do this
3398  # efficiently and hence we fallback to a sequential loop for vectorization.
3399  return _fallback_converter(pfor_input, warn=False)
3400
3401
3402# linalg_ops
3403
3404
3405@RegisterPForWithArgs("XlaEinsum")
3406@RegisterPForWithArgs("Einsum")
3407def _convert_einsum(pfor_input, op_type):
3408  # Einsum may have either 1 or 2 inputs.
3409  inputs, input_stacked, _ = zip(*[
3410      pfor_input.input(i)
3411      for i in range(pfor_input.num_inputs)])
3412
3413  # Parse the einsum equation.
3414  equation = pfor_input.get_attr("equation").decode("utf-8")
3415  input_expr, output_expr = equation.split("->")
3416  input_exprs = input_expr.split(",")
3417
3418  # Pick a placeholder symbol to use for the new axis.
3419  chosen_symbol = None
3420  for s in string.ascii_letters:
3421    if s in equation:
3422      continue
3423    else:
3424      chosen_symbol = s
3425      break
3426
3427  if chosen_symbol is None:
3428    raise ValueError("Could not figure out what symbol to use for new axis.")
3429
3430  assert any(input_stacked)
3431  for i in range(len(inputs)):
3432    if input_stacked[i]:
3433      input_exprs[i] = "{}{}".format(chosen_symbol, input_exprs[i])
3434  output_expr = "{}{}".format(chosen_symbol, output_expr)
3435
3436  new_equation = "{}->{}".format(",".join(input_exprs), output_expr)
3437
3438  if op_type == "XlaEinsum":
3439    if len(inputs) == 1:
3440      result = xla.einsum(equation=new_equation, a=inputs[0])
3441    else:
3442      result = xla.einsum(equation=new_equation, a=inputs[0], b=inputs[1])
3443  else:
3444    assert op_type == "Einsum"
3445    result = special_math_ops.einsum(new_equation, *inputs)
3446
3447  return wrap(result, True)
3448
3449
3450@RegisterPFor("Cholesky")
3451def _convert_cholesky(pfor_input):
3452  t = pfor_input.stacked_input(0)
3453  return wrap(linalg_ops.cholesky(t), True)
3454
3455
3456@RegisterPFor("LogMatrixDeterminant")
3457def _convert_log_matrix_determinant(pfor_input):
3458  t = pfor_input.stacked_input(0)
3459  return [wrap(x, True) for x in linalg_ops.log_matrix_determinant(t)]
3460
3461
3462@RegisterPFor("MatrixInverse")
3463def _convert_matrix_inverse(pfor_input):
3464  t = pfor_input.stacked_input(0)
3465  adjoint = pfor_input.get_attr("adjoint")
3466  return wrap(gen_linalg_ops.matrix_inverse(t, adjoint=adjoint), True)
3467
3468
3469@RegisterPFor("MatrixSolve")
3470def _convert_matrix_solve(pfor_input):
3471  pfor_input.stack_inputs()
3472  matrix = pfor_input.stacked_input(0)
3473  rhs = pfor_input.stacked_input(1)
3474  adjoint = pfor_input.get_attr("adjoint")
3475  output = gen_linalg_ops.matrix_solve(
3476      matrix, rhs, adjoint=adjoint)
3477  return wrap(output, True)
3478
3479
3480@RegisterPFor("MatrixTriangularSolve")
3481def _convert_matrix_triangular_solve(pfor_input):
3482  pfor_input.expanddim_inputs_for_broadcast()
3483  matrix = pfor_input.input(0)[0]
3484  rhs = pfor_input.input(1)[0]
3485  lower = pfor_input.get_attr("lower")
3486  adjoint = pfor_input.get_attr("adjoint")
3487  output = linalg_ops.matrix_triangular_solve(
3488      matrix, rhs, lower=lower, adjoint=adjoint)
3489  return wrap(output, True)
3490
3491
3492@RegisterPFor("SelfAdjointEigV2")
3493def _convert_self_adjoint_eig(pfor_input):
3494  t = pfor_input.stacked_input(0)
3495  compute_v = pfor_input.get_attr("compute_v")
3496  e, v = gen_linalg_ops.self_adjoint_eig_v2(t, compute_v=compute_v)
3497  # If compute_v is False, v will have shape [0].
3498  return wrap(e, True), wrap(v, compute_v)
3499
3500
3501# logging_ops
3502
3503
3504@RegisterPFor("Assert")
3505def _convert_assert(pfor_input):
3506  cond, cond_stacked, _ = pfor_input.input(0)
3507  if cond_stacked:
3508    cond = math_ops.reduce_all(cond)
3509
3510  data_list = [x.t for x in pfor_input.inputs][1:]
3511  return _create_op(
3512      "Assert", [cond] + data_list, [], attrs=pfor_input.op.node_def.attr)
3513
3514
3515@RegisterPFor("Print")
3516def _convert_print(pfor_input):
3517  # Note that we don't stack all the inputs. Hence unstacked values are printed
3518  # once here vs multiple times in a while_loop.
3519  pfor_input.stack_inputs([0])
3520  outputs = _create_op(
3521      "Print", [x.t for x in pfor_input.inputs],
3522      [x.dtype for x in pfor_input.outputs],
3523      attrs=pfor_input.op.node_def.attr).outputs
3524  return [wrap(x, True) for x in outputs]
3525
3526
3527@RegisterPFor("PrintV2")
3528def _convert_print_v2(pfor_input):
3529  # Print the full input Tensor(s), including the batch dimension if stacked.
3530  return _create_op(
3531      "PrintV2", [x.t for x in pfor_input.inputs],
3532      [x.dtype for x in pfor_input.outputs],
3533      attrs=pfor_input.op.node_def.attr)
3534
3535
3536@RegisterPFor("StringFormat")
3537def _convert_string_format(pfor_input):
3538  # Format using the full input Tensor(s), including the batch dimension if
3539  # stacked.
3540  op = _create_op(
3541      "StringFormat", [x.t for x in pfor_input.inputs],
3542      [x.dtype for x in pfor_input.outputs],
3543      attrs=pfor_input.op.node_def.attr)
3544  return [wrap(output, False) for output in op.outputs]
3545
3546
3547# data_flow_ops
3548
3549# TensorArray conversion is tricky since we don't support arrays of
3550# TensorArrays. For converting them, we consider two distinct cases:
3551#
3552# 1. The array is constructed outside the pfor call, and read/written inside the
3553# loop.
3554# This is an easier case since we don't need to make an array of TensorArrays.
3555# A correctness requirement is that these parallel iterations shouldn't attempt
3556# to write to the same location. Hence at conversion time we disallow indices to
3557# be loop-invariant as that would guarantee a collision. Even if the indices are
3558# not loop-invariant, they could conflict and that shall trigger runtime errors.
3559#
3560# 2. The array is constructed and used entirely inside each pfor iteration.
3561# For simplicity, here we require that the indices used for write/scatter are
3562# "unstacked". Otherwise it becomes hard to merge the TensorArrays created in
3563# different pfor iterations. We consider two sub_cases:
3564#
3565# 2a Elements written to the array are "stacked"
3566# To simulate multiple TensorArrays, we may increase the dimension of each
3567# element of the array. i.e. the i_th row of the j_th entry of the converted
3568# TensorArray corresponds to the j_th entry of the TensorArray in the i_th
3569# pfor iteration.
3570#
3571# 2b Elements written to the array are "unstacked"
3572# In this case we don't increase the dimensions to avoid redundant tiling. Each
3573# iteration is trying to write the same value. So we convert that to a single
3574# write.
3575#
3576# Here are some tricks used to implement the above:
3577# - TensorArrayV3 constructor encodes the element shape as an attr. Instead of
3578# trying to trace whether future writes are stacked or unstacked in order to set
3579# this attr, we set it to correspond to unknown shape.
3580# - We use the "flow" output of the different ops to track whether the array
3581# elements are stacked or unstacked. If a stacked write/scatter is done, we make
3582# the flow stacked as well.
3583# - We use some heuristic traversal of the graph to track whether the
3584# TensorArray handle was created inside or outside the pfor loop.
3585
3586
3587@RegisterPFor("TensorArrayV3")
3588def _convert_tensor_array_v3(pfor_input):
3589  size = pfor_input.unstacked_input(0)
3590  dtype = pfor_input.get_attr("dtype")
3591  dynamic_size = pfor_input.get_attr("dynamic_size")
3592  clear_after_read = pfor_input.get_attr("clear_after_read")
3593  identical_element_shapes = pfor_input.get_attr("identical_element_shapes")
3594  tensor_array_name = pfor_input.get_attr("tensor_array_name")
3595  handle, flow = data_flow_ops.tensor_array_v3(
3596      size,
3597      dtype=dtype,
3598      # We don't set element shape since we don't know if writes are stacked or
3599      # not yet.
3600      element_shape=None,
3601      dynamic_size=dynamic_size,
3602      clear_after_read=clear_after_read,
3603      identical_element_shapes=identical_element_shapes,
3604      tensor_array_name=tensor_array_name)
3605  # Note we keep flow unstacked for now since we don't know if writes will be
3606  # stacked or not.
3607  return wrap(handle, False), wrap(flow, False)
3608
3609
3610@RegisterPFor("TensorArraySizeV3")
3611def _convert_tensor_array_size_v3(pfor_input):
3612  handle = pfor_input.unstacked_input(0)
3613  flow, flow_stacked, _ = pfor_input.input(1)
3614  if flow_stacked:
3615    flow = _unstack_flow(flow)
3616  size = data_flow_ops.tensor_array_size_v3(handle, flow)
3617  return wrap(size, False)
3618
3619
3620def _handle_inside_pfor(pfor_input, handle):
3621  """Returns True if handle was created inside the pfor loop."""
3622  # We use some heuristic to find the original TensorArray creation op.
3623  # The logic should handle the common cases (except cond based subgraphs).
3624  # In theory the user could perform different operations on the handle (like
3625  # Reshape, stack multiple handles, etc) which could break this logic.
3626  # TODO(agarwal): handle Switch/Merge.
3627  while handle.op.type in ("Enter", "Identity"):
3628    handle = handle.op.inputs[0]
3629  if handle.op.type not in [
3630      "TensorArrayV3", "TensorArrayGradV3", "TensorArrayGradWithShape"
3631  ]:
3632    raise ValueError(f"Unable to find source for handle {handle}.")
3633  else:
3634    return pfor_input.pfor.op_is_inside_loop(handle.op)
3635
3636
3637def _unstack_flow(value):
3638  # TODO(agarwal): consider looking if this is a Tile op then get its input.
3639  # This may avoid running the Tile operations.
3640  return array_ops.gather(value, 0)
3641
3642
3643@RegisterPFor("TensorArrayReadV3")
3644def _convert_tensor_array_read_v3(pfor_input):
3645  handle = pfor_input.unstacked_input(0)
3646  index, index_stacked, _ = pfor_input.input(1)
3647  dtype = pfor_input.get_attr("dtype")
3648  flow, flow_stacked, _ = pfor_input.input(2)
3649  if flow_stacked:
3650    flow = _unstack_flow(flow)
3651
3652  is_inside_pfor = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
3653  if is_inside_pfor:
3654    # Note that if we are inside a control flow construct inside the pfor, and
3655    # only some of the iterations are doing the read (i.e.
3656    # `all_indices_partitioned` is True), then the read operation should only
3657    # return values for the currently active pfor iterations (`all_indices`
3658    # below). Hence, whenever the returned value is stacked (i.e. `flow` is
3659    # stacked), we may need to do an extra gather after reading the values. Also
3660    # note that if `is_inside` is false, then values in the tensor array are
3661    # unstacked. So the check is only needed in this branch.
3662    all_indices = pfor_input.pfor.all_indices
3663    all_indices_partitioned = pfor_input.pfor.all_indices_partitioned
3664    # Note: flow_stacked indicates if values in the TensorArray are stacked or
3665    # not.
3666    if index_stacked:
3667      if flow_stacked:
3668        raise ValueError(
3669            "It looks like TensorArrayReadV3 was called on a TensorArray whose"
3670            " values are not loop-invariant, and the read indices were also"
3671            " not loop invariant. This is currently unsupported.")
3672      value = data_flow_ops.tensor_array_gather_v3(
3673          handle, index, flow, dtype=dtype)
3674      return wrap(value, True)
3675    value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype)
3676    if flow_stacked and all_indices_partitioned:
3677      value = array_ops.gather(value, all_indices)
3678    return wrap(value, flow_stacked)
3679  # Values in the TensorArray should be unstacked (since different iterations
3680  # couldn't write to the same location). So whether output is stacked or not
3681  # depends on index_stacked.
3682  if index_stacked:
3683    value = data_flow_ops.tensor_array_gather_v3(
3684        handle, index, flow, dtype=dtype)
3685  else:
3686    value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype)
3687  return wrap(value, index_stacked)
3688
3689
3690@RegisterPFor("TensorArrayWriteV3")
3691def _convert_tensor_array_write_v3(pfor_input):
3692  handle = pfor_input.unstacked_input(0)
3693  index, index_stacked, _ = pfor_input.input(1)
3694  value, value_stacked, _ = pfor_input.input(2)
3695  flow, flow_stacked, _ = pfor_input.input(3)
3696  if value_stacked and pfor_input.pfor.all_indices_partitioned:
3697    # Looks like we are in a control flow in a pfor where not all iterations are
3698    # active now. We don't allow that since that could lead to different indices
3699    # having different shapes which will be hard to merge later.
3700    raise ValueError("Writing non loop invariant values to TensorArray from "
3701                     "inside a while_loop/cond not supported.")
3702  if flow_stacked:
3703    flow = _unstack_flow(flow)
3704  is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
3705  if is_inside:
3706    if index_stacked:
3707      raise ValueError(f"Need indices for {handle} to be loop invariant.")
3708    if not flow_stacked and not value_stacked:
3709      flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow)
3710      return wrap(flow_out, False)
3711    else:
3712      if not value_stacked:
3713        value = _stack(value, pfor_input.pfor.loop_len_vector).t
3714      # TODO(agarwal): Note that if flow is unstacked and value is stacked, then
3715      # this may or may not be a safe situation. flow is unstacked both for a
3716      # freshly created TensorArray, as well as after unstacked values are
3717      # written to it. If it is the latter, then we cannot write a stacked value
3718      # now since that may cause runtime errors due to different shapes in the
3719      # array. At the moment we are not able to handle this gracefully and
3720      # distinguish between the two cases. That would require some heuristic
3721      # traversal of the graph to figure out whether all the writes are
3722      # unstacked or not.
3723      flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow)
3724      return _stack(flow_out, pfor_input.pfor.loop_len_vector)
3725  else:
3726    if not index_stacked:
3727      raise ValueError(f"Need indices for {handle} to be not loop invariant.")
3728    # Note that even when index_stacked is true, actual values in index may
3729    # still not be unique. However that will cause runtime error when executing
3730    # the scatter operation below.
3731    if not value_stacked:
3732      value = _stack(value, pfor_input.pfor.loop_len_vector).t
3733    flow_out = data_flow_ops.tensor_array_scatter_v3(handle, index, value, flow)
3734    return _stack(flow_out, pfor_input.pfor.loop_len_vector)
3735
3736
3737def _transpose_first_two_dims(value):
3738  # TODO(agarwal): optimize if one of the dims == 1.
3739  value_shape = array_ops.shape(value)
3740  v0 = value_shape[0]
3741  v1 = value_shape[1]
3742  value = array_ops.reshape(value, [v0, v1, -1])
3743  value = array_ops.transpose(value, [1, 0, 2])
3744  new_shape = array_ops.concat([[v1, v0], value_shape[2:]], axis=0)
3745  return array_ops.reshape(value, new_shape)
3746
3747
3748@RegisterPFor("TensorArrayGatherV3")
3749def _convert_tensor_array_gather_v3(pfor_input):
3750  handle = pfor_input.unstacked_input(0)
3751  indices, indices_stacked, _ = pfor_input.input(1)
3752  indices = array_ops.reshape(indices, [-1])
3753  flow, flow_stacked, _ = pfor_input.input(2)
3754  if flow_stacked:
3755    flow = _unstack_flow(flow)
3756  dtype = pfor_input.get_attr("dtype")
3757  # TODO(agarwal): support element_shape attr?
3758
3759  n = pfor_input.pfor.loop_len_vector
3760  value = data_flow_ops.tensor_array_gather_v3(
3761      handle, indices, flow, dtype=dtype)
3762  is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
3763  if is_inside:
3764    # flow_stacked indicates if values in the TensorArray are stacked or not.
3765    if indices_stacked:
3766      if flow_stacked:
3767        raise ValueError(
3768            "It looks like TensorArrayGatherV3 was called on a TensorArray "
3769            "whose values are not loop-invariant, and the indices were also "
3770            "not loop invariant. This is currently unsupported.")
3771      else:
3772        value = _unflatten_first_dim(value, n)
3773        return wrap(value, True)
3774    else:
3775      if flow_stacked:
3776        # Since elements in this array are stacked and `value` was produced by
3777        # gather, its first two dims are "gathered elements" and "stack
3778        # dimension". Our semantics require these two to be flipped.
3779        value = _transpose_first_two_dims(value)
3780      return wrap(value, flow_stacked)
3781  else:
3782    # Values in the TensorArray should be unstacked (since different iterations
3783    # couldn't write to the same location). So whether output is stacked or not
3784    # depends on indices_stacked.
3785    if indices_stacked:
3786      value = _unflatten_first_dim(value, n)
3787    return wrap(value, indices_stacked)
3788
3789
3790@RegisterPFor("TensorArrayScatterV3")
3791def _convert_tensor_array_scatter_v3(pfor_input):
3792  handle = pfor_input.unstacked_input(0)
3793  indices, indices_stacked, _ = pfor_input.input(1)
3794  indices = array_ops.reshape(indices, [-1])
3795  value, value_stacked, _ = pfor_input.input(2)
3796  flow, flow_stacked, _ = pfor_input.input(3)
3797
3798  if flow_stacked:
3799    flow = _unstack_flow(flow)
3800
3801  is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
3802  if is_inside:
3803    if indices_stacked:
3804      raise ValueError(f"Need indices for {handle} to be loop invariant.")
3805    # Note that flow_stacked indicates if existing values in the array are
3806    # stacked or not.
3807    if not flow_stacked and not value_stacked:
3808      flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value,
3809                                                       flow)
3810      return wrap(flow_out, False)
3811    if not value_stacked:
3812      # TODO(agarwal): tile in the second dimension directly instead of
3813      # transposing below.
3814      value = _stack(value, pfor_input.pfor.loop_len_vector).t
3815
3816    value = _transpose_first_two_dims(value)
3817    # TODO(agarwal): Note that if a previous write was unstacked, flow will be
3818    # unstacked, and a stacked value may be written here which may cause
3819    # runtime error due to different elements having different shape. We do
3820    # not try to prevent that.
3821    flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value,
3822                                                     flow)
3823    return _stack(flow_out, pfor_input.pfor.loop_len_vector)
3824  if not indices_stacked:
3825    raise ValueError(f"Need indices for {handle} to be not loop invariant.")
3826  if not value_stacked:
3827    value = _stack(value, pfor_input.pfor.loop_len_vector).t
3828  value = _flatten_first_two_dims(value)
3829  flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, flow)
3830  return _stack(flow_out, pfor_input.pfor.loop_len_vector)
3831
3832
3833@RegisterPFor("TensorArrayGradV3")
3834def _convert_tensor_array_grad_v3(pfor_input):
3835  handle = pfor_input.unstacked_input(0)
3836  flow, flow_stacked, _ = pfor_input.input(1)
3837  if flow_stacked:
3838    flow = _unstack_flow(flow)
3839  source = pfor_input.get_attr("source")
3840  # TODO(agarwal): For now, we assume that gradients are stacked if the
3841  # TensorArrayGradV3 call is being done inside the pfor. Getting that wrong
3842  # will give runtime error due to incorrect shape being written to the
3843  # accumulator. It is difficult to know in advance if gradients written will be
3844  # stacked or not. Note that flow being stacked is not indicative of the
3845  # gradient being stacked or not. Revisit this later.
3846  shape_to_prepend = pfor_input.pfor.loop_len_vector
3847  grad_handle, flow_out = data_flow_ops.tensor_array_grad_with_shape(
3848      handle=handle,
3849      flow_in=flow,
3850      shape_to_prepend=shape_to_prepend,
3851      source=source)
3852  flow_out = _stack(flow_out, pfor_input.pfor.loop_len_vector).t
3853  return [wrap(grad_handle, False), wrap(flow_out, True)]
3854
3855
3856def _stack_tensor_list_shape(shape, first_dim):
3857  shape_value = tensor_util.constant_value(shape)
3858  # Note that negative values in the shape are used to signify unknown shapes
3859  # and are handled in a special way.
3860  if shape_value is not None:
3861    shape_value = np.asarray(shape_value)
3862    if -1 in shape_value:
3863      return constant_op.constant(-1)
3864    elif not shape_value.size:
3865      return first_dim
3866  else:
3867    shape = array_ops.reshape(shape, [-1])
3868    return control_flow_ops.cond(
3869        math_ops.reduce_any(shape < 0),
3870        lambda: constant_op.constant(-1),
3871        lambda: array_ops.concat([first_dim, shape], axis=0))
3872
3873
3874def _tile_variant_with_length(t, length):
3875  """stacks `t` `length` times."""
3876  if _is_variant_with_internal_stacking(t):
3877    # The content of TensorLists is vectorized, not the variant itself.
3878    return t
3879  original_tensor = t
3880  t.set_shape([])
3881  t = array_ops.reshape(t, [-1])
3882  with ops.device("CPU:0"):
3883    result = array_ops.tile(t, length)
3884    # TODO(b/169968286): Should regular shape functions do handle data
3885    # propagation here?
3886    handle_data_util.copy_handle_data(original_tensor, result)
3887    return result
3888
3889
3890def _tile_variant(t, pfor_input):
3891  """stacks `t` according to its loop context."""
3892  return _tile_variant_with_length(t, pfor_input.pfor.loop_len_vector)
3893
3894
3895def _untile_variant(t):
3896  if _is_variant_with_internal_stacking(t):
3897    # The content of TensorLists is vectorized, not the variant itself.
3898    if not t.shape.is_compatible_with([]):
3899      raise AssertionError(
3900          ("Unexpectedly saw a vectorized variant (e.g. TensorList) with "
3901           f"non-scalar shape: {t!r}"))
3902    return t
3903  return array_ops.gather(t, 0)
3904
3905
3906@RegisterPFor("OptionalFromValue")
3907def _convert_optional_from_value(pfor_input):
3908  pfor_input.stack_inputs()
3909  return wrap(
3910      gen_dataset_ops.optional_from_value([x.t for x in pfor_input.inputs]),
3911      True)
3912
3913
3914@RegisterPFor("OptionalGetValue")
3915def _convert_optional_get_value(pfor_input):
3916  handle = pfor_input.stacked_input(0)
3917  output_types = pfor_input.get_attr("output_types")
3918  original_output_shapes = pfor_input.get_attr("output_shapes")
3919  output_shapes = []
3920  for shape in original_output_shapes:
3921    shape = tensor_shape.TensorShape(shape)
3922    loop_len_shape = tensor_shape.TensorShape(
3923        [tensor_util.constant_value(pfor_input.pfor.loop_len_vector)])
3924    shape = loop_len_shape.concatenate(shape)
3925    output_shapes.append(shape.as_proto())
3926  results = gen_dataset_ops.optional_get_value(handle, output_types,
3927                                               output_shapes)
3928  return [wrap(t, True) for t in results]
3929
3930
3931@RegisterPFor("TensorListReserve")
3932def _convert_tensor_list_reserve(pfor_input):
3933  element_shape = pfor_input.unstacked_input(0)
3934  num_elements = pfor_input.unstacked_input(1)
3935  element_dtype = pfor_input.get_attr("element_dtype")
3936
3937  # Prepend a dimension to element_shape.
3938  element_shape = _stack_tensor_list_shape(element_shape,
3939                                           pfor_input.pfor.loop_len_vector)
3940  handle = list_ops.tensor_list_reserve(
3941      element_shape, num_elements, element_dtype=element_dtype)
3942
3943  return wrap(_tile_variant(handle, pfor_input), True)
3944
3945
3946@RegisterPFor("TensorListElementShape")
3947def _convert_tensor_list_element_shape(pfor_input):
3948  handle = _untile_variant(pfor_input.stacked_input(0))
3949  shape_type = pfor_input.get_attr("shape_type")
3950  shape = list_ops.tensor_list_element_shape(handle, shape_type)
3951  shape = array_ops.reshape(shape, [-1])
3952  shape = shape[1:]
3953  return wrap(shape, False)
3954
3955
3956@RegisterPFor("TensorListLength")
3957def _convert_tensor_list_length(pfor_input):
3958  handle = _untile_variant(pfor_input.stacked_input(0))
3959  return wrap(list_ops.tensor_list_length(handle), False)
3960
3961
3962def _stack_tensor_list(handle, dtype, loop_len_vector, element_shape=None):
3963  if element_shape is None:
3964    element_shape = list_ops.tensor_list_element_shape(handle, dtypes.int32)
3965  length = list_ops.tensor_list_length(handle)
3966  new_handle = list_ops.tensor_list_reserve(
3967      _stack_tensor_list_shape(element_shape, loop_len_vector), length, dtype)
3968
3969  def _body_fn(i, h):
3970    elem = list_ops.tensor_list_get_item(handle, i, dtype, element_shape)
3971    elem = _stack(elem, loop_len_vector).t
3972    return i + 1, list_ops.tensor_list_set_item(h, i, elem)
3973
3974  return control_flow_ops.while_loop(lambda i, _: i < length, _body_fn,
3975                                     [0, new_handle])[1]
3976
3977
3978@RegisterPFor("TensorListGetItem")
3979def _convert_tensor_list_get_item(pfor_input):
3980  handle, handle_stacked, _ = pfor_input.input(0)
3981  index, index_stacked, _ = pfor_input.input(1)
3982  element_shape = pfor_input.unstacked_input(2)
3983  element_dtype = pfor_input.get_attr("element_dtype")
3984
3985  if handle_stacked:
3986    handle = _untile_variant(handle)
3987    element_shape = _stack_tensor_list_shape(element_shape,
3988                                             pfor_input.pfor.loop_len_vector)
3989    if index_stacked:
3990      # We use a sequential loop since that may be more efficient than first
3991      # gathering and concatenating all the element corresponding to `index`,
3992      # and then doing a gather on it.
3993      def _map_fn(i):
3994        item_i = list_ops.tensor_list_get_item(
3995            handle,
3996            index[i],
3997            element_dtype=element_dtype)
3998        return array_ops.gather(item_i, i)
3999
4000      output = map_fn.map_fn(_map_fn, pfor_input.pfor.all_indices)
4001      return wrap(output, True)
4002    else:
4003      output = list_ops.tensor_list_get_item(
4004          handle,
4005          index,
4006          element_shape=element_shape,
4007          element_dtype=element_dtype)
4008      return wrap(output, True)
4009  else:
4010    assert index_stacked
4011    return wrap(
4012        list_ops.tensor_list_gather(
4013            handle,
4014            index,
4015            element_shape=element_shape,
4016            element_dtype=element_dtype), True)
4017
4018
4019@RegisterPFor("TensorListSetItem")
4020def _convert_tensor_array_set_item(pfor_input):
4021  handle, handle_stacked, _ = pfor_input.input(0)
4022  index, index_stacked, _ = pfor_input.input(1)
4023  item, item_stacked, _ = pfor_input.input(2)
4024
4025  if not handle_stacked:
4026    # Special case where we can statically guarantee that the indices are
4027    # disjoint.
4028    if index is pfor_input.pfor.all_indices:
4029      if not item_stacked:
4030        item = _stack(item, pfor_input.pfor.loop_len_vector).t
4031      return wrap(
4032          list_ops.tensor_list_scatter(item, index, input_handle=handle), False)
4033    else:
4034      handle = _stack_tensor_list(handle, item.dtype,
4035                                  pfor_input.pfor.loop_len_vector)
4036  else:
4037    handle = _untile_variant(handle)
4038
4039  if index_stacked:
4040    # TODO(agarwal): handle this.
4041    raise ValueError("Vectorizing writes to a TensorList with loop "
4042                     "variant indices is currently unsupported.")
4043
4044  else:
4045    if not item_stacked:
4046      item = _stack(item, pfor_input.pfor.loop_len_vector).t
4047    handle = list_ops.tensor_list_set_item(handle, index, item)
4048    return wrap(_tile_variant(handle, pfor_input), True)
4049
4050
4051@RegisterPFor("TensorListPushBack")
4052def _convert_tensor_list_push_back(pfor_input):
4053  handle, handle_stacked, _ = pfor_input.input(0)
4054  tensor, tensor_stacked, _ = pfor_input.input(1)
4055  if handle_stacked:
4056    handle = _untile_variant(handle)
4057  else:
4058    handle = _stack_tensor_list(handle, tensor.dtype,
4059                                pfor_input.pfor.loop_len_vector)
4060  if not tensor_stacked:
4061    tensor = _stack(tensor, pfor_input.pfor.loop_len_vector).t
4062  handle = list_ops.tensor_list_push_back(handle, tensor)
4063  return wrap(_tile_variant(handle, pfor_input), True)
4064
4065
4066@RegisterPFor("TensorListPopBack")
4067def _convert_tensor_array_push_back(pfor_input):
4068  handle = pfor_input.stacked_input(0)
4069  element_shape = pfor_input.unstacked_input(1)
4070  handle = _untile_variant(handle)
4071
4072  if element_shape.shape.ndims == 0:
4073    # Default / unspecified
4074    vectorized_shape = -1
4075  else:
4076    # PopBack has an element shape set when it's the gradient of PushBack, only
4077    # used when the list is uninitialized.
4078    vectorized_shape = array_ops.concat(
4079        [pfor_input.pfor.loop_len_vector, element_shape], axis=0)
4080
4081  output_handle, tensor = gen_list_ops.tensor_list_pop_back(
4082      input_handle=handle, element_dtype=pfor_input.get_attr("element_dtype"),
4083      element_shape=vectorized_shape)
4084  return wrap(output_handle, True), wrap(tensor, True)
4085
4086
4087@RegisterPFor("TensorListConcatV2")
4088def _convert_tensor_list_concat_v2(pfor_input):
4089  input_handle = pfor_input.stacked_input(0)
4090  element_shape = pfor_input.unstacked_input(1)
4091  leading_dims = pfor_input.unstacked_input(2)
4092  element_dtype = pfor_input.get_attr("element_dtype")
4093
4094  handle = _untile_variant(input_handle)
4095  length = list_ops.tensor_list_length(handle)
4096  # Note that element_shape attribute can have incomplete shapes. This doesn't
4097  # seem to work well when creating another list and then doing a concat on it.
4098  # Hence we try to find the dynamic shape here.
4099  element_shape = control_flow_ops.cond(
4100      length > 0, lambda: array_ops.shape(
4101          list_ops.tensor_list_get_item(handle, 0, element_dtype, None)),
4102      lambda: constant_op.constant([0, 0], dtype=dtypes.int32))
4103  # The code below creates a copy of the list with each elements' first two
4104  # dimensions transposed.
4105  new_element_shape = array_ops.concat(
4106      [element_shape[1:2], element_shape[0:1], element_shape[2:]], axis=0)
4107
4108  # Create a new TensorList with elements transposed.
4109  def _transpose_elem(i, h):
4110    elem = list_ops.tensor_list_get_item(handle, i, element_dtype, None)
4111    elem = _transpose_first_two_dims(elem)
4112    return i + 1, list_ops.tensor_list_set_item(h, i, elem)
4113
4114  new_handle = list_ops.tensor_list_reserve(new_element_shape, length,
4115                                            element_dtype)
4116  new_handle = control_flow_ops.while_loop(lambda i, _: i < length,
4117                                           _transpose_elem, [0, new_handle])[1]
4118  output, lengths = gen_list_ops.tensor_list_concat_v2(
4119      input_handle=new_handle,
4120      element_dtype=element_dtype,
4121      element_shape=new_element_shape,
4122      leading_dims=leading_dims)
4123  output = _transpose_first_two_dims(output)
4124  return wrap(output, True), wrap(lengths, False)
4125
4126
4127@RegisterPFor("TensorListStack")
4128def _convert_tensor_list_stack(pfor_input):
4129  handle = pfor_input.stacked_input(0)
4130  input_shape = pfor_input.unstacked_input(1)
4131  element_dtype = pfor_input.get_attr("element_dtype")
4132  num_elements = pfor_input.get_attr("num_elements")
4133
4134  handle = _untile_variant(handle)
4135  input_shape = _stack_tensor_list_shape(input_shape,
4136                                         pfor_input.pfor.loop_len_vector)
4137  output = list_ops.tensor_list_stack(
4138      handle,
4139      element_dtype,
4140      element_shape=input_shape,
4141      num_elements=num_elements)
4142  output = _transpose_first_two_dims(output)
4143  return wrap(output, True)
4144
4145
4146@RegisterPFor("TensorListGather")
4147def _convert_tensor_list_gather(pfor_input):
4148  handle, handle_stacked, _ = pfor_input.input(0)
4149  index, index_stacked, _ = pfor_input.input(1)
4150  element_shape = pfor_input.unstacked_input(2)
4151  element_dtype = pfor_input.get_attr("element_dtype")
4152
4153  if handle_stacked:
4154    handle = _untile_variant(handle)
4155    element_shape = _stack_tensor_list_shape(element_shape,
4156                                             pfor_input.pfor.loop_len_vector)
4157    if index_stacked:
4158      # We use a sequential loop since that may be more efficient than first
4159      # gathering and concatenating all the element corresponding to `index`,
4160      # and then doing a gather on it.
4161      def _map_fn(i):
4162        item_i = list_ops.tensor_list_gather(
4163            handle,
4164            index[i],
4165            element_dtype=element_dtype)
4166        axis = array_ops.rank(index) - 1
4167        return array_ops.gather(item_i, i, axis=axis)
4168
4169      output = map_fn.map_fn(_map_fn, pfor_input.pfor.all_indices)
4170      return wrap(output, True)
4171    else:
4172      output = list_ops.tensor_list_gather(
4173          handle,
4174          index,
4175          element_shape=element_shape,
4176          element_dtype=element_dtype)
4177      return wrap(output, True)
4178  else:
4179    assert index_stacked
4180    index_shape = array_ops.shape(index)
4181    index = array_ops.reshape(index, [-1])
4182    values = list_ops.tensor_list_gather(
4183        handle, index, element_shape=element_shape, element_dtype=element_dtype)
4184    final_shape = array_ops.concat(
4185        [index_shape, array_ops.shape(values)[1:]], axis=0)
4186    return wrap(array_ops.reshape(values, final_shape), True)
4187
4188
4189@RegisterPFor("TensorListScatterIntoExistingList")
4190def _convert_tensor_list_scatter(pfor_input):
4191  pfor_input.stack_inputs([1])
4192  handle, handle_stacked, _ = pfor_input.input(0)
4193  item = pfor_input.stacked_input(1)
4194  indices, indices_stacked, _ = pfor_input.input(2)
4195  if handle_stacked:
4196    handle = _untile_variant(handle)
4197  else:
4198    handle = _stack_tensor_list(handle, item.dtype,
4199                                pfor_input.pfor.loop_len_vector)
4200
4201  item = _transpose_first_two_dims(item)
4202  if indices_stacked:
4203    # Pretend the list is a dense tensor:
4204    #   list_as_dense: Tensor[list_len, loop_len, ...]
4205    # And indices are a tensor with shape (before transpose):
4206    #   indices: Tensor[loop_len, num_scatters]
4207    # The item to scatter has shape (before transpose):
4208    #   item: Tensor[loop_len, num_scatters, ...]
4209    #
4210    # We want list_as_dense[indices[i, j], i] = item[i, j]
4211    #
4212    # Since we're not just indexing along the first axis of `list_as_dense`, we
4213    # need to first extract the relevant list entries based on `indices`,
4214    # scatter into them according to the loop index, and re-scatter the chunks
4215    # we updated back into the list.
4216    indices = _transpose_first_two_dims(indices)
4217    indices_flat = array_ops.reshape(indices, [-1])
4218    # In many cases `indices` will be unique across pfor iterations, but this is
4219    # not guaranteed. If there are duplicates, we need to map multiple updates
4220    # to a single chunk extracted from the list. The last update should win.
4221    unique_indices = array_ops.unique(indices_flat)
4222    gathered_items = list_ops.tensor_list_gather(
4223        handle, unique_indices.y, element_dtype=item.dtype,
4224        element_shape=array_ops.shape(item)[1:])
4225    loop_idx = math_ops.range(pfor_input.pfor.loop_len_vector[0])
4226    scatters_per_op = array_ops.shape(indices)[0]
4227
4228    unique_indices_loop_idx = array_ops.reshape(array_ops.tile(
4229        loop_idx[None, :], [scatters_per_op, 1]), [-1])
4230    scatter_indices = array_ops.stack(
4231        [unique_indices.idx, unique_indices_loop_idx],
4232        axis=1)
4233    # This op does *not* guarantee last-update-wins on GPU, so semantics may not
4234    # be exactly preserved for duplicate updates there.
4235    scattered = array_ops.tensor_scatter_nd_update(
4236        tensor=gathered_items,
4237        indices=scatter_indices,
4238        updates=_flatten_first_two_dims(item))
4239    handle = list_ops.tensor_list_scatter(
4240        scattered, unique_indices.y, input_handle=handle)
4241  else:
4242    handle = list_ops.tensor_list_scatter(item, indices, input_handle=handle)
4243  return wrap(_tile_variant(handle, pfor_input), True)
4244
4245
4246@RegisterPFor("TensorListFromTensor")
4247def _convert_tensor_list_from_tensor(pfor_input):
4248  tensor = pfor_input.stacked_input(0)
4249  element_shape = pfor_input.unstacked_input(1)
4250  tensor = _transpose_first_two_dims(tensor)
4251  element_shape = _stack_tensor_list_shape(element_shape,
4252                                           pfor_input.pfor.loop_len_vector)
4253  handle = list_ops.tensor_list_from_tensor(tensor, element_shape)
4254  return wrap(_tile_variant(handle, pfor_input), True)
4255
4256
4257@RegisterPFor("TensorScatterUpdate")
4258def _convert_tensor_scatter_update(pfor_input):
4259  pfor_input.stack_inputs([0, 1, 2])
4260  tensor = pfor_input.stacked_input(0)
4261  indices = pfor_input.stacked_input(1)
4262  updates = pfor_input.stacked_input(2)
4263
4264  indices_shape = array_ops.shape(indices)
4265  indices_rank = array_ops.rank(indices)
4266  loop_length = indices_shape[0]
4267
4268  # Create a loop count range and extend its dimensions to match `indices`.
4269  loop_count_shape = array_ops.tensor_scatter_nd_update(
4270      array_ops.ones([indices_rank], dtype=dtypes.int32), [[0]], [loop_length])
4271  loop_count = array_ops.reshape(math_ops.range(loop_length), loop_count_shape)
4272
4273  # Tile the loop count range for the batch dimensions (all except the first and
4274  # last dimensions of indices).
4275  # Rank(indices) >= 3 always for this function so we always have at least 1.
4276  tile_multiplier = array_ops.tensor_scatter_nd_update(
4277      indices_shape, [[0], [indices_rank - 1]], [1, 1])
4278  meta_index = array_ops.tile(loop_count, tile_multiplier)
4279
4280  # Insert the loop-identifying index.
4281  indices = array_ops.concat([meta_index, indices], axis=-1)
4282
4283  result = array_ops.tensor_scatter_nd_update(tensor, indices, updates)
4284  return wrap(result, True)
4285
4286# StackV2 conversion is tricky since we don't have arrays of StackV2. So similar
4287# to TensorArrays, we convert them by changing the dimension of the elements
4288# inside the stack.
4289#
4290# We consider two cases:
4291#
4292# 1. StackV2 is constructed and used entirely inside the pfor loop.
4293# We keep a single Stack and perform the push/pop operations of all the
4294# iterations in lock-step. We also assume that all the iterations perform these
4295# operations. In case of dynamic control flow, if only some of the iterations
4296# try to perform a push/pop, then the conversion may not work correctly and may
4297# cause undefined behavior.
4298# TODO(agarwal): test StackV2 with dynamic control flow.
4299#
4300# 2. StackV2 is constructed outside the pfor loop.
4301# Performing stack push/pop in a parallel fashion is ill-defined. However given
4302# that reading stacks created externally is a common operation when computing
4303# jacobians, we provide some special semantics here as follows.
4304#  - disallow push operations to the stack
4305#  - pop operations are performed in lock step by all iterations, similar to the
4306#  case when the stack is created inside. A single value is popped during the
4307#  lock-step operation and broadcast to all the iterations. Values in the stack
4308#  are assumed to be loop-invariant.
4309#
4310# Some other implementation details:
4311# We use an ugly logic to find whether values in Stack data structure are
4312# loop invariant or not. When converting push/pop operations, we keep track of
4313# whether the last conversion used a stacked value or not (see _stack_cache
4314# below). As a result if an unstacked value is written first, subsequent stacked
4315# writes are disallowed when they could have been allowed in theory.
4316
4317# Map from cache key based on StackV2 handle to a bool indicating whether values
4318# are stacked or not.
4319# TODO(agarwal): move _stack_cache inside pfor?
4320_stack_cache = {}
4321
4322
4323def _stack_cache_key(pfor_input):
4324  """Create cache key corresponding to a stack handle."""
4325  op_type = pfor_input.op_type
4326  assert op_type in ["StackPushV2", "StackPopV2"], op_type
4327  orig_handle = pfor_input.op.inputs[0]
4328  while orig_handle.op.type in ["Identity", "Enter"]:
4329    orig_handle = orig_handle.op.inputs[0]
4330  assert orig_handle.op.type == "StackV2", orig_handle.op
4331  return ops.get_default_graph(), pfor_input.pfor, orig_handle
4332
4333
4334def _stack_handle_inside_pfor(handle, pfor_input):
4335  while handle.op.type in ["Identity", "Enter"]:
4336    handle = handle.op.inputs[0]
4337  assert handle.op.type == "StackV2", ("Unable to find StackV2 op. Got %s" %
4338                                       handle.op)
4339  return pfor_input.pfor.op_is_inside_loop(handle.op)
4340
4341
4342@RegisterPFor("StackPushV2")
4343def _convert_stack_push_v2(pfor_input):
4344  handle = pfor_input.unstacked_input(0)
4345  elem, elem_stacked, _ = pfor_input.input(1)
4346  swap_memory = pfor_input.get_attr("swap_memory")
4347
4348  if not _stack_handle_inside_pfor(pfor_input.op.inputs[0], pfor_input):
4349    raise ValueError("StackPushV2 not allowed on stacks created outside pfor.")
4350  stack_cache_key = _stack_cache_key(pfor_input)
4351  stacked = _stack_cache.get(stack_cache_key, None)
4352  if stacked is None:
4353    stacked = elem_stacked
4354    _stack_cache[stack_cache_key] = stacked
4355  else:
4356    # If we previously made it unstacked then we can't revert to being stacked.
4357    if not stacked and elem_stacked:
4358      raise ValueError(
4359          "It looks like the stack was previously determined to be loop "
4360          "invariant, but we are now trying to push a loop dependent value "
4361          "to it. This is currently unsupported.")
4362    if stacked and not elem_stacked:
4363      elem = _stack(elem, pfor_input.pfor.loop_len_vector).t
4364  out = data_flow_ops.stack_push_v2(handle, elem, swap_memory=swap_memory)
4365  return wrap(out, stacked)
4366
4367
4368# Note that inputs to this convertor will be unstacked. However it should get
4369# called since it is a stateful op.
4370@RegisterPFor("StackPopV2")
4371def _convert_stack_pop_v2(pfor_input):
4372  handle = pfor_input.unstacked_input(0)
4373  stack_cache_key = _stack_cache_key(pfor_input)
4374  stacked = _stack_cache.get(stack_cache_key, None)
4375  # If a StackPushV2 has not been converted yet, we default to unstacked since
4376  # the push could be outside of pfor, or the convertor may not be called if the
4377  # inputs are unconverted.
4378  if stacked is None:
4379    stacked = False
4380    _stack_cache[stack_cache_key] = False
4381  elem_type = pfor_input.get_attr("elem_type")
4382  out = data_flow_ops.stack_pop_v2(handle, elem_type)
4383  return wrap(out, stacked)
4384
4385
4386# parsing_ops
4387
4388
4389@RegisterPFor("DecodeCSV")
4390def _convert_decode_csv(pfor_input):
4391  lines = pfor_input.stacked_input(0)
4392  record_defaults = [
4393      pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs)
4394  ]
4395  field_delim = pfor_input.get_attr("field_delim")
4396  use_quote_delim = pfor_input.get_attr("use_quote_delim")
4397  select_cols = pfor_input.get_attr("select_cols")
4398  if not select_cols:
4399    select_cols = None
4400  return [
4401      wrap(t, True) for t in parsing_ops.decode_csv(
4402          lines,
4403          record_defaults,
4404          field_delim=field_delim,
4405          use_quote_delim=use_quote_delim,
4406          select_cols=select_cols)
4407  ]
4408
4409
4410@RegisterPFor("ParseSingleExample")
4411def _convert_parse_single_example(pfor_input):
4412  serialized = pfor_input.stacked_input(0)
4413  dense_defaults = [
4414      pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs)
4415  ]
4416  sparse_keys = pfor_input.get_attr("sparse_keys")
4417  dense_keys = pfor_input.get_attr("dense_keys")
4418  sparse_types = pfor_input.get_attr("sparse_types")
4419  dense_shapes = pfor_input.get_attr("dense_shapes")
4420  output = gen_parsing_ops.parse_example(
4421      serialized=serialized,
4422      names=[],
4423      dense_defaults=dense_defaults,
4424      sparse_keys=sparse_keys,
4425      dense_keys=dense_keys,
4426      sparse_types=sparse_types,
4427      dense_shapes=dense_shapes)
4428  return [wrap(t, True, True) for t in nest.flatten(output)]
4429
4430
4431@RegisterPFor("ParseExampleV2")
4432def _convert_parse_example_v2(pfor_input):
4433  serialized = pfor_input.stacked_input(0)
4434  sparse_keys = pfor_input.unstacked_input(2)
4435  dense_keys = pfor_input.unstacked_input(3)
4436  ragged_keys = pfor_input.unstacked_input(4)
4437  dense_defaults = [
4438      pfor_input.unstacked_input(i) for i in range(5, pfor_input.num_inputs)
4439  ]
4440  num_sparse = pfor_input.get_attr("num_sparse")
4441  sparse_types = pfor_input.get_attr("sparse_types")
4442  ragged_value_types = pfor_input.get_attr("ragged_value_types")
4443  ragged_split_types = pfor_input.get_attr("ragged_split_types")
4444  dense_shapes = pfor_input.get_attr("dense_shapes")
4445  if serialized.shape.ndims not in (None, 1):
4446    raise ValueError("ParseExampleV2 can only be converted if `serialized` "
4447                     f"is scalar. Received shape: {serialized.shape}.")
4448  output = gen_parsing_ops.parse_example_v2(
4449      serialized=serialized,
4450      names=[],
4451      sparse_keys=sparse_keys,
4452      dense_keys=dense_keys,
4453      ragged_keys=ragged_keys,
4454      dense_defaults=dense_defaults,
4455      num_sparse=num_sparse,
4456      sparse_types=sparse_types,
4457      ragged_value_types=ragged_value_types,
4458      ragged_split_types=ragged_split_types,
4459      dense_shapes=dense_shapes)
4460  return [wrap(t, True, True) for t in nest.flatten(output)]
4461
4462
4463# functional_ops
4464
4465
4466def _convert_function_call(func, converter, inputs):
4467  assert isinstance(func.graph, func_graph.FuncGraph), func
4468  assert isinstance(converter, PFor)
4469
4470  # TODO(agarwal): consider caching this function definition.
4471  @def_function.function
4472  def f(*args):
4473    assert all(isinstance(arg, WrappedTensor) for arg in args), args
4474    assert len(args) == len(func.graph.inputs), (args, func.graph.inputs)
4475    #  Map inputs to function arguments.
4476    for inp, arg in zip(func.graph.inputs, args):
4477      converter._add_conversion(inp, arg)
4478    # Convert output tensors.
4479    return tuple(
4480        [converter._convert_helper(x).t for x in func._func_graph_outputs])
4481
4482  call_outputs = f(*inputs)
4483  assert len(call_outputs) == len(func._func_graph_outputs)
4484  outputs = []
4485  for call_output, output_tensor in zip(call_outputs, func._func_graph_outputs):
4486    func_output = converter._convert_helper(output_tensor)
4487    outputs.append(
4488        wrap(call_output, func_output.is_stacked,
4489             func_output.is_sparse_stacked))
4490  return outputs
4491
4492
4493@RegisterPFor("StatefulPartitionedCall")
4494@RegisterPFor("PartitionedCall")
4495def _convert_partitioned_call(pfor_input):
4496  func_name = pfor_input.get_attr("f").name
4497  func = pfor_input.op.graph._get_function(compat.as_bytes(func_name))
4498  assert isinstance(func.graph, func_graph.FuncGraph), (
4499      "Could not find FuncGraph object for %s. Got func %s" % (func_name, func))
4500  pfor = pfor_input.pfor
4501  converter = PFor(
4502      loop_var=pfor.loop_var,
4503      loop_len=pfor.loop_len_vector[0],
4504      pfor_ops=func.graph.get_operations(),
4505      fallback_to_while_loop=pfor.fallback_to_while_loop,
4506      all_indices=pfor.all_indices,
4507      all_indices_partitioned=pfor.all_indices_partitioned,
4508      pfor_config=pfor.pfor_config)
4509  return _convert_function_call(func, converter, pfor_input.inputs)
4510
4511
4512def _partition_inputs_for_indices(inputs, indices):
4513  new_inputs = []
4514  for inp in inputs:
4515    if inp.is_stacked:
4516      new_inputs.append(wrap(array_ops.gather(inp.t, indices), True))
4517    else:
4518      new_inputs.append(inp)
4519  return new_inputs
4520
4521
4522def _outputs_for_branch(func_name, indices, pfor_input, inputs):
4523  if indices is None:
4524    indices = pfor_input.pfor.all_indices
4525    partitioned = pfor_input.pfor.all_indices_partitioned
4526  else:
4527    partitioned = True
4528  func = pfor_input.op.graph._get_function(func_name)
4529  converter = PFor(
4530      loop_var=pfor_input.pfor.loop_var,
4531      loop_len=array_ops.size(indices),
4532      pfor_ops=func.graph.get_operations(),
4533      fallback_to_while_loop=pfor_input.pfor.fallback_to_while_loop,
4534      all_indices=indices,
4535      all_indices_partitioned=partitioned,
4536      pfor_config=pfor_input.pfor.pfor_config)
4537  outputs = _convert_function_call(func, converter, inputs)
4538  stacked_outputs = []
4539  for out in outputs:
4540    if not out.is_stacked:
4541      stacked_outputs.append(_stack(out.t, [array_ops.size(indices)]).t)
4542    else:
4543      stacked_outputs.append(out.t)
4544  return stacked_outputs
4545
4546
4547# TODO(agarwal): Currently the converted code aggressively tiles loop variant
4548# outputs from the then/else branches. Instead, it could do so only if at least
4549# one of the branch outputs is loop variant.
4550@RegisterPFor("StatelessIf")
4551@RegisterPFor("If")
4552def _convert_if(pfor_input):
4553  cond, cond_stacked, _ = pfor_input.input(0)
4554  inputs = pfor_input.inputs[1:]
4555  then_branch = pfor_input.get_attr("then_branch")
4556  else_branch = pfor_input.get_attr("else_branch")
4557
4558  if cond_stacked:
4559    cond_int = math_ops.cast(cond, dtypes.int32)
4560    # Compute loop indices for the different branches
4561    false_indices, true_indices = data_flow_ops.dynamic_partition(
4562        pfor_input.pfor.all_indices, cond_int, 2)
4563    # Compute indices for cond being True or False.
4564    if pfor_input.pfor.all_indices_partitioned:
4565      else_indices, then_indices = data_flow_ops.dynamic_partition(
4566          math_ops.range(pfor_input.pfor.loop_len_vector[0]),
4567          cond_int, 2)
4568    else:
4569      else_indices, then_indices = false_indices, true_indices
4570    # Partition inputs
4571    then_inputs = _partition_inputs_for_indices(inputs, then_indices)
4572    else_inputs = _partition_inputs_for_indices(inputs, else_indices)
4573
4574    # Convert "then" branch.
4575    then_outputs = _outputs_for_branch(then_branch.name, true_indices,
4576                                       pfor_input, then_inputs)
4577
4578    # Convert "else" branch.
4579    else_outputs = _outputs_for_branch(else_branch.name, false_indices,
4580                                       pfor_input, else_inputs)
4581
4582    assert len(then_outputs) == len(else_outputs)
4583    # Note that if the "then" and "else" branches are updating the same state,
4584    # and possibly reading them as well, it could lead to undefined behavior
4585    # since the ordering of those operations is not well defined.
4586    # One possibility is to order all the "then" branches to execute before all
4587    # the "else" branches so that the side-effects in the former are visible to
4588    # the latter. For now, we leave that as undefined behavior.
4589    outputs = []
4590    # Merge outputs
4591    for then_output, else_output in zip(then_outputs, else_outputs):
4592      out = data_flow_ops.dynamic_stitch([then_indices, else_indices],
4593                                         [then_output, else_output])
4594      outputs.append(wrap(out, True))
4595    return outputs
4596  else:
4597    outputs = control_flow_ops.cond(
4598        cond,
4599        lambda: _outputs_for_branch(then_branch.name, None, pfor_input, inputs),
4600        lambda: _outputs_for_branch(else_branch.name, None, pfor_input, inputs))
4601    return [wrap(t, True) for t in outputs]
4602
4603
4604@RegisterPFor("Case")
4605@RegisterPFor("StatelessCase")
4606def _convert_stateless_case(pfor_input):
4607  branch_idx, is_stacked, _ = pfor_input.input(0)
4608  branches = pfor_input.get_attr("branches")
4609  inputs = pfor_input.inputs[1:]
4610
4611  if is_stacked:
4612    logging.info("Running stacked flow")
4613
4614    # Compute loop indices for the different branches
4615    switch_indices = data_flow_ops.dynamic_partition(
4616        pfor_input.pfor.all_indices, branch_idx, len(branches))
4617    if pfor_input.pfor.all_indices_partitioned:
4618      partitioned_indices = data_flow_ops.dynamic_partition(
4619          math_ops.range(pfor_input.pfor.loop_len_vector[0]), branch_idx,
4620          len(branches))
4621    else:
4622      partitioned_indices = switch_indices
4623    # Partition inputs
4624    input_list = []
4625    for indices in partitioned_indices:
4626      input_list.append(_partition_inputs_for_indices(inputs, indices))
4627
4628    outputs = []
4629    for (b, indices, inputs) in zip(branches, switch_indices, input_list):
4630      out = _outputs_for_branch(b.name, indices, pfor_input, inputs)
4631      outputs.extend(out)
4632
4633    out = data_flow_ops.dynamic_stitch(partitioned_indices, outputs)
4634    return [wrap(out, True)]
4635  else:
4636    new_branches = []
4637    for b in branches:
4638      def new_function(func=b.name):
4639        return _outputs_for_branch(func, None, pfor_input,
4640                                   pfor_input.inputs[1:])
4641
4642      new_branches.append(new_function)
4643
4644    outputs = []
4645    outputs = control_flow_ops.switch_case(branch_idx, new_branches)
4646    return [wrap(t, True) for t in outputs]
4647
4648
4649class WhileV2:
4650  """Object for vectorizing V2 while_loop op."""
4651
4652  def __init__(self, pfor_input):
4653    self._pfor_input = pfor_input
4654    self._pfor = pfor_input.pfor
4655    cond_func_name = pfor_input.get_attr("cond").name
4656    self._cond_func = pfor_input.op.graph._get_function(compat.as_bytes(
4657        cond_func_name))
4658    body_func_name = pfor_input.get_attr("body").name
4659    self._body_func = pfor_input.op.graph._get_function(compat.as_bytes(
4660        body_func_name))
4661    if self._cond_func is None or self._body_func is None:
4662      raise ValueError("Error extracting cond and body functions for op "
4663                       f"{self._pfor_input.op}.")
4664    # Indices of inputs that are passed unchanged through the while loop body.
4665    # Typically these are tensors captured from outside the body context.
4666    self._body_pass_through_indices = set()
4667    for i, (inp, out) in enumerate(zip(self._body_func.graph.inputs,
4668                                       self._body_func.graph.outputs)):
4669      if id(inp) == id(out):
4670        self._body_pass_through_indices.add(i)
4671    self._parallel_iterations = self._pfor_input.get_attr("parallel_iterations")
4672
4673  def _output_shapes(self):
4674    # Calculate output shape for vectorized loop. This will be used as
4675    # shape_invariant. Merges shape inference outputs with the `output_shapes`
4676    # attribute of the op.
4677    output_shapes = [out.shape for out in self._pfor_input.op.outputs]
4678    shapes = self._pfor_input.get_attr("output_shapes")
4679    if not shapes:
4680      shapes = [tensor_shape.TensorShape(None) for _ in output_shapes]
4681    else:
4682      shapes = [tensor_shape.TensorShape(shape) for shape in shapes]
4683    for i, shape in enumerate(shapes):
4684      shape = shape.merge_with(output_shapes[i])
4685      pfor_input = self._pfor_input.input(i)
4686      if pfor_input.is_stacked:
4687        if _is_variant_with_internal_stacking(pfor_input.t):
4688          shape = tensor_shape.TensorShape([]).concatenate(shape)
4689        else:
4690          shape = tensor_shape.TensorShape([None]).concatenate(shape)
4691      output_shapes[i] = shape
4692    assert len(output_shapes) == self._pfor_input.num_inputs
4693    return output_shapes
4694
4695  def _init_values(self):
4696    """Create arguments passed to converted while_loop."""
4697    loop_len = self._pfor.loop_len_vector[0]
4698    inputs = []
4699    # TensorArrays for outputs of converted while loop
4700    output_tas = []
4701
4702    with ops.name_scope("while_init"):
4703      for inp in self._pfor_input.inputs:
4704        inputs.append(inp.t)
4705        variant_type_id = _variant_type_id(inp.t)
4706        if variant_type_id in _INTERNAL_STACKING_TYPE_IDS:
4707          if variant_type_id != full_type_pb2.TFT_ARRAY:
4708            raise NotImplementedError(
4709                "While loop conversion is only supported for TensorLists. Got "
4710                f"another variant {inp.t}, probably an optional. Please file "
4711                "a bug.")
4712
4713          # For TensorLists, the input format is:
4714          #
4715          #   List[user_list_len, Tensor[loop_len, ...]]
4716          #
4717          # rather than the usual
4718          #
4719          #   Tensor[loop_len, ...]
4720          #
4721          # The body of the loop will take and return lists in this "internal
4722          # vectorization" format, so we want to keep it that way as much as
4723          # possible. We'll accumulate finished iterations (only relevant for
4724          # pfor-loop-variant while_loop conditions) in an accumulator with
4725          # type :
4726          #
4727          #   List[user_list_len, List[loop_len, Tensor[...]]]
4728          #
4729          # This means that each while_loop iteration, we'll iterate over the
4730          # length of the TensorList, dividing done/remaining pfor loop indices
4731          # and scattering the done indices into the inner nested list of the
4732          # accumulator.
4733          element_shape = list_ops.tensor_list_element_shape(
4734              inp.t, dtypes.int32)
4735          if inp.is_stacked:
4736            # Shapes may be tf.constant(-1) for fully dynamic, in which case
4737            # slicing is an error.
4738            element_shape = control_flow_ops.cond(
4739                math_ops.equal(array_ops.rank(element_shape), 0),
4740                lambda: element_shape,
4741                lambda: element_shape[1:])
4742          dtype = _parse_variant_shapes_and_types(inp.t)[0].dtype
4743
4744          def _init_loop_body(index, output_ta):
4745            output_ta = output_ta.write(
4746                index,
4747                list_ops.tensor_list_reserve(element_shape, loop_len, dtype))
4748            return index + 1, output_ta
4749
4750          length = list_ops.tensor_list_length(inp.t)
4751          output_ta = tensor_array_ops.TensorArray(
4752            inp.t.dtype,  # Variant; this is a nested TensorList
4753            size=length,
4754            dynamic_size=True,
4755            infer_shape=False)
4756          _, output_ta = control_flow_ops.while_loop(
4757              lambda index, _: index < length,
4758              _init_loop_body,
4759              [0, output_ta])
4760        else:
4761          output_ta = tensor_array_ops.TensorArray(
4762            inp.t.dtype,
4763            size=loop_len,
4764            dynamic_size=False,
4765            infer_shape=True)
4766        output_tas.append(output_ta)
4767    # See documentation for __call__ for the structure of init_values.
4768    indices = (
4769        math_ops.range(self._pfor.loop_len_vector[0])
4770        if self._pfor.all_indices_partitioned else self._pfor.all_indices)
4771    return [True, indices] + inputs + output_tas
4772
4773  def _process_cond_unstacked(self, conditions, indices, inputs, output_tas):
4774    """Handles case when condition is pfor loop invariant."""
4775    # Note that all iterations end together. So we don't need to partition the
4776    # inputs.
4777    not_all_done = array_ops.reshape(conditions, [])
4778    return not_all_done, indices, inputs, output_tas
4779
4780  def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked,
4781                            output_tas):
4782    """Handles case when condition is pfor loop dependent."""
4783    # Compute if all iterations are done.
4784    not_all_done = math_ops.reduce_any(conditions)
4785    conditions_int = math_ops.cast(conditions, dtypes.int32)
4786    # Partition the indices.
4787    done_indices, new_indices = data_flow_ops.dynamic_partition(
4788        indices, conditions_int, 2)
4789
4790    new_inputs = []
4791    new_output_tas = []
4792    for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)):
4793      pass_through = i in self._body_pass_through_indices
4794      if not pass_through and  _variant_type_id(inp) == full_type_pb2.TFT_ARRAY:
4795        shape_and_type = _parse_variant_shapes_and_types(inp)[0]
4796        element_shape = list_ops.tensor_list_element_shape(inp, dtypes.int32)
4797        user_list_len = list_ops.tensor_list_length(inp)
4798
4799        def _split_vectorized_ta_element(index, new_inp, new_out_ta):
4800          elem = list_ops.tensor_list_get_item(inp, index, shape_and_type.dtype,
4801                                               element_shape)
4802          if stacked:
4803            done_elem, new_elem = data_flow_ops.dynamic_partition(
4804                elem, conditions_int, 2)
4805            new_inp = list_ops.tensor_list_set_item(new_inp, index, new_elem)
4806          else:
4807            done_elem = _stack(elem, [array_ops.size(done_indices)]).t
4808          done_accum = new_out_ta.read(index)
4809          done_accum = list_ops.tensor_list_scatter(
4810              tensor=done_elem, indices=done_indices, input_handle=done_accum)
4811          new_out_ta = new_out_ta.write(index, done_accum)
4812          return index + 1, new_inp, new_out_ta
4813
4814        length = list_ops.tensor_list_length(inp)
4815        new_inp = list_ops.tensor_list_reserve(
4816            tensor_shape.TensorShape([None])
4817            + tensor_shape.TensorShape(shape_and_type.shape)[1:],
4818            user_list_len, shape_and_type.dtype)
4819        _, new_inp, out_ta = control_flow_ops.while_loop(
4820            lambda index, unused_new_inp, unused_new_out_ta: index < length,
4821            _split_vectorized_ta_element,
4822            [0, new_inp, output_tas[i]])
4823      else:
4824        # Partition the inputs.
4825        if stacked:
4826          done_inp, new_inp = data_flow_ops.dynamic_partition(
4827              inp, conditions_int, 2)
4828        else:
4829          if not pass_through:
4830            done_inp = _stack(inp, [array_ops.size(done_indices)]).t
4831          new_inp = inp
4832
4833        out_ta = output_tas[i]
4834        if not pass_through:
4835          # Note that done_indices can be empty. done_inp should also be empty
4836          # in that case.
4837          out_ta = out_ta.scatter(done_indices, done_inp)
4838      new_inputs.append(new_inp)
4839      new_output_tas.append(out_ta)
4840
4841    assert len(new_output_tas) == len(output_tas)
4842    assert len(new_inputs) == len(inputs)
4843    return not_all_done, new_indices, new_inputs, new_output_tas
4844
4845  def _process_body(self, inputs_stacked, new_indices, cond_stacked,
4846                    new_inputs, not_all_done):
4847    """Convert the body function."""
4848    # This is used to store the indices of inputs to the while op that need to
4849    # be stacked. This stacking may be needed in cases where the input to the
4850    # while_loop is loop_invariant but the corresponding output is not.
4851    mismatching_stacked_indices = []
4852
4853    def true_fn():
4854      """Converts the body function for all but last iteration."""
4855      wrapped_inputs = [wrap(inp, stacked) for inp, stacked in
4856                        zip(new_inputs, inputs_stacked)]
4857      # Note the iterative process below to figure out loop invariance.
4858      # Here we iterate on vectorization process till a fixed point. The issue
4859      # is that the while body can take pfor loop invariant inputs but return
4860      # loop variant outputs. For any loop variant output, the corresponding
4861      # input has to be then made loop variant (since subsequent while
4862      # iterations will need to see loop variant values).
4863      # However once we make a new input loop variant, we might make other
4864      # outputs loop variant. Hence we need to iterate till we get fixed point.
4865      while True:
4866        if self._pfor.all_indices_partitioned:
4867          indices = array_ops.gather(self._pfor.all_indices, new_indices)
4868        else:
4869          indices = new_indices
4870        body_pfor = PFor(
4871            loop_var=self._pfor.loop_var,
4872            loop_len=array_ops.size(new_indices),
4873            pfor_ops=self._body_func.graph.get_operations(),
4874            fallback_to_while_loop=self._pfor.fallback_to_while_loop,
4875            all_indices=indices,
4876            all_indices_partitioned=(self._pfor.all_indices_partitioned or
4877                                     cond_stacked),
4878            pfor_config=self._pfor.pfor_config)
4879        stacking_mismatch = False
4880        outputs = _convert_function_call(self._body_func,
4881                                         body_pfor,
4882                                         wrapped_inputs)
4883        for i, (out, inp) in enumerate(zip(outputs, wrapped_inputs)):
4884          if out.is_stacked != inp.is_stacked:
4885            stacking_mismatch = True
4886            mismatching_stacked_indices.append(i)
4887            stacked = _stack(inp.t, [array_ops.size(new_indices)])
4888            if inp.t.dtype == dtypes.variant:
4889              stacked = wrap(
4890                  _tile_variant_with_length(stacked.t,
4891                                            [array_ops.size(new_indices)]))
4892            wrapped_inputs[i] = stacked
4893        if not stacking_mismatch:
4894          if mismatching_stacked_indices:
4895            # We needed to stack some inputs. This code will be abandoned and
4896            # should not get executed. Hence we simply return `new_inputs` to
4897            # make sure the graph construction code completes.
4898            with ops.control_dependencies([
4899                control_flow_ops.Assert(
4900                    False, ["pfor ERROR: this branch should never execute"])]):
4901              return [array_ops.identity(x) for x in new_inputs]
4902          else:
4903            return [out.t for out in outputs]
4904
4905    # If all are done, we simply return `new_inputs`. Else we need to run the
4906    # body function.
4907    return control_flow_ops.cond(
4908        not_all_done,
4909        true_fn,
4910        lambda: list(new_inputs)), mismatching_stacked_indices
4911
4912  def __call__(self):
4913    """Converter for the V2 while_loop.
4914
4915    The conversion of a while_loop is another while_loop.
4916
4917    The arguments to this converted while_loop are as follows:
4918    not_all_done: Boolean scalar Tensor indicating if all the pfor iterations
4919      are done.
4920    indices: int32 1-D Tensor storing the id of the pfor iterations that are not
4921      done.
4922    args: Remaining arguments. These can be divided into 2 categories:
4923      - The first set of arguments correspond one-to-one to the inputs to the
4924        unvectorized while_loop.
4925      - The second set are TensorArrays, corresponding one-to-one to each output
4926        of the unvectorized while_loop. Each TensorArray has `PFor.loop_len`
4927        elements, i.e. the number of pfor iterations. At the end, the i'th
4928        element of each TensorArray will contain the output computed by the i'th
4929        iteration of pfor. Note that elements can be written into these tensors
4930        arrays in any order, depending on when the corresponding pfor iteration
4931        is done.
4932    In each iteration, the while_loop body recomputes the condition for all
4933    active pfor iterations to see which of them are now done. It then partitions
4934    all the inputs and passes them along to the converted body. Values for all
4935    the iterations that are done are written to TensorArrays indexed by the pfor
4936    iteration number. When all iterations are done, the TensorArrays are stacked
4937    to get the final value.
4938
4939    Returns:
4940      List of converted outputs.
4941    """
4942    output_shapes = self._output_shapes()
4943    # Note that we use these lists as a hack since we need the `body` to compute
4944    # these values during construction of the while_loop graph.
4945    cond_is_stacked = [None]
4946    indices_to_stack = []
4947
4948    def cond(not_all_done, *_):
4949      return not_all_done
4950
4951    def body(not_all_done, indices, *args):
4952      # See documentation for __call__ for the structure of *args.
4953      num_inputs = self._pfor_input.num_inputs
4954      inputs = args[:num_inputs]
4955      output_tas = args[num_inputs:]
4956      inputs_stacked = [x.is_stacked for x in self._pfor_input.inputs]
4957      assert len(inputs) >= len(output_tas)
4958      assert len(inputs) == len(inputs_stacked)
4959      # Convert condition
4960      with ops.name_scope("while_cond"):
4961        # Note that we set all_indices_partitioned to True here. At this point
4962        # we don't know if indices will be partitioned. Hence we use the
4963        # conservative value.
4964        cond_pfor = PFor(
4965            loop_var=self._pfor.loop_var,
4966            loop_len=array_ops.size(indices),
4967            pfor_ops=self._cond_func.graph.get_operations(),
4968            fallback_to_while_loop=self._pfor.fallback_to_while_loop,
4969            all_indices=indices,
4970            all_indices_partitioned=True,
4971            pfor_config=self._pfor.pfor_config)
4972
4973        wrapped_inputs = [wrap(inp, stacked) for inp, stacked
4974                          in zip(inputs, inputs_stacked)]
4975        conditions, cond_stacked, _ = _convert_function_call(
4976            self._cond_func,
4977            cond_pfor,
4978            wrapped_inputs)[0]
4979        cond_is_stacked[0] = cond_stacked
4980
4981      # Recompute the new condition, write outputs of done iterations, and
4982      # partition the inputs if needed.
4983      if not cond_stacked:
4984        (not_all_done, new_indices, new_inputs,
4985         new_output_tas) = self._process_cond_unstacked(conditions, indices,
4986                                                        inputs, output_tas)
4987      else:
4988        (not_all_done, new_indices, new_inputs,
4989         new_output_tas) = self._process_cond_stacked(conditions, indices,
4990                                                      inputs, inputs_stacked,
4991                                                      output_tas)
4992      # Convert body
4993      with ops.name_scope("while_body"):
4994        #  Compute the outputs from the body.
4995        new_outputs, mismatching_stacked_indices = self._process_body(
4996            inputs_stacked, new_indices, cond_stacked, new_inputs, not_all_done)
4997
4998      indices_to_stack[:] = mismatching_stacked_indices
4999      for i, new_output in enumerate(new_outputs):
5000        new_output.set_shape(output_shapes[i])
5001      new_args = ([not_all_done, new_indices] + new_outputs +
5002                  list(new_output_tas))
5003      return tuple(new_args)
5004
5005    # Note that we run the code below in a function since we might abandon the
5006    # generated code in cases where the conversion dictates that some inputs be
5007    # further stacked. Hence we run the graph construction using
5008    # `get_concrete_function` and avoid calling the constructed function if not
5009    # needed.
5010    @def_function.function
5011    def while_fn():
5012      # Create init_values that will be passed to the while_loop.
5013      init_values = self._init_values()
5014      ta_shape_invariants = [tensor_shape.TensorShape([]) for _ in
5015                             self._pfor_input.outputs]
5016      shape_invariants = (
5017          [tensor_shape.TensorShape([]), tensor_shape.TensorShape([None])]
5018          + output_shapes + ta_shape_invariants)
5019
5020      while_outputs = control_flow_ops.while_loop(
5021          cond, body, init_values,
5022          shape_invariants=shape_invariants,
5023          parallel_iterations=self._parallel_iterations)
5024      if indices_to_stack:
5025        # This function will be abandoned.
5026        return while_outputs
5027      else:
5028        num_inputs = self._pfor_input.num_inputs
5029        new_inputs = while_outputs[2:num_inputs+2]
5030        output_tas = while_outputs[num_inputs+2:]
5031        assert cond_is_stacked[0] is not None
5032        outputs = []
5033        for i, inp in enumerate(new_inputs):
5034          if cond_is_stacked[0]:
5035            if i in self._body_pass_through_indices:
5036              outputs.append(init_values[i + 2])
5037            else:
5038              ta = output_tas[i]
5039              if _variant_type_id(inp) == full_type_pb2.TFT_ARRAY:
5040                shape_and_type = _parse_variant_shapes_and_types(inp)[0]
5041                length = list_ops.tensor_list_length(inp)
5042
5043                # We have been accumulating values in a:
5044                #
5045                #   List[user_list_len, List[loop_len, Tensor[...]]]
5046                #
5047                # We want to return an output in the same format as the input:
5048                #
5049                #   List[user_list_len, Tensor[loop_len, ...]]
5050                #
5051                # So we need to loop over the list and stack its contents.
5052                def _stack_loop_body(index, output_list):
5053                  current_value = ta.read(index)
5054                  output_list = list_ops.tensor_list_set_item(
5055                      output_list, index,
5056                      list_ops.tensor_list_stack(
5057                          current_value, shape_and_type.dtype))
5058                  return index + 1, output_list
5059
5060                output_list = list_ops.tensor_list_reserve(
5061                    tensor_shape.TensorShape(shape_and_type.shape), length,
5062                    shape_and_type.dtype)
5063                _, output_list = control_flow_ops.while_loop(
5064                    lambda index, _: index < length,
5065                    _stack_loop_body,
5066                    [0, output_list])
5067                outputs.append(output_list)
5068              else:
5069                outputs.append(ta.stack())
5070          else:
5071            outputs.append(inp)
5072        return outputs
5073
5074    _ = while_fn.get_concrete_function()
5075    if indices_to_stack:
5076      # Need to abandon the current conversion, stack some inputs and restart.
5077      self._pfor_input.stack_inputs(
5078          stack_indices=indices_to_stack, tile_variants=True)
5079      # Note that this call will recurse at most one time. The first call will
5080      # do the required stacking, based on the iterative procedure in
5081      # _process_body, and the next invocation to __call__ should not need to do
5082      # any more stacking.
5083      # We invoke `self()` here as a way to discard any corrupted state.
5084      return self()
5085    else:
5086      outputs = while_fn()
5087      wrapped_outputs = []
5088      for i, (out, inp) in enumerate(zip(outputs, self._pfor_input.inputs)):
5089        if i not in self._body_pass_through_indices and cond_is_stacked[0]:
5090          wrapped_outputs.append(wrap(out, True))
5091        else:
5092          wrapped_outputs.append(wrap(out, inp.is_stacked))
5093      return wrapped_outputs
5094
5095
5096@RegisterPFor("StatelessWhile")
5097@RegisterPFor("While")
5098def _convert_while(pfor_input):
5099  converter = WhileV2(pfor_input)
5100  return converter()
5101
5102
5103# spectral_ops
5104
5105
5106@RegisterPForWithArgs("FFT", gen_spectral_ops.fft)
5107@RegisterPForWithArgs("FFT2D", gen_spectral_ops.fft2d)
5108@RegisterPForWithArgs("FFT3D", gen_spectral_ops.fft3d)
5109@RegisterPForWithArgs("IFFT", gen_spectral_ops.ifft)
5110@RegisterPForWithArgs("IFFT2D", gen_spectral_ops.ifft2d)
5111@RegisterPForWithArgs("IFFT3D", gen_spectral_ops.ifft3d)
5112def _convert_fft(pfor_input, _, op_func):
5113  return wrap(op_func(pfor_input.stacked_input(0)), True)
5114
5115
5116@RegisterPForWithArgs("RFFT", gen_spectral_ops.rfft, "Tcomplex")
5117@RegisterPForWithArgs("RFFT2D", gen_spectral_ops.rfft2d, "Tcomplex")
5118@RegisterPForWithArgs("RFFT3D", gen_spectral_ops.rfft3d, "Tcomplex")
5119@RegisterPForWithArgs("IRFFT", gen_spectral_ops.irfft, "Treal")
5120@RegisterPForWithArgs("IRFFT2D", gen_spectral_ops.irfft2d, "Treal")
5121@RegisterPForWithArgs("IRFFT3D", gen_spectral_ops.irfft3d, "Treal")
5122def _convert_rfft(pfor_input, _, op_func, attr_name):
5123  inp = pfor_input.stacked_input(0)
5124  fft_length = pfor_input.unstacked_input(1)
5125  attr = pfor_input.get_attr(attr_name)
5126  return wrap(op_func(inp, fft_length, attr), True)
5127