• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Compiled parallel-for loop."""
16# pylint: disable=missing-docstring,g-direct-tensorflow-import
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import string
24import sys
25import traceback
26
27import numpy as np
28import six
29
30from tensorflow.compiler.tf2xla.python import xla
31from tensorflow.core.framework import full_type_pb2
32from tensorflow.python.eager import context
33from tensorflow.python.eager import def_function
34from tensorflow.python.eager import execute
35from tensorflow.python.framework import constant_op
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import func_graph
38from tensorflow.python.framework import ops
39from tensorflow.python.framework import smart_cond
40from tensorflow.python.framework import sparse_tensor
41from tensorflow.python.framework import tensor_shape
42from tensorflow.python.framework import tensor_spec
43from tensorflow.python.framework import tensor_util
44from tensorflow.python.ops import array_ops
45from tensorflow.python.ops import control_flow_ops
46from tensorflow.python.ops import data_flow_ops
47from tensorflow.python.ops import gen_array_ops
48from tensorflow.python.ops import gen_dataset_ops
49from tensorflow.python.ops import gen_image_ops
50from tensorflow.python.ops import gen_linalg_ops
51from tensorflow.python.ops import gen_list_ops
52from tensorflow.python.ops import gen_math_ops
53from tensorflow.python.ops import gen_nn_ops
54from tensorflow.python.ops import gen_parsing_ops
55from tensorflow.python.ops import gen_random_ops
56from tensorflow.python.ops import gen_sparse_ops
57from tensorflow.python.ops import gen_spectral_ops
58from tensorflow.python.ops import handle_data_util
59from tensorflow.python.ops import linalg_ops
60from tensorflow.python.ops import list_ops
61from tensorflow.python.ops import manip_ops
62from tensorflow.python.ops import map_fn
63from tensorflow.python.ops import math_ops
64from tensorflow.python.ops import nn_ops
65from tensorflow.python.ops import parsing_ops
66from tensorflow.python.ops import resource_variable_ops
67from tensorflow.python.ops import sparse_ops
68from tensorflow.python.ops import special_math_ops
69from tensorflow.python.ops import tensor_array_ops
70from tensorflow.python.platform import flags
71from tensorflow.python.platform import tf_logging as logging
72from tensorflow.python.util import compat
73from tensorflow.python.util import nest
74from tensorflow.python.util import object_identity
75
76
77# TODO(agarwal): remove flag.
78flags.DEFINE_bool(
79    "op_conversion_fallback_to_while_loop", True,
80    "DEPRECATED: Flag is ignored.")
81
82
83def _variant_handle_data(t):
84  """Fetches handle data for a variant tensor `t`, or None if unavailable."""
85  handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
86  if not handle_data.is_set:
87    return None
88  return handle_data.shape_and_type
89
90
91def _variant_type_id(t):
92  """Returns the full_type_pb2 type of `t`, or None if it is not available."""
93  if t.dtype != dtypes.variant:
94    return None
95  shapes_and_types = _variant_handle_data(t)
96  if shapes_and_types is None or not shapes_and_types:
97    # TODO(b/169968286): Identify all variant tensors (e.g. maps) and we can
98    # make this an error instead of assuming TensorLists have handle data.
99    return None  # Presumed not a TensorList/Optional
100  return shapes_and_types[0].type.type_id
101
102
103_INTERNAL_STACKING_TYPE_IDS = (
104    full_type_pb2.TFT_ARRAY,
105    full_type_pb2.TFT_OPTIONAL)
106
107
108def _is_variant_with_internal_stacking(t):
109  """Identifies variant tensors which pfor always maintains as scalars.
110
111  For these, the pfor tensor is recorded as "stacked" if the content of the
112  variant tensor (e.g. the elements of a TensorList) are all stacked.
113
114  Args:
115    t: A tensor to identify.
116  Returns:
117    True if `t` is a TensorList/Optional, False not, None if unknown.
118  """
119  type_id = _variant_type_id(t)
120  return type_id in _INTERNAL_STACKING_TYPE_IDS
121
122
123def _parse_variant_shapes_and_types(t):
124  """Extracts shape and dtype information from a variant tensor `t`."""
125  shapes_and_types = _variant_handle_data(t)
126  if shapes_and_types is None or not shapes_and_types:
127    raise ValueError("Required handle data not set for {!r}".format(t))
128  if shapes_and_types[0].type.type_id == full_type_pb2.TFT_ARRAY:
129    return shapes_and_types
130  else:
131    if shapes_and_types[0].type.type_id == full_type_pb2.TFT_UNSET:
132      return shapes_and_types
133    else:
134      raise ValueError(
135          "Attempted to stack a variant-dtype tensor with no type set ({!r})"
136          .format(t))
137
138
139def _stack(t, length):
140  """stacks `t` `length` times."""
141  # Note that this stacking may currently be triggered, for example, when a
142  # loop invariant tensor with dtype variant is input to a while_loop which then
143  # produces a loop dependent output. Simply stacking the variants may not be
144  # suitable since operations on stacked handles may expect a vectorized version
145  # of the variant.
146  if t.dtype == dtypes.variant:
147    shapes_and_types = _parse_variant_shapes_and_types(t)
148    if shapes_and_types[0].type.type_id == full_type_pb2.TFT_ARRAY:
149      if len(shapes_and_types) != 1:
150        raise ValueError(
151            "Expected handle data of length 1, got {!r} of length {}"
152            .format(shapes_and_types, len(shapes_and_types)))
153      return wrap(
154          _stack_tensor_list(t, shapes_and_types[0].dtype, length),
155          True)
156    else:
157      raise ValueError(
158          ("Attempted to stack an unhandled variant-dtype tensor of "
159           "type {!r} ({!r})").format(shapes_and_types[0].type, t))
160  ones = array_ops.ones_like(array_ops.shape(t))
161  ones = array_ops.reshape(ones, [-1])
162  length = array_ops.reshape(length, [-1])
163  multiples = array_ops.concat([length, ones], 0)
164  t = array_ops.tile(array_ops.expand_dims(t, 0), multiples)
165  return wrap(t, True)
166
167
168# The following stateful ops can be safely called once, and with the same
169# signature as the unconverted version, if their inputs are loop invariant.
170# TODO(agarwal): implement a strategy for converting Variable reads/writes. The
171# plan is to map each read/write in the loop_fn to a corresponding merged
172# read/write in the converted graph. Writes need to be mergeable (e.g.
173# AssignAdd) to be used in `pfor`. Given a certain read/write order in the
174# loop_fn, doing a one-to-one conversion will simulate executing such
175# instructions in lock-step across all iterations.
176passthrough_stateful_ops = set([
177    "VariableV2",
178    "VarHandleOp",
179    "VariableShape",
180    "ReadVariableOp",
181    "StackV2",
182    "TensorArrayWriteV3",
183    "TensorArrayReadV3",
184    "TensorArraySizeV3",
185])
186
187
188# Ops which we will treat like stateful for the purpose of vectorization.
189# Typically this is used to force pfor converters to run for these ops.
190force_stateful_ops = set([
191    # We vectorize this since we need to change the element shape set on the
192    # list.
193    "TensorListReserve",
194])
195
196
197def _is_stateful_pfor_op(op):
198  if isinstance(op, WhileOp):
199    return op.is_stateful
200  if op.type == "Const":
201    # Const didn't have an op_def.
202    return False
203  if op.type in passthrough_stateful_ops:
204    return False
205  if op.type in force_stateful_ops:
206    return True
207  assert hasattr(op, "op_def") and op.op_def is not None, op
208  return op.op_def.is_stateful
209
210
211# pylint: disable=protected-access
212class WhileOp(object):
213  """Object for storing state for converting the outputs of a while_loop."""
214
215  def __init__(self, exit_node, pfor_ops, fallback_to_while_loop, pfor_config):
216    """Initializer.
217
218    Args:
219      exit_node: A tensor output from the while_loop.
220      pfor_ops: list of ops inside the current pfor loop.
221      fallback_to_while_loop: If True, fallback to while loop when conversion of
222        an op is not supported
223      pfor_config: PForConfig object used while constructing loop body.
224    """
225    self._fallback_to_while_loop = fallback_to_while_loop
226    self._pfor_config = pfor_config
227    self._pfor_ops = set(pfor_ops)
228    self._pfor_op_ids = set(x._id for x in pfor_ops)
229    assert isinstance(exit_node, ops.Tensor)
230    self._while_context = exit_node.op._get_control_flow_context()
231    assert isinstance(self._while_context, control_flow_ops.WhileContext)
232    self._context_name = self._while_context.name
233    self._condition = self._while_context.pivot.op.inputs[0]
234    # Parts of an external while_loop could be created inside a pfor loop.
235    # However for the purpose here, we declare such loops to be external. Also
236    # note that we check if the condition was created inside or outside to
237    # determine if the while_loop was first created inside or outside.
238    # TODO(agarwal): check that the Enter and Exit of this loop are unstacked.
239    self._is_inside_loop = self.op_is_inside_loop(self._condition.op)
240    if self._is_inside_loop:
241      for e in self._while_context.loop_exits:
242        assert self.op_is_inside_loop(e.op)
243
244    # Note the code below tries to reverse engineer an existing while_loop graph
245    # by assuming the following pattern of nodes.
246    #
247    #          NextIteration <---- Body <--- Enter
248    #              |                ^
249    #              V             ___| Y
250    #    Enter -> Merge -> Switch___
251    #                       ^       | N
252    #                       |       V
253    #                  LoopCond    Exit
254
255    # Node that elements in the list below correspond one-to-one with each
256    # other. i.e. these lists are the same size, and the i_th entry corresponds
257    # to different Operations/Tensors of a single cycle as illustrated above.
258    # List of Switch ops (ops.Operation) that feed into an Exit Node.
259    self._exit_switches = []
260    # List of inputs (ops.Tensor) to NextIteration.
261    self._body_outputs = []
262    # List of list of control inputs of the NextIteration nodes.
263    self._next_iter_control_inputs = []
264    # List of Merge ops (ops.Operation).
265    self._enter_merges = []
266    # List of output (ops.Tensor) of Exit nodes.
267    self._outputs = []
268
269    # List of Enter Tensors.
270    # There are two types of Enter nodes:
271    # - The Enter nodes that are used in the `loop_vars` argument to
272    # `while_loop` (see
273    # https://www.tensorflow.org/api_docs/python/tf/while_loop). We collect
274    # these Enter nodes immediately below by tracing backwards from the Exit
275    # nodes via Exit <- Switch <- Merge <- Enter. You can see this chain in the
276    # diagram above. This allows us to have a 1:1 correspondence between the
277    # self._outputs and the first elements in self._enters.
278    # - The Enter nodes that are used only by the body. They don't appear in the
279    # `loop_vars` and are not returned from the `while_loop`. In Python code,
280    # they are usually captured by the body lambda. We collect them below by
281    # iterating over all the ops in the graph. They are appended to the end of
282    # self._enters or self._direct_enters, and don't correspond to any outputs
283    # in self._outputs. Note that we keep the resource/variant Enter nodes in
284    # self._direct_enters and the constructed while_loop's body uses them
285    # directly as opposed to passing them as loop variables. This is done
286    # because the while_body cannot partition the resource/variant Tensors, so
287    # it has to leave them unchanged.
288    self._enters = []
289    self._direct_enters = []
290
291    for e in self._while_context.loop_exits:
292      self._outputs.append(e.op.outputs[0])
293      switch = e.op.inputs[0].op
294      assert switch.type == "Switch", switch
295      self._exit_switches.append(switch)
296      merge = switch.inputs[0].op
297      assert merge.type == "Merge", merge
298      self._enter_merges.append(merge)
299      enter = merge.inputs[0].op
300      assert enter.type == "Enter", enter
301      self._enters.append(enter.outputs[0])
302      next_iter = merge.inputs[1].op
303      assert next_iter.type == "NextIteration", next_iter
304      self._body_outputs.append(next_iter.inputs[0])
305      self._next_iter_control_inputs.append(next_iter.control_inputs)
306
307    # Collect all the Enter nodes that are not part of `loop_vars`, the second
308    # category described above.
309    # Also track whether the loop body has any stateful ops.
310    self._is_stateful = False
311    for op in ops.get_default_graph().get_operations():
312      # TODO(agarwal): make sure this works with nested case.
313      control_flow_context = op._get_control_flow_context()
314      if control_flow_context is None:
315        continue
316      if control_flow_context.name == self._context_name:
317        self._is_stateful |= _is_stateful_pfor_op(op)
318        if op.type == "Enter":
319          output = op.outputs[0]
320          if output not in self._enters:
321            if output.dtype in (dtypes.resource, dtypes.variant):
322              if output not in self._direct_enters:
323                self._direct_enters.append(output)
324            else:
325              self._enters.append(output)
326
327  def __str__(self):
328    """String representation."""
329    return "while_loop(%s)" % self.name
330
331  @property
332  def inputs(self):
333    """Input to all the Enter nodes."""
334    return [x.op.inputs[0] for x in self._enters + self._direct_enters]
335
336  @property
337  def control_inputs(self):
338    """Control input to all the Enter nodes."""
339    control_inputs = []
340    for x in self._enters + self._direct_enters:
341      control_inputs.extend(x.op.control_inputs)
342    return control_inputs
343
344  @property
345  def outputs(self):
346    """Outputs of all the Exit nodes."""
347    return self._outputs
348
349  @property
350  def name(self):
351    """Context name for the while loop."""
352    return self._context_name
353
354  @property
355  def is_inside_loop(self):
356    """Returns true if the while_loop was created inside the pfor."""
357    return self._is_inside_loop
358
359  def op_is_inside_loop(self, op):
360    """True if op was created inside the pfor loop body."""
361    assert isinstance(op, ops.Operation)
362    # Note that we use self._pfor_op_ids for the check and not self._pfor_ops
363    # since it appears there tensorflow API could return different python
364    # objects representing the same Operation node.
365    return op._id in self._pfor_op_ids
366
367  @property
368  def is_stateful(self):
369    return self._is_stateful
370
371  @property
372  def pfor_converter(self):
373    """Return a converter for the while loop."""
374    return self
375
376  def _init_pfor(self, parent_pfor, indices, cond_stacked, inputs,
377                 inputs_stacked):
378    """Create a PFor object for converting parts of the while_loop.
379
380    Args:
381      parent_pfor: PFor object being used for converting the while_loop.
382      indices: int32 Tensor of ids for the iterations that are still active
383        (i.e. did not exit the while_loop).
384      cond_stacked: True if the while_loop condition is stacked.
385      inputs: list of input Tensors corresponding 1-to-1 with self._enters. Note
386        that these Tensors are a subset of the loop variables for the generated
387        while_loop.
388      inputs_stacked: List of booleans corresponding 1-to-1 with `inputs`,
389        indicating if the value is stacked or not.
390
391    Returns:
392      A PFor instance. The instance is initialized by adding conversion mappings
393        of nodes that will be external to the conversion that the returned
394        instance will be used for. e.g. Enter nodes as well as Merge and Switch
395        outputs are mapped to converted values.
396    """
397    num_outputs = len(self._outputs)
398    assert len(inputs) == len(self._enters)
399    assert len(inputs_stacked) == len(self._enters)
400    loop_var = parent_pfor.loop_var
401    loop_len = array_ops.size(indices)
402    pfor = PFor(
403        loop_var,
404        loop_len,
405        pfor_ops=self._pfor_ops,
406        all_indices=indices,
407        all_indices_partitioned=cond_stacked,
408        fallback_to_while_loop=self._fallback_to_while_loop,
409        pfor_config=self._pfor_config)
410    # Map all inputs of Enter nodes in self._direct_enters to their converted
411    # values.
412    for enter in self._direct_enters:
413      enter_input = enter.op.inputs[0]
414      converted_enter, stacked, is_sparse_stacked = parent_pfor._convert_helper(
415          enter_input)
416      # Since these are resources / variants, they should be unstacked.
417      assert not stacked and not is_sparse_stacked, (enter, converted_enter)
418      pfor._add_conversion(enter, wrap(converted_enter, False))
419
420    # Map all Enter nodes to the inputs.
421    for enter, inp, stacked in zip(self._enters, inputs, inputs_stacked):
422      pfor._add_conversion(enter, wrap(inp, stacked))
423    # Map outputs of Switch and Merge.
424    for i in range(num_outputs):
425      wrapped_inp = wrap(inputs[i], inputs_stacked[i])
426      merge = self._enter_merges[i]
427      pfor._add_conversion(merge.outputs[0], wrapped_inp)
428      # Note that second output of Merge is typically not used, except possibly
429      # as a control dependency. To avoid trying to output the correct value, we
430      # employ a hack here. We output a dummy invalid value with an incorrect
431      # dtype. This will allow control dependency to work but if using it as an
432      # input, it should typically lead to errors during graph construction due
433      # to dtype mismatch.
434      # TODO(agarwal): Check in the original graph to see if there are any
435      # consumers of this Tensor that use it as an input.
436      pfor._add_conversion(merge.outputs[1],
437                           wrap(constant_op.constant(-1.0), False))
438      switch = self._exit_switches[i]
439      # Don't need to worry about switch.output[0] which will feed to Exit node.
440      pfor._add_conversion(switch.outputs[1], wrapped_inp)
441    return pfor
442
443  def _convert_enter(self, parent_pfor, enter):
444    """Converts an Enter node."""
445    inp, stacked, _ = parent_pfor._convert_helper(enter.op.inputs[0])
446    control_inputs = []
447    for x in enter.op.control_inputs:
448      converted = parent_pfor._convert_helper(x)
449      if not isinstance(converted, ops.Operation):
450        converted = converted.t
451      control_inputs.append(converted)
452    if control_inputs:
453      with ops.control_dependencies(control_inputs):
454        inp = array_ops.identity(inp)
455    return inp, stacked
456
457  def _maybe_stacked(self, cache, inp):
458    """Heuristic to figure out if the converting inp leads to a stacked value.
459
460
461    Args:
462      cache: map from Tensor to boolean indicating stacked/unstacked.
463      inp: input Tensor.
464
465    Returns:
466      True if `inp` could get stacked. If the function returns False, the
467      converted value should be guaranteed to be unstacked. If returning True,
468      it may or may not be stacked.
469    """
470    if inp in cache:
471      return cache[inp]
472    if not self.op_is_inside_loop(inp.op):
473      return False
474    op = inp.op
475    output = False
476    if op.type in [
477        "Shape",
478        "Rank",
479        "ShapeN",
480        "ZerosLike",
481        "TensorArrayV3",
482        "TensorArraySizeV3",
483    ]:
484      output = False
485    elif _is_stateful_pfor_op(op):
486      # This may be fairly aggressive.
487      output = True
488    elif op.type == "Exit":
489      # This may be fairly aggressive.
490      output = True
491    else:
492      for t in op.inputs:
493        if self._maybe_stacked(cache, t):
494          output = True
495          break
496    cache[inp] = output
497    return output
498
499  def _create_init_values(self, pfor_input):
500    """Create arguments passed to converted while_loop."""
501    with ops.name_scope("while_init"):
502      loop_len_vector = pfor_input.pfor.loop_len_vector
503      loop_len = loop_len_vector[0]
504      num_outputs = len(self._outputs)
505
506      inputs = []
507      maybe_stacked_cache = {}
508      # Convert all the Enters. Need to do this before checking for stacking
509      # below.
510      for i, enter in enumerate(self._enters):
511        inp, stacked = self._convert_enter(pfor_input.pfor, enter)
512        inputs.append(inp)
513        maybe_stacked_cache[enter] = stacked
514        # Since this enter node is part of the `loop_vars`, it corresponds to an
515        # output and its preceding switch. We mark this switch's output the same
516        # stackness, to act at the base case for the logic below. Below, we will
517        # be going through the body figuring out which inputs might need to be
518        # stacked and which inputs can safely remain unstacked.
519        if i < num_outputs:
520          maybe_stacked_cache[self._exit_switches[i].outputs[1]] = stacked
521
522      # Shape invariants for init_values corresponding to self._enters.
523      input_shape_invariants = []
524      # TensorArrays for outputs of converted while loop
525      output_tas = []
526      # Shape invariants for output TensorArrays.
527      ta_shape_invariants = []
528      # List of booleans indicating stackness of inputs, i.e. tensors
529      # corresponding to self._enters.
530      inputs_stacked = []
531      for i, inp in enumerate(inputs):
532        enter = self._enters[i]
533        inp_stacked = self._maybe_stacked(maybe_stacked_cache, enter)
534        # Note that even when an input is unstacked, the body could make it
535        # stacked. we use a heuristic below to figure out if body may be making
536        # it stacked.
537        if i < num_outputs:
538          body_output = self._body_outputs[i]
539          if enter.op in self._pfor_ops:
540            body_output_stacked = self._maybe_stacked(maybe_stacked_cache,
541                                                      body_output)
542          else:
543            # If constructed outside of pfor loop, then the output would not be
544            # stacked.
545            body_output_stacked = False
546          if body_output_stacked and not inp_stacked:
547            inp = _stack(inp, loop_len_vector).t
548            inputs[i] = inp
549            inp_stacked = True
550          # TODO(agarwal): other attributes for the TensorArray ?
551          output_tas.append(tensor_array_ops.TensorArray(inp.dtype, loop_len))
552          ta_shape_invariants.append(tensor_shape.TensorShape(None))
553
554        inputs_stacked.append(inp_stacked)
555        input_shape_invariants.append(tensor_shape.TensorShape(None))
556
557      # See documentation for __call__ for the structure of init_values.
558      init_values = [True, pfor_input.pfor.all_indices] + inputs + output_tas
559      # TODO(agarwal): try stricter shape invariants
560      shape_invariants = (
561          [tensor_shape.TensorShape(None),
562           tensor_shape.TensorShape(None)] + input_shape_invariants +
563          ta_shape_invariants)
564
565      return init_values, inputs_stacked, shape_invariants
566
567  def _process_cond_unstacked(self, conditions, indices, inputs, output_tas):
568    """Handles case when condition is unstacked.
569
570    Note that all iterations end together. So we don't need to partition the
571    inputs. When all iterations are done, we write the inputs to the
572    TensorArrays. Note that we only write to index 0 of output_tas. Since all
573    iterations end together, they can all be output together.
574    """
575    not_all_done = array_ops.reshape(conditions, [])
576    new_output_tas = []
577    # pylint: disable=cell-var-from-loop
578    for i, out_ta in enumerate(output_tas):
579      inp = inputs[i]
580      new_output_tas.append(
581          control_flow_ops.cond(not_all_done, lambda: out_ta,
582                                lambda: out_ta.write(0, inp)))
583    # pylint: enable=cell-var-from-loop
584    return not_all_done, indices, inputs, new_output_tas
585
586  def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked,
587                            output_tas):
588    num_outputs = len(self._outputs)
589    # Compute if all iterations are done.
590    not_all_done = math_ops.reduce_any(conditions)
591    conditions_int = math_ops.cast(conditions, dtypes.int32)
592    # Partition the indices.
593    done_indices, new_indices = data_flow_ops.dynamic_partition(
594        indices, conditions_int, 2)
595
596    new_inputs = []
597    new_output_tas = []
598    for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)):
599      # Partition the inputs.
600      if stacked:
601        done_inp, new_inp = data_flow_ops.dynamic_partition(
602            inp, conditions_int, 2)
603      else:
604        # TODO(agarwal): avoid this stacking. See TODO earlier in
605        # _process_cond_unstacked.
606        done_inp = _stack(inp, [array_ops.size(done_indices)]).t
607        new_inp = inp
608      new_inputs.append(new_inp)
609      # For iterations that are done, write them to TensorArrays.
610      if i < num_outputs:
611        out_ta = output_tas[i]
612        # Note that done_indices can be empty. done_inp should also be empty in
613        # that case.
614        new_output_tas.append(out_ta.scatter(done_indices, done_inp))
615    return not_all_done, new_indices, new_inputs, new_output_tas
616
617  def _process_body(self, pfor_input, inputs_stacked, new_indices, cond_stacked,
618                    new_inputs, not_all_done):
619    """Convert the body function."""
620
621    def true_fn(control_inputs, body_pfor, body_output, stacked):
622      """Converts the body function for all but last iteration.
623
624      This essentially converts body_output. Additionally, it needs to handle
625      any control dependencies on the NextIteration node. So it creates another
626      Identity node with the converted dependencies.
627      """
628      converted_control_inp = []
629      for x in control_inputs:
630        for t in x.outputs:
631          converted_control_inp.append(body_pfor._convert_helper(t).t)
632      if stacked:
633        # Note convert always does the stacking.
634        output = body_pfor.convert(body_output)
635      else:
636        output, convert_stacked, _ = body_pfor._convert_helper(body_output)
637        assert convert_stacked == stacked, body_output
638      with ops.control_dependencies(converted_control_inp):
639        return array_ops.identity(output)
640
641    body_pfor = self._init_pfor(pfor_input.pfor, new_indices, cond_stacked,
642                                new_inputs, inputs_stacked)
643    new_outputs = []
644
645    for i, (body_output,
646            stacked) in enumerate(zip(self._body_outputs, inputs_stacked)):
647      control_inp = self._next_iter_control_inputs[i]
648      out_dtype = body_output.dtype
649      # Note that we want to run the body only if not all pfor iterations are
650      # done. If all are done, we return empty tensors since these values will
651      # not be used. Notice that the value returned by the loop is based on
652      # TensorArrays and not directly on these returned values.
653      # pylint: disable=cell-var-from-loop
654      new_output = control_flow_ops.cond(
655          not_all_done,
656          lambda: true_fn(control_inp, body_pfor, body_output, stacked),
657          lambda: constant_op.constant([], dtype=out_dtype))
658      # pylint: enable=cell-var-from-loop
659      new_outputs.append(new_output)
660    return new_outputs
661
662  def __call__(self, pfor_input):
663    """Converter for the while_loop.
664
665    The conversion of a while_loop is another while_loop.
666
667    The arguments to this converted while_loop are as follows:
668    not_all_done: Boolean scalar Tensor indicating if all the pfor iterations
669      are done.
670    indices: int32 1-D Tensor storing the id of the iterations that are not
671      done.
672    args: Remaining arguments. These can be divided into 3 categories:
673      - First set of arguments are the tensors that correspond to the initial
674        elements of self._enters. The elements that appear in original while
675        loop's `loop_vars`.
676      - The second set of arguments are the tensors that correspond to the
677        remaining elements of self._enters. These are the tensors that directly
678        enter the original while loop body.
679       - Finally, the last set of arguments are TensorArrays. These TensorArrays
680         correspond to the outputs of the original while_loop, i.e. to the
681         elements in self._outputs. Each TensorArray has `PFor.loop_len`
682         elements, i.e. the number of pfor iterations. At the end, the i'th
683         element of each TensorArray will contain the output computed by the
684         i'th iteration of pfor. Note that elements can be written into these
685         tensors arrays in any order, depending on when the corresponding pfor
686         iteration is done.
687      If the original while_loop had `k` tensors in its `loop_vars` and its body
688      directly captured `m` tensors, the `args` will contain `2 * k + m` values.
689
690    In each iteration, the while_loop body recomputes the condition for all
691    active pfor iterations to see which of them are now done. It then partitions
692    all the inputs and passes them along to the converted body. Values for all
693    the iterations that are done are written to TensorArrays indexed by the pfor
694    iteration number. When all iterations are done, the TensorArrays are stacked
695    to get the final value.
696
697    Args:
698      pfor_input: A PForInput object corresponding to the output of any Exit
699        node from this while loop.
700
701    Returns:
702      List of converted outputs.
703    """
704    # Create init_values that will be passed to the while_loop.
705    init_values, inputs_stacked, shape_invariants = self._create_init_values(
706        pfor_input)
707    # Note that we use a list as a hack since we need the nested function body
708    # to set the value of cond_is_stacked. python2.x doesn't support nonlocal
709    # variables.
710    cond_is_stacked = [None]
711
712    def cond(not_all_done, *_):
713      return not_all_done
714
715    def body(not_all_done, indices, *args):
716      # See documentation for __call__ for the structure of *args.
717      num_enters = len(self._enters)
718      inputs = args[:num_enters]
719      output_tas = args[num_enters:]
720      # TODO(agarwal): see which outputs have consumers and only populate the
721      # TensorArrays corresponding to those. Or do those paths get trimmed out
722      # from inside the while_loop body?
723      assert len(inputs) >= len(output_tas)
724      assert len(inputs) == len(inputs_stacked)
725
726      # Convert condition
727      with ops.name_scope("while_cond"):
728        # Note that we set cond_stacked to True here. At this point we don't
729        # know if it could be loop invariant, hence the conservative value is
730        # to assume stacked.
731        cond_pfor = self._init_pfor(
732            pfor_input.pfor,
733            indices,
734            cond_stacked=True,
735            inputs=inputs,
736            inputs_stacked=inputs_stacked)
737        conditions, cond_stacked, _ = cond_pfor._convert_helper(self._condition)
738        cond_is_stacked[0] = cond_stacked
739
740      # Recompute the new condition, write outputs of done iterations, and
741      # partition the inputs if needed.
742      if not cond_stacked:
743        (not_all_done, new_indices, new_inputs,
744         new_output_tas) = self._process_cond_unstacked(conditions, indices,
745                                                        inputs, output_tas)
746      else:
747        (not_all_done, new_indices, new_inputs,
748         new_output_tas) = self._process_cond_stacked(conditions, indices,
749                                                      inputs, inputs_stacked,
750                                                      output_tas)
751
752      # Convert body
753      with ops.name_scope("while_body"):
754        #  Compute the outputs from the body.
755        new_outputs = self._process_body(pfor_input, inputs_stacked,
756                                         new_indices, cond_stacked, new_inputs,
757                                         not_all_done)
758
759      # Note that the first num_outputs new values of inputs are computed using
760      # the body. Rest of them were direct Enters into the condition/body and
761      # the partitioning done earlier is sufficient to give the new value.
762      num_outputs = len(self._outputs)
763      new_args = ([not_all_done, new_indices] + new_outputs +
764                  list(new_inputs[num_outputs:]) + new_output_tas)
765      return tuple(new_args)
766
767    while_outputs = control_flow_ops.while_loop(
768        cond, body, init_values, shape_invariants=shape_invariants)
769    output_tas = while_outputs[-len(self._outputs):]
770    outputs = []
771    assert cond_is_stacked[0] is not None
772    for inp_stacked, ta in zip(inputs_stacked, output_tas):
773      if cond_is_stacked[0]:
774        outputs.append(wrap(ta.stack(), True))
775      else:
776        # Note that if while_loop condition is unstacked, all iterations exit at
777        # the same time and we wrote those outputs in index 0 of the tensor
778        # array.
779        outputs.append(wrap(ta.read(0), inp_stacked))
780    return outputs
781
782
783class ConversionNotImplementedError(Exception):
784  pass
785
786
787class _PforInput(object):
788  """Input object passed to registered pfor converters."""
789
790  __slots__ = ["pfor", "_op", "_inputs"]
791
792  def __init__(self, pfor, op, inputs):
793    """Creates a _PforInput object.
794
795    Args:
796      pfor: PFor converter object.
797      op: the Operation object that is being converted.
798      inputs: list of WrappedTensor objects representing converted values of the
799        inputs of `op`.
800    """
801    self.pfor = pfor
802    self._op = op
803    self._inputs = inputs
804
805  def stack_inputs(self, stack_indices=None, tile_variants=False):
806    """Stacks unstacked inputs at `stack_indices`.
807
808    Args:
809      stack_indices: indices of inputs at which stacking is done. If None,
810        stacking is done at all indices.
811      tile_variants: If True, affected indices which have a variant dtype will
812        be tiled after this operation to match the expected shape of a
813        vectorized tensor. Variants generally need to be un-tiled when they are
814        inputs to operations and tiled when returned.
815    """
816    if stack_indices is None:
817      stack_indices = range(len(self._inputs))
818    length = self.pfor.loop_len_vector
819    for i in stack_indices:
820      inp = self._inputs[i]
821      is_variant = inp.t.dtype == dtypes.variant
822      if not inp.is_stacked:
823        self._inputs[i] = _stack(inp.t, length)
824        if tile_variants and is_variant:
825          self._inputs[i] = wrap(
826              _tile_variant_with_length(self._inputs[i].t, length), True)
827      elif not tile_variants and is_variant:
828        self._inputs[i] = wrap(_untile_variant(self._inputs[i].t), True)
829
830  def expanddim_inputs_for_broadcast(self):
831    """Reshapes stacked inputs to prepare them for broadcast.
832
833    Since stacked inputs have an extra leading dimension, automatic broadcasting
834    rules could incorrectly try to expand dimensions before that leading
835    dimension. To avoid that, we reshape these stacked inputs to the maximum
836    rank they will need to be broadcasted to.
837    """
838    if not self._inputs:
839      return
840
841    # Find max rank
842    def _get_rank(x):
843      rank = array_ops.rank(x.t)
844      if not x.is_stacked:
845        rank += 1
846      return rank
847
848    ranks = [_get_rank(x) for x in self._inputs]
849    max_rank = ranks[0]
850    for rank in ranks[1:]:
851      max_rank = math_ops.maximum(rank, max_rank)
852
853    for i, inp in enumerate(self._inputs):
854      if inp.is_stacked:
855        shape = array_ops.shape(inp.t)
856        rank_diff = array_ops.reshape(max_rank - ranks[i], [1])
857        ones = array_ops.tile([1], rank_diff)
858        new_shape = array_ops.concat([shape[:1], ones, shape[1:]], axis=0)
859        self._inputs[i] = wrap(array_ops.reshape(inp.t, new_shape), True)
860
861  @property
862  def inputs(self):
863    return self._inputs
864
865  @property
866  def num_inputs(self):
867    return len(self._inputs)
868
869  def input(self, index):
870    assert len(self._inputs) > index, (index, self._inputs)
871    return self._inputs[index]
872
873  def stacked_input(self, index):
874    t, is_stacked, _ = self.input(index)
875    if not is_stacked:
876      op_type = self.op_type
877      op_def = getattr(self._op, "op_def", None)
878      if op_def is None:
879        input_name = "at index %d" % index
880      else:
881        input_name = "\"%s\"" % op_def.input_arg[index].name
882      raise ConversionNotImplementedError(
883          "Input %s of op \"%s\" expected to be not loop invariant" %
884          (input_name, op_type))
885    return t
886
887  def unstacked_input(self, index):
888    t, is_stacked, _ = self.input(index)
889    if is_stacked:
890      op_type = self.op_type
891      op_def = getattr(self._op, "op_def", None)
892      if op_def is None:
893        input_name = "at index %d" % index
894      else:
895        input_name = "\"%s\"" % op_def.input_arg[index].name
896      raise ConversionNotImplementedError(
897          "Input %s of op \"%s\" expected to be loop invariant" %
898          (input_name, op_type))
899    return t
900
901  @property
902  def op(self):
903    return self._op
904
905  @property
906  def op_type(self):
907    return self._op.type
908
909  def get_attr(self, attr):
910    return self._op.get_attr(attr)
911
912  @property
913  def outputs(self):
914    return self._op.outputs
915
916  def output(self, index):
917    assert index < len(self._op.outputs)
918    return self._op.outputs[index]
919
920
921_pfor_converter_registry = {}
922
923
924class RegisterPFor(object):
925  """Utility to register converters for pfor.
926
927  Usage:
928  @RegisterPFor(foo_op_type)
929  def _foo_converter(pfor_input):
930    ...
931
932  The above will register conversion function `_foo_converter` for handling
933  conversion of `foo_op_type`. These converters are called during vectorization
934  of a `pfor` loop body. For each operation node in this loop body,
935  the vectorization process will call the converter corresponding to the
936  operation type of the node.
937
938  During conversion, the registered function will be called with a single
939  argument `pfor_input`, of type `PForInput`, which will contain state needed
940  for the conversion.  When the converter is called for a node, all its inputs
941  should already have been converted and these converted values are stored in
942  `pfor_input.inputs`.  This registered function should output a list of
943  WrappedTensor objects with the same length as the number of outputs of the
944  node being converted. If the node had zero outputs, then it should return an
945  ops.Operation object.  These new sets of nodes should implement the
946  functionality of running that operation for the number of iterations specified
947  by `pfor_input.pfor.loop_len_vector[0]` where the inputs of the node for each
948  iteration are picked from `pfor_inputs.inputs()`.
949
950  One tricky aspect of the conversion process is keeping track of, and
951  leveraging loop invariance of computation. Each converted input is a
952  WrappedTensor which indicates whether the input was loop invariant or not. If
953  the converted value is loop invariant, its rank should match the rank of the
954  corresponding tensor in the loop body, else its rank is larger by 1. The
955  converter should look at the loop invariance of the inputs and generate new
956  nodes based on that. Note that the converter will not be called if all inputs
957  are loop invariant and the operation is not stateful. The converter should
958  determine if its own output is loop invariant and `wrap` its output
959  accordingly.
960
961  Example:
962
963  Here, the converter is trying to convert a Reshape node in the loop body. This
964  node will have two inputs: the tensor to reshape, and the new shape.  The
965  example here only handles the case where the shape is loop invariant.
966
967  @RegisterPFor("Reshape")
968  def _convert_reshape(pfor_input):
969    # We assume that input is not loop invariant. Call to `stacked_input`
970    # asserts that and returns the converted value. This value will have a rank
971    # larger by 1 compared to the rank of the input in the loop body.
972    t = pfor_input.stacked_input(0)
973
974    # We assume that shape input is loop invariant. Call to `unstacked_input`
975    # asserts that and returns the converted value.
976    shape = pfor_input.unstacked_input(1)
977
978    # We compute `new_shape` by prepending the number of iterations to the
979    # original shape.
980    new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape],
981                                 axis=0)
982
983    # The vectorized output involves reshaping the converted input `t` using
984    # `new_shape`.
985    new_output = array_ops.reshape(t, new_shape)
986
987    # The converted output is marked as not loop invariant using the call to
988    # wrap.
989    return wrap(new_output, True)
990  """
991
992  def __init__(self, op_type):
993    """Creates an object to register a converter for op with type `op_type`."""
994    self.op_type = op_type
995
996  def __call__(self, converter):
997    name = self.op_type
998    assert name not in _pfor_converter_registry, "Re-registering %s " % name
999    _pfor_converter_registry[name] = converter
1000    return converter
1001
1002
1003class RegisterPForWithArgs(RegisterPFor):
1004  """Utility to register converters for pfor.
1005
1006  Usage:
1007  @RegisteRPFor(foo_op_type, foo=value, ....)
1008  def _foo_converter(pfor_input, foo=None, ....):
1009    ...
1010
1011  See RegisterPFor for details on the conversion function.
1012  `RegisterPForWithArgs` allows binding extra arguments to the
1013  conversion function at registration time.
1014  """
1015
1016  def __init__(self, op_type, *args, **kw_args):
1017    super(RegisterPForWithArgs, self).__init__(op_type)
1018    self._args = args
1019    self._kw_args = kw_args
1020
1021  def __call__(self, converter):
1022
1023    def _f(pfor_input):
1024      return converter(pfor_input, self.op_type, *self._args, **self._kw_args)
1025
1026    super(RegisterPForWithArgs, self).__call__(_f)
1027    return converter
1028
1029
1030# TODO(agarwal): call raw_ops instead of calling these low level routines.
1031def _create_op(op_type, inputs, op_dtypes, attrs=None):
1032  """Utility to create an op."""
1033  op = ops.get_default_graph().create_op(
1034      op_type, inputs, op_dtypes, attrs=attrs, compute_device=True)
1035  flat_attrs = []
1036  # The tape expects an alternating flat list of names and attribute values.
1037  for a in attrs:
1038    flat_attrs.append(str(a))
1039    flat_attrs.append(op.get_attr(str(a)))
1040  execute.record_gradient(op_type, op.inputs, tuple(flat_attrs), op.outputs[:])
1041  return op
1042
1043
1044WrappedTensor = collections.namedtuple("WrappedTensor",
1045                                       ["t", "is_stacked", "is_sparse_stacked"])
1046"""Wrapper around the result of a Tensor conversion.
1047
1048The additional fields are useful for keeping track of the conversion state as
1049data flows through the ops in the loop body. For every op whose output is a
1050Tensor, its converter should return either a WrappedTensor or a list of
1051WrappedTensors.
1052
1053Args:
1054  t: The converted tensor
1055  is_stacked: True if the tensor is stacked, i.e. represents the results of all
1056    the iterations of the loop, where each row i of the tensor corresponds to
1057    that op's output on iteration i of the loop. False if the tensor is not
1058    stacked, i.e. represents the result of the op on of a single iteration of
1059    the loop, where the result does not vary between iterations.
1060  is_sparse_stacked: True if the tensor corresponds to a component tensor
1061    (indices, values, or dense_shape) of a sparse tensor, and has been logically
1062    stacked via a sparse conversion.
1063"""
1064
1065
1066def wrap(tensor, is_stacked=True, is_sparse_stacked=False):
1067  """Helper to create a WrappedTensor object."""
1068  assert isinstance(is_stacked, bool)
1069  assert isinstance(is_sparse_stacked, bool)
1070  assert isinstance(tensor, ops.Tensor)
1071  assert not is_sparse_stacked or is_stacked, ("If the wrapped tensor is "
1072                                               "stacked via a sparse "
1073                                               "conversion, it must also be "
1074                                               "stacked.")
1075  return WrappedTensor(tensor, is_stacked, is_sparse_stacked)
1076
1077
1078def _wrap_and_tile_variants(tensor, length):
1079  if tensor.dtype == dtypes.variant:
1080    tensor = _tile_variant_with_length(tensor, length)
1081  return wrap(tensor)
1082
1083
1084def _fallback_converter(pfor_input, warn=True):
1085  if warn:
1086    logging.warning("Using a while_loop for converting %s", pfor_input.op_type)
1087  output_dtypes = [x.dtype for x in pfor_input.outputs]
1088  iters = pfor_input.pfor.loop_len_vector[0]
1089
1090  def while_body(i, *ta_list):
1091    """Body of while loop."""
1092    inputs = [
1093        x[i, ...] if stacked else x for x, stacked, _ in pfor_input.inputs
1094    ]
1095    op_outputs = _create_op(
1096        pfor_input.op_type,
1097        inputs,
1098        output_dtypes,
1099        attrs=pfor_input.op.node_def.attr).outputs
1100
1101    outputs = []
1102    # TODO(agarwal): Add tf.debugging asserts to check that the shapes across
1103    # the different iterations are the same.
1104    for out, ta in zip(op_outputs, ta_list):
1105      assert isinstance(out, ops.Tensor)
1106      outputs.append(ta.write(i, array_ops.expand_dims(out, 0)))
1107    return tuple([i + 1] + outputs)
1108
1109  ta_list = control_flow_ops.while_loop(
1110      lambda i, *ta: i < iters, while_body, [0] +
1111      [tensor_array_ops.TensorArray(dtype, iters) for dtype in output_dtypes
1112      ])[1:]
1113  return tuple([wrap(ta.concat(), True) for ta in ta_list])
1114
1115
1116class PForConfig(object):
1117  """A configuration object used to communicate with loop body function."""
1118
1119  def __init__(self):
1120    # This may be set to the number of iterations.
1121    self._maybe_iters = None
1122    # Map from reduction node, created by `reduce`, to the bundle of reduction
1123    # function and arguments.
1124    self._reduce_map = {}
1125
1126  def _has_reductions(self):
1127    """True if some reductions where performed by loop body."""
1128    return len(self._reduce_map)
1129
1130  def _set_iters(self, iters):
1131    """Set number of pfor iterations."""
1132    if isinstance(iters, ops.Tensor):
1133      iters = tensor_util.constant_value(iters)
1134    self._maybe_iters = iters
1135
1136  def reduce(self, fn, *args):
1137    """Performs reduction `fn` on `args` vectorized across pfor iterations.
1138
1139    Note that `fn` is traced once inside the loop function context. Hence any
1140    captures or side-effects will happen in that context. Call to the traced
1141    version of `fn` happens during the construction of the vectorized code.
1142
1143    Note that this currently may not work inside a control flow construct.
1144    Args:
1145      fn: a reduction function. It will be called with arguments that have the
1146        same structure as *args but with individual values whose rank may be
1147        higher by 1 since they represent loop invariant vectorized versions of
1148        the corresponding Tensors in *args.
1149      *args: unvectorized Tensors.
1150
1151    Returns:
1152      The result of running `fn` on the vectorized versions of `*args`. These
1153      outputs will be available as loop invariant values to all the iterations.
1154    """
1155    assert not context.executing_eagerly()
1156    # Creates a concrete function that will be used for reduction.
1157    tensor_specs = []
1158    for arg in args:
1159      if not isinstance(arg, ops.Tensor):
1160        raise ValueError("Got a non-Tensor argument %s in reduce" % arg)
1161      batched_shape = tensor_shape.TensorShape([self._maybe_iters
1162                                               ]).concatenate(arg.shape)
1163      tensor_specs.append(
1164          tensor_spec.TensorSpec(shape=batched_shape, dtype=arg.dtype))
1165    concrete_function = def_function.function(fn).get_concrete_function(
1166        *tensor_specs)
1167
1168    # Creates PlaceholderWithDefault and IdentityN nodes corresponding the
1169    # reduction.
1170    pl_outputs = []
1171    with ops.control_dependencies(args):
1172      for output in concrete_function.outputs:
1173        if not isinstance(output, ops.Tensor):
1174          raise ValueError("Got a non-Tensor output %s while running reduce" %
1175                           output)
1176        # Note that we use placeholder_with_default just to make XLA happy since
1177        # it does not like placeholder ops.
1178        if output.shape.is_fully_defined():
1179          dummy = array_ops.zeros(output.shape.as_list(), dtype=output.dtype)
1180          pl_outputs.append(
1181              array_ops.placeholder_with_default(dummy, shape=output.shape))
1182        else:
1183          # TODO(agarwal): support case when under XLA and output.shape is not
1184          # fully defined.
1185          pl_outputs.append(
1186              array_ops.placeholder(output.dtype, shape=output.shape))
1187
1188      reduction_op = array_ops.identity_n(pl_outputs)[0].op
1189    self._reduce_map[reduction_op] = (concrete_function, args)
1190    if len(reduction_op.outputs) == 1:
1191      return reduction_op.outputs[0]
1192    else:
1193      return tuple(reduction_op.outputs)
1194
1195  # TODO(agarwal): handle reductions inside control flow constructs.
1196  def reduce_concat(self, x):
1197    """Performs a concat reduction on `x` across pfor iterations.
1198
1199    Note that this currently may not work inside a control flow construct.
1200    Args:
1201      x: an unvectorized Tensor.
1202
1203    Returns:
1204      A Tensor that has rank one higher than `x`. The value is the vectorized
1205      version of `x`, i.e. stacking the value of `x` across different pfor
1206      iterations.
1207    """
1208    return self.reduce(lambda y: y, x)
1209
1210  def reduce_mean(self, x):
1211    """Performs a mean reduction on `x` across pfor iterations.
1212
1213    Note that this currently may not work inside a control flow construct.
1214    Args:
1215      x: an unvectorized Tensor.
1216
1217    Returns:
1218      A Tensor that has same rank as `x`. The value is the mean of the values
1219      of `x` across the pfor iterations.
1220    """
1221    return self.reduce(lambda y: math_ops.reduce_mean(y, axis=0), x)
1222
1223  def reduce_sum(self, x):
1224    """Performs a sum reduction on `x` across pfor iterations.
1225
1226    Note that this currently may not work inside a control flow construct.
1227    Args:
1228      x: an unvectorized Tensor.
1229
1230    Returns:
1231      A Tensor that has same rank as `x`. The value is the sum of the values
1232      of `x` across the pfor iterations.
1233    """
1234    return self.reduce(lambda y: math_ops.reduce_sum(y, axis=0), x)
1235
1236  def _lookup_reduction(self, t):
1237    """Lookups Tensor `t` in the reduction maps."""
1238    assert isinstance(t, ops.Tensor), t
1239    return self._reduce_map.get(t.op)
1240
1241
1242class PFor(object):
1243  """Implementation of rewrite of parallel-for loops.
1244
1245  This class takes a DAG or a set of DAGs representing the body of a
1246  parallel-for loop, and adds new operations to the graph that implements
1247  functionality equivalent to running that loop body for a specified number of
1248  iterations. This new set of nodes may or may not use a tensorflow loop
1249  construct.
1250
1251  The process of conversion does not delete or change any existing operations.
1252  It only adds operations that efficiently implement the equivalent
1253  functionality. We refer to the added ops as "converted ops".
1254
1255  The conversion process uses a simple greedy heuristic. It walks the loop body
1256  and tries to express the functionality of running each node in a loop with a
1257  new set of nodes. When converting an op several cases are possible:
1258  - The op is not inside the loop body. Hence it can be used as is.
1259  - The op does not depend on the iteration number and is stateless. In this
1260    case, it can be used as is.
1261  - The op is not stateful, and depends on iteration number only through control
1262    dependencies. In this case, we can create a single op with same inputs and
1263    attributes, but with "converted" control dependencies.
1264  - The op is not stateful, and all its inputs are loop invariant. In this
1265    case, similar to above, we can create a single op with same inputs and
1266    attributes, but with "converted" control dependencies.
1267  - The op is stateful or at least one of the inputs is not loop invariant. In
1268    this case, we run the registered converter for that op to create a set of
1269    converted ops. All nodes in the set will have converted control dependencies
1270    corresponding to control dependencies of the original op. If the op returned
1271    multiple outputs, "converted outputs" could be produced by different ops in
1272    this set.
1273  """
1274
1275  def __init__(self,
1276               loop_var,
1277               loop_len,
1278               pfor_ops,
1279               fallback_to_while_loop,
1280               all_indices=None,
1281               all_indices_partitioned=False,
1282               pfor_config=None):
1283    """Creates an object to rewrite a parallel-for loop.
1284
1285    Args:
1286      loop_var: ops.Tensor output of a Placeholder operation. The value should
1287        be an int32 scalar representing the loop iteration number.
1288      loop_len: A scalar or scalar Tensor representing the number of iterations
1289        the loop is run for.
1290      pfor_ops: List of all ops inside the loop body.
1291      fallback_to_while_loop: If True, on failure to vectorize an op, a while
1292        loop is used to sequentially execute that op.
1293      all_indices: If not None, an int32 vector with size `loop_len`
1294        representing the iteration ids that are still active. These values
1295        should be unique and sorted. However they may not be contiguous. This is
1296        typically the case when inside a control flow construct which has
1297        partitioned the indices of the iterations that are being converted.
1298      all_indices_partitioned: If True, this object is being constructed from a
1299        control flow construct where not all the pfor iterations are guaranteed
1300        to be active.
1301      pfor_config: PForConfig object used while constructing the loop body.
1302    """
1303    assert isinstance(loop_var, ops.Tensor)
1304    assert loop_var.op.type == "PlaceholderWithDefault"
1305    self._loop_var = loop_var
1306    loop_len_value = tensor_util.constant_value(loop_len)
1307    if loop_len_value is not None:
1308      loop_len = loop_len_value
1309    self._loop_len_vector = array_ops.reshape(loop_len, [1])
1310    self._all_indices_partitioned = all_indices_partitioned
1311    if all_indices_partitioned:
1312      assert all_indices is not None
1313    self.all_indices = (
1314        math_ops.range(loop_len) if all_indices is None else all_indices)
1315
1316    self._conversion_map = object_identity.ObjectIdentityDictionary()
1317    self._conversion_map[loop_var] = wrap(self.all_indices, True)
1318    self._pfor_ops = set(pfor_ops)
1319    self._pfor_op_ids = set(x._id for x in pfor_ops)
1320    self._fallback_to_while_loop = fallback_to_while_loop
1321    self._pfor_config = pfor_config
1322
1323  def op_is_inside_loop(self, op):
1324    """True if op was created inside the pfor loop body."""
1325    assert isinstance(op, ops.Operation)
1326    # Note that we use self._pfor_op_ids for the check and not self._pfor_ops
1327    # since it appears there tensorflow API could return different python
1328    # objects representing the same Operation node.
1329    return op._id in self._pfor_op_ids
1330
1331  def _convert_sparse(self, y):
1332    """Returns the converted value corresponding to SparseTensor y.
1333
1334    For SparseTensors, instead of stacking the component tensors separately,
1335    resulting in component tensors with shapes (N, m, rank), (N, m), and (N,
1336    rank) respectively for indices, values, and dense_shape (where N is the loop
1337    length and m is the number of sparse tensor values per loop iter), we want
1338    to logically stack the SparseTensors, to create a SparseTensor whose
1339    components are size (N * m, rank + 1), (N * m, ), and (rank + 1,)
1340    respectively.
1341
1342    Here, we try to get the conversion of each component tensor.
1343    If the tensors are stacked via a sparse conversion, return the resulting
1344    SparseTensor composed of the converted components. Otherwise, the component
1345    tensors are either unstacked or stacked naively. In the latter case, we
1346    unstack the component tensors to reform loop_len SparseTensor elements,
1347    then correctly batch them.
1348
1349    The unstacked tensors must have the same rank. Each dimension of each
1350    SparseTensor will expand to be the largest among all SparseTensor elements
1351    for that dimension. For example, if there are N SparseTensors of rank 3
1352    being stacked, with N dense shapes, where the i_th shape is (x_i, y_i, z_i),
1353    the new dense shape will be (N, max_i(x_i), max_i(y_i), max_i(z_i)).
1354
1355    Args:
1356      y: A tf.sparse.SparseTensor.
1357
1358    Returns:
1359      A tf.sparse.SparseTensor that is the converted value corresponding to y.
1360    """
1361    outputs = [
1362        self._convert_helper(t) for t in (y.indices, y.values, y.dense_shape)
1363    ]
1364    assert all(isinstance(o, WrappedTensor) for o in outputs)
1365
1366    if all(w.is_sparse_stacked for w in outputs):
1367      return sparse_tensor.SparseTensor(*[w.t for w in outputs])
1368
1369    assert not any(w.is_sparse_stacked for w in outputs), (
1370        "Error converting SparseTensor. All components should be logically "
1371        "stacked, or none.")
1372
1373    # If component tensors were not sparsely stacked, they are either unstacked
1374    # or stacked without knowledge that they are components of sparse tensors.
1375    # In this case, we have to restack them.
1376    return self._restack_sparse_tensor_logically(
1377        *[self._unwrap_or_tile(w) for w in outputs])
1378
1379  def _restack_sparse_tensor_logically(self, indices, values, shape):
1380    sparse_tensor_rank = indices.get_shape().dims[-1].value
1381    if sparse_tensor_rank is not None:
1382      sparse_tensor_rank += 1
1383
1384    def fn(args):
1385      res = gen_sparse_ops.serialize_sparse(
1386          args[0], args[1], args[2], out_type=dtypes.variant)
1387      return res
1388
1389    # Applies a map function to the component tensors to serialize each
1390    # sparse tensor element and batch them all, then deserializes the batch.
1391    # TODO(rachelim): Try to do this without map_fn -- add the right offsets
1392    # to shape and indices tensors instead.
1393    result = map_fn.map_fn(fn, [indices, values, shape], dtype=dtypes.variant)
1394    return sparse_ops.deserialize_sparse(
1395        result, dtype=values.dtype, rank=sparse_tensor_rank)
1396
1397  def _unwrap_or_tile(self, wrapped_tensor):
1398    """Given a wrapped tensor, unwrap if stacked. Otherwise, tiles it."""
1399    output, is_stacked = wrapped_tensor.t, wrapped_tensor.is_stacked
1400    if is_stacked:
1401      return output
1402    else:
1403      return _stack(output, self._loop_len_vector).t
1404
1405  def convert(self, y):
1406    """Returns the converted value corresponding to y.
1407
1408    Args:
1409      y: A ops.Tensor or a ops.Operation object. If latter, y should not have
1410        any outputs.
1411
1412    Returns:
1413      If y does not need to be converted, it returns y as is. Else it returns
1414      the "converted value" corresponding to y.
1415    """
1416    if y is None:
1417      return None
1418    if isinstance(y, sparse_tensor.SparseTensor):
1419      return self._convert_sparse(y)
1420    assert isinstance(y, (ops.Tensor, ops.Operation)), y
1421    output = self._convert_helper(y)
1422    if isinstance(output, WrappedTensor):
1423      assert isinstance(y, ops.Tensor)
1424      return self._unwrap_or_tile(output)
1425    else:
1426      assert isinstance(y, ops.Operation)
1427      assert not y.outputs
1428      assert isinstance(output, ops.Operation)
1429    return output
1430
1431  def _was_converted(self, t):
1432    """True if t is not a conversion of itself."""
1433    converted_t = self._conversion_map[t]
1434    return converted_t.t is not t
1435
1436  def _add_conversion(self, old_output, new_output):
1437    assert isinstance(old_output, (ops.Tensor, ops.Operation)), old_output
1438    assert isinstance(new_output, (WrappedTensor, ops.Operation)), new_output
1439    self._conversion_map[old_output] = new_output
1440
1441  def _convert_reduction(self, y):
1442    # Handle reductions.
1443    if self._pfor_config is None or isinstance(y, ops.Operation):
1444      return None
1445    reduction = self._pfor_config._lookup_reduction(y)
1446    if reduction is None:
1447      return None
1448    (reduction_fn, reduction_args) = reduction
1449    batched_args = []
1450    for reduction_arg in reduction_args:
1451      assert isinstance(reduction_arg, ops.Tensor), reduction_arg
1452      # Tensor being reduced should already be converted due to a control
1453      # dependency on the created placeholder.
1454      # Note that in cases where reduction_arg is in an outer context, one
1455      # needs to locate the corresponding Enter node and use that to lookup
1456      # the conversion.
1457      # TODO(agarwal): handle reductions inside control flow constructs.
1458      assert reduction_arg in self._conversion_map, (
1459          "Unable to handle reduction of %s, possibly as it was used "
1460          "inside a control flow construct. Note that reductions across "
1461          "pfor iterations are currently not supported inside control flow "
1462          "constructs." % reduction_arg)
1463      batched_arg = self._conversion_map[reduction_arg]
1464      batched_args.append(self._unwrap_or_tile(batched_arg))
1465    outputs = reduction_fn(*batched_args)
1466    return [wrap(output, False) for output in nest.flatten(outputs)]
1467
1468  def _convert_helper(self, op_or_tensor):
1469    stack = collections.deque([op_or_tensor])
1470    while stack:
1471      y = stack[0]
1472      if y in self._conversion_map:
1473        assert isinstance(self._conversion_map[y],
1474                          (WrappedTensor, ops.Operation))
1475        stack.popleft()
1476        continue
1477      if isinstance(y, ops.Operation):
1478        assert not y.outputs, (
1479            "We only support converting Operation objects with no outputs. "
1480            "Got %s", y)
1481        y_op = y
1482      else:
1483        assert isinstance(y, ops.Tensor), y
1484        y_op = y.op
1485
1486      is_while_loop = y_op.type == "Exit"
1487      if is_while_loop:
1488        while_op = WhileOp(
1489            y, pfor_ops=self._pfor_ops,
1490            fallback_to_while_loop=self.fallback_to_while_loop,
1491            pfor_config=self._pfor_config)
1492        is_inside_loop = while_op.is_inside_loop
1493        # If all nodes in the while_loop graph were created inside the pfor, we
1494        # treat the whole loop subgraph as a single op (y_op) and try to convert
1495        # it. For while_loops that are created completely or partially outside,
1496        # we treat them as external and should be able to simply return the Exit
1497        # node output as is without needing any conversion. Note that for
1498        # while_loops that are partially constructed inside, we assume they will
1499        # be loop invariant. If that is not the case, it will create runtime
1500        # errors since the converted graph would depend on the self._loop_var
1501        # placeholder.
1502        if is_inside_loop:
1503          y_op = while_op
1504      else:
1505        is_inside_loop = self.op_is_inside_loop(y_op)
1506
1507      # If this op was not created inside the loop body, we will return as is.
1508      # 1. Convert inputs and control inputs.
1509
1510      def _add_to_stack(x):
1511        if x not in self._conversion_map:
1512          stack.appendleft(x)
1513          return True
1514        else:
1515          return False
1516
1517      if is_inside_loop:
1518        added_to_stack = False
1519        for inp in y_op.inputs:
1520          added_to_stack |= _add_to_stack(inp)
1521        for cinp in y_op.control_inputs:
1522          if cinp.outputs:
1523            for t in cinp.outputs:
1524              added_to_stack |= _add_to_stack(t)
1525          else:
1526            added_to_stack |= _add_to_stack(cinp)
1527        if added_to_stack:
1528          continue
1529
1530        converted_inputs = [self._conversion_map[inp] for inp in y_op.inputs]
1531        some_input_converted = any(self._was_converted(x) for x in y_op.inputs)
1532        some_input_stacked = any(x.is_stacked for x in converted_inputs)
1533
1534        converted_control_ops = set()
1535        some_control_input_converted = False
1536        for cinp in y_op.control_inputs:
1537          if cinp.outputs:
1538            for t in cinp.outputs:
1539              converted_t = self._conversion_map[t]
1540              if self._was_converted(t):
1541                some_control_input_converted = True
1542              converted_control_ops.add(converted_t.t.op)
1543          else:
1544            converted_cinp = self._conversion_map[cinp]
1545            assert isinstance(converted_cinp, ops.Operation)
1546            if converted_cinp != cinp:
1547              some_control_input_converted = True
1548            converted_control_ops.add(converted_cinp)
1549        converted_control_ops = list(converted_control_ops)
1550        is_stateful = _is_stateful_pfor_op(y_op)
1551      else:
1552        converted_inputs = []
1553        converted_control_ops = []
1554      logging.vlog(3, "converting op:%s\ninputs:%s\ncontrol_inputs:%s", y_op,
1555                   converted_inputs, converted_control_ops)
1556
1557      # 2. Convert y_op
1558      # If converting a while_loop, we let the while_loop convertor deal with
1559      # putting the control dependencies appropriately.
1560      control_dependencies = [] if is_while_loop else converted_control_ops
1561      with ops.control_dependencies(control_dependencies), ops.name_scope(
1562          y_op.name + "/pfor/"), ops.get_default_graph()._original_op(y_op):
1563        # Op is a placeholder for a reduction.
1564        reduce_output = self._convert_reduction(y)
1565        if reduce_output is not None:
1566          new_outputs = reduce_output
1567        # None of the inputs and control inputs were converted.
1568        elif ((not is_inside_loop or
1569               (not is_stateful and not some_input_converted and
1570                not some_control_input_converted)) and
1571              y.graph == ops.get_default_graph()):
1572          if y is y_op:
1573            assert not isinstance(y_op, WhileOp)
1574            new_outputs = y_op
1575          else:
1576            new_outputs = [wrap(x, False) for x in y_op.outputs]
1577        elif not (is_stateful or is_while_loop or some_input_stacked):
1578          # All inputs are unstacked or unconverted but some control inputs are
1579          # converted.
1580          # TODO(rachelim): Handle the case where some inputs are sparsely
1581          # stacked (i.e. any(x.is_sparse_stacked for x in converted_inputs))
1582          new_op = _create_op(y_op.type, [x.t for x in converted_inputs],
1583                              [x.dtype for x in y_op.outputs],
1584                              y_op.node_def.attr)
1585          if y is y_op:
1586            new_outputs = new_op
1587          else:
1588            new_outputs = []
1589            for old_output, new_output in zip(y_op.outputs, new_op.outputs):
1590              handle_data_util.copy_handle_data(old_output, new_output)
1591              new_outputs.append(wrap(new_output, False))
1592        else:
1593          # Either some inputs are not loop invariant or op is stateful.
1594          if hasattr(y_op, "pfor_converter"):
1595            converter = y_op.pfor_converter
1596          else:
1597            converter = _pfor_converter_registry.get(y_op.type, None)
1598          if converter is None:
1599            has_variant_outputs = any(x.dtype == dtypes.variant for x in
1600                                      y_op.outputs)
1601            has_vectorized_variant_inputs = any(
1602                _is_variant_with_internal_stacking(x) for x in
1603                y_op.inputs)
1604            if (self._fallback_to_while_loop and not has_variant_outputs
1605                and not has_vectorized_variant_inputs):
1606              converter = _fallback_converter
1607            else:
1608              message = ("No pfor vectorization defined for %s\n"
1609                         "%s\n"
1610                         "inputs: %s. " %
1611                         (y_op.type, y_op, converted_inputs))
1612              if not self._fallback_to_while_loop:
1613                message += ("Consider enabling the fallback_to_while_loop "
1614                            "option to pfor, which may run slower.")
1615              raise ValueError(message)
1616          # TODO(rachelim): Handle the case where some inputs are sparsely
1617          # stacked. We should only call the converter if it supports handling
1618          # those inputs.
1619          pfor_inputs = _PforInput(self, y_op, converted_inputs)
1620          try:
1621            try:
1622              new_outputs = converter(pfor_inputs)
1623            except ConversionNotImplementedError as e:
1624              has_vectorized_variant_inputs = any(
1625                  _is_variant_with_internal_stacking(x) for x in
1626                  y_op.inputs)
1627              if (self._fallback_to_while_loop
1628                  and not has_vectorized_variant_inputs):
1629                new_outputs = _fallback_converter(pfor_inputs)
1630              else:
1631                six.reraise(ValueError, ValueError(str(e)), sys.exc_info()[2])
1632          except Exception as e:  # pylint: disable=broad-except
1633            logging.error(
1634                "Got error while pfor was converting op %s"
1635                "with inputs %s\n, converted inputs %s\n"
1636                "%s\n"
1637                "Here are the pfor conversion stack traces:", y_op,
1638                y_op.inputs[:], pfor_inputs.inputs, str(e))
1639            original_op = y_op
1640            while isinstance(original_op, ops.Operation):
1641              logging.error(
1642                  "%s\ncreated at:\n  %s", original_op,
1643                  "  ".join(traceback.format_list(original_op.traceback)))
1644              original_op = original_op._original_op
1645            six.reraise(e.__class__, e, sys.exc_info()[2])
1646
1647          if isinstance(new_outputs, WrappedTensor):
1648            new_outputs = [new_outputs]
1649          assert isinstance(new_outputs,
1650                            (list, tuple, ops.Operation)), new_outputs
1651        logging.vlog(2, "converted %s %s", y_op, new_outputs)
1652
1653        # Insert into self._conversion_map
1654        if y is y_op:
1655          assert isinstance(new_outputs, ops.Operation)
1656          self._add_conversion(y_op, new_outputs)
1657        else:
1658          assert len(y_op.outputs) == len(new_outputs), (y_op, y_op.outputs,
1659                                                         new_outputs)
1660          for old_output, new_output in zip(y_op.outputs, new_outputs):
1661            assert isinstance(new_output, WrappedTensor), (new_output, y, y_op)
1662            assert old_output.dtype == new_output.t.dtype, (new_output, y, y_op)
1663            # Set shape for converted output.
1664            output_shape = old_output.shape
1665            if not new_output.is_sparse_stacked:
1666              if new_output.is_stacked:
1667                loop_len = tensor_util.constant_value(self.loop_len_vector)
1668                if loop_len is None:
1669                  batch_dim = tensor_shape.TensorShape([None])
1670                else:
1671                  batch_dim = tensor_shape.TensorShape(loop_len)
1672                output_shape = batch_dim.concatenate(output_shape)
1673              if _is_variant_with_internal_stacking(new_output.t):
1674                new_output.t.set_shape([])
1675              else:
1676                new_output.t.set_shape(output_shape)
1677            self._add_conversion(old_output, new_output)
1678        stack.popleft()
1679
1680    return self._conversion_map[op_or_tensor]
1681
1682  @property
1683  def loop_len_vector(self):
1684    """Returns a single element vector whose value is number of iterations."""
1685    return self._loop_len_vector
1686
1687  @property
1688  def loop_var(self):
1689    """Returns placeholder loop variable."""
1690    return self._loop_var
1691
1692  @property
1693  def pfor_ops(self):
1694    return self._pfor_ops
1695
1696  @property
1697  def pfor_config(self):
1698    return self._pfor_config
1699
1700  @property
1701  def all_indices_partitioned(self):
1702    """all_indices_partitioned property.
1703
1704    Returns:
1705      True if we are inside a control flow construct and not all pfor iterations
1706      may be active.
1707    """
1708    return self._all_indices_partitioned
1709
1710  @property
1711  def fallback_to_while_loop(self):
1712    return self._fallback_to_while_loop
1713
1714
1715# The code below defines converters for different operations. Please see comment
1716# for RegisterPFor to see how converters should be defined.
1717
1718
1719# image_ops
1720
1721
1722@RegisterPFor("AdjustContrastv2")
1723def _convert_adjust_contrastv2(pfor_input):
1724  images = pfor_input.stacked_input(0)
1725  contrast_factor = pfor_input.unstacked_input(1)
1726  return wrap(gen_image_ops.adjust_contrastv2(images, contrast_factor), True)
1727
1728
1729@RegisterPFor("AdjustHue")
1730def _convert_adjust_hue(pfor_input):
1731  images = pfor_input.stacked_input(0)
1732  delta = pfor_input.unstacked_input(1)
1733  return wrap(gen_image_ops.adjust_hue(images, delta), True)
1734
1735
1736@RegisterPFor("AdjustSaturation")
1737def _convert_adjust_saturation(pfor_input):
1738  images = pfor_input.stacked_input(0)
1739  scale = pfor_input.unstacked_input(1)
1740  return wrap(gen_image_ops.adjust_saturation(images, scale), True)
1741
1742
1743# nn_ops
1744
1745
1746def _flatten_first_two_dims(x):
1747  """Merges first two dimensions."""
1748  old_shape = array_ops.shape(x)
1749  new_shape = array_ops.concat([[-1], old_shape[2:]], axis=0)
1750  return array_ops.reshape(x, new_shape)
1751
1752
1753def _unflatten_first_dim(x, first_dim):
1754  """Splits first dimension into [first_dim, -1]."""
1755  old_shape = array_ops.shape(x)
1756  new_shape = array_ops.concat([first_dim, [-1], old_shape[1:]], axis=0)
1757  return array_ops.reshape(x, new_shape)
1758
1759
1760def _inputs_with_flattening(pfor_input, input_indices):
1761  """Stacks and flattens first dim of inputs at indices `input_indices`."""
1762  if input_indices is None:
1763    input_indices = []
1764  pfor_input.stack_inputs(stack_indices=input_indices)
1765  inputs = []
1766  for i in range(pfor_input.num_inputs):
1767    if i in input_indices:
1768      inp = pfor_input.stacked_input(i)
1769      inp = _flatten_first_two_dims(inp)
1770    else:
1771      inp = pfor_input.unstacked_input(i)
1772    inputs.append(inp)
1773  return inputs
1774
1775
1776@RegisterPForWithArgs("Conv2D", dims=[0])
1777@RegisterPForWithArgs("DepthToSpace", dims=[0])
1778@RegisterPForWithArgs("AvgPool", dims=[0])
1779@RegisterPForWithArgs("AvgPool3D", dims=[0])
1780@RegisterPForWithArgs("MaxPool", dims=[0])
1781@RegisterPForWithArgs("MaxPoolV2", dims=[0])
1782@RegisterPForWithArgs("MaxPool3D", dims=[0])
1783@RegisterPForWithArgs("MaxPool3DGrad", dims=[0, 1, 2])
1784@RegisterPForWithArgs("MaxPoolGrad", dims=[0, 1, 2])
1785@RegisterPForWithArgs("MaxPoolGradV2", dims=[0, 1, 2])
1786@RegisterPForWithArgs("MaxPool3DGradGrad", dims=[0, 1, 2])
1787@RegisterPForWithArgs("MaxPoolGradGrad", dims=[0, 1, 2])
1788@RegisterPForWithArgs("MaxPoolGradGradV2", dims=[0, 1, 2])
1789@RegisterPForWithArgs("SoftmaxCrossEntropyWithLogits", dims=[0, 1])
1790@RegisterPForWithArgs("SparseSoftmaxCrossEntropyWithLogits", dims=[0, 1])
1791@RegisterPForWithArgs("SpaceToDepth", dims=[0])
1792def _convert_flatten_batch(pfor_input, op_type, dims):
1793  del op_type
1794  inputs = _inputs_with_flattening(pfor_input, dims)
1795  outputs = _create_op(
1796      pfor_input.op_type,
1797      inputs, [x.dtype for x in pfor_input.outputs],
1798      attrs=pfor_input.op.node_def.attr).outputs
1799  n = pfor_input.pfor.loop_len_vector
1800  outputs = [_unflatten_first_dim(x, n) for x in outputs]
1801  return [wrap(x, True) for x in outputs]
1802
1803
1804_channel_flatten_input_cache = {}
1805
1806
1807@RegisterPFor("BatchToSpaceND")
1808def _convert_batch_to_space_nd(pfor_input):
1809  inp = pfor_input.stacked_input(0)
1810  block_shape = pfor_input.unstacked_input(1)
1811  crops = pfor_input.unstacked_input(2)
1812
1813  inp_shape = array_ops.shape(inp)
1814  n = pfor_input.pfor.loop_len_vector
1815
1816  # Reshape and transpose to move the vectorization axis inside the axes that
1817  # will move to space.
1818  # Reshape to 4D and transpose
1819  block_size = math_ops.reduce_prod(block_shape)
1820  new_shape = [n[0], block_size, inp_shape[1] // block_size, -1]
1821  inp = array_ops.reshape(inp, new_shape)
1822  inp = array_ops.transpose(inp, [1, 0, 2, 3])
1823  # Reshape back to merge the block, vectorization and batch dimension, and
1824  # restore the other dimensions.
1825  new_shape = array_ops.concat([n * inp_shape[1], inp_shape[2:]], axis=0)
1826  inp = array_ops.reshape(inp, new_shape)
1827  # Call batch_to_space and then split the new batch axis.
1828  output = gen_array_ops.batch_to_space_nd(inp, block_shape, crops)
1829  output = _unflatten_first_dim(output, n)
1830  return wrap(output, True)
1831
1832
1833@RegisterPFor("SpaceToBatchND")
1834def _convert_space_to_batch_nd(pfor_input):
1835  inp = pfor_input.stacked_input(0)
1836  block_shape = pfor_input.unstacked_input(1)
1837  paddings = pfor_input.unstacked_input(2)
1838
1839  n = pfor_input.pfor.loop_len_vector
1840  inp_shape = array_ops.shape(inp)
1841  inp = _flatten_first_two_dims(inp)
1842  output = gen_array_ops.space_to_batch_nd(inp, block_shape, paddings)
1843  output_shape = array_ops.shape(output)
1844  block_size = math_ops.reduce_prod(block_shape)
1845  new_shape = [block_size, n[0], -1]
1846  output = array_ops.reshape(output, new_shape)
1847  output = array_ops.transpose(output, [1, 0, 2])
1848  new_shape = array_ops.concat(
1849      [n, block_size * inp_shape[1:2], output_shape[1:]], axis=0)
1850  output = array_ops.reshape(output, new_shape)
1851  return wrap(output, True)
1852
1853
1854def _channel_flatten_input(x, data_format):
1855  """Merge the stack dimension with the channel dimension.
1856
1857  If S is pfor's stacking dimension, then,
1858    - for SNCHW, we transpose to NSCHW. If N dimension has size 1, the transpose
1859      should be cheap.
1860    - for SNHWC, we transpose to NHWSC.
1861  We then merge the S and C dimension.
1862
1863  Args:
1864    x: ops.Tensor to transform.
1865    data_format: "NCHW" or "NHWC".
1866
1867  Returns:
1868    A 3-element tuple with the transformed value, along with the shape for
1869    reshape and order for transpose required to transform back.
1870  """
1871
1872  graph = ops.get_default_graph()
1873  cache_key = (graph, x.ref(), data_format)
1874  if cache_key not in _channel_flatten_input_cache:
1875    x_shape = array_ops.shape(x)
1876    if data_format == b"NCHW":
1877      order = [1, 0, 2, 3, 4]
1878      shape = array_ops.concat([x_shape[1:2], [-1], x_shape[3:]], axis=0)
1879      reverse_order = order
1880    else:
1881      order = [1, 2, 3, 0, 4]
1882      shape = array_ops.concat([x_shape[1:4], [-1]], axis=0)
1883      reverse_order = [3, 0, 1, 2, 4]
1884    # Move S dimension next to C dimension.
1885    x = array_ops.transpose(x, order)
1886    reverse_shape = array_ops.shape(x)
1887    # Reshape to merge the S and C dimension.
1888    x = array_ops.reshape(x, shape)
1889    outputs = x, reverse_order, reverse_shape
1890    _channel_flatten_input_cache[cache_key] = outputs
1891  else:
1892    outputs = _channel_flatten_input_cache[cache_key]
1893  return outputs
1894
1895
1896# Note that with training=True, running FusedBatchNormV3 on individual examples
1897# is very different from running FusedBatchNormV3 on a batch of those examples.
1898# This is because, for the latter case, the operation can be considered as first
1899# computing the mean and variance over all the examples and then using these
1900# to scale all those examples. This creates a data dependency between these
1901# different "iterations" since the inputs to the scaling step depends on the
1902# statistics coming from all these inputs.
1903# As with other kernels, the conversion here effectively runs the kernel
1904# independently for each iteration, and returns outputs by stacking outputs from
1905# each of those iterations.
1906@RegisterPFor("FusedBatchNormV3")
1907def _convert_fused_batch_norm(pfor_input):
1908  is_training = pfor_input.get_attr("is_training")
1909  # When BatchNorm is used with training=False, mean and variance are provided
1910  # externally and used as is by the op. Thus, we can merge the S and N
1911  # dimensions as we do for regular operations.
1912  # When BatchNorm is used with training=True, mean and variance are computed
1913  # for each channel across the batch dimension (first one). If we merge S and N
1914  # dimensions, mean and variances will be computed over a larger set. So, we
1915  # merge the S and C dimensions instead.
1916  if not is_training:
1917    # We return zeros for batch_mean and batch_variance output. Note that CPU
1918    # and GPU seem to have different behavior for those two outputs. CPU outputs
1919    # zero because these values are not used during inference. GPU outputs
1920    # something, probably real means and variances.
1921    inputs = _inputs_with_flattening(pfor_input, [0])
1922    outputs = _create_op(
1923        pfor_input.op_type,
1924        inputs, [x.dtype for x in pfor_input.outputs],
1925        attrs=pfor_input.op.node_def.attr).outputs
1926    y = outputs[0]
1927    n = pfor_input.pfor.loop_len_vector
1928    y = _unflatten_first_dim(y, n)
1929    mean = pfor_input.unstacked_input(3)
1930    zeros = array_ops.zeros_like(mean)
1931    return [wrap(y, True)] + [wrap(zeros, False)] * 5
1932
1933  pfor_input.stack_inputs()
1934  data_format = pfor_input.get_attr("data_format")
1935  # We merge the first dimension with the "C" dimension, run FusedBatchNormV3,
1936  # and then transpose back.
1937  x = pfor_input.stacked_input(0)
1938  x, reverse_order, reverse_shape = _channel_flatten_input(x, data_format)
1939  # Note that we stack all the other inputs as well so that they are the same
1940  # size as the new size of the channel dimension.
1941  inputs = [x] + [
1942      array_ops.reshape(pfor_input.stacked_input(i), [-1])
1943      for i in range(1, pfor_input.num_inputs)
1944  ]
1945  outputs = _create_op(
1946      pfor_input.op_type,
1947      inputs, [x.dtype for x in pfor_input.outputs],
1948      attrs=pfor_input.op.node_def.attr).outputs
1949  y = outputs[0]
1950  y = array_ops.reshape(y, reverse_shape)
1951  y = array_ops.transpose(y, reverse_order)
1952  n = pfor_input.pfor.loop_len_vector
1953  outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]]
1954  outputs = [y] + outputs
1955  return [wrap(x, True) for x in outputs]
1956
1957
1958@RegisterPFor("FusedBatchNormGradV3")
1959def _convert_fused_batch_norm_grad(pfor_input):
1960  pfor_input.stack_inputs()
1961  data_format = pfor_input.get_attr("data_format")
1962  y_backprop = pfor_input.stacked_input(0)
1963  y_backprop, _, _ = _channel_flatten_input(y_backprop, data_format)
1964  x = pfor_input.stacked_input(1)
1965  x, x_reverse_order, x_reverse_shape = _channel_flatten_input(x, data_format)
1966  inputs = [y_backprop, x] + [
1967      array_ops.reshape(pfor_input.stacked_input(i), [-1])
1968      for i in range(2, pfor_input.num_inputs)
1969  ]
1970  outputs = _create_op(
1971      pfor_input.op_type,
1972      inputs, [x.dtype for x in pfor_input.outputs],
1973      attrs=pfor_input.op.node_def.attr).outputs
1974  x_backprop = outputs[0]
1975  x_backprop = array_ops.reshape(x_backprop, x_reverse_shape)
1976  x_backprop = array_ops.transpose(x_backprop, x_reverse_order)
1977  n = pfor_input.pfor.loop_len_vector
1978  outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]]
1979  outputs = [x_backprop] + outputs
1980  return [wrap(output, True) for output in outputs]
1981
1982
1983@RegisterPForWithArgs("Conv2DBackpropInput", flatten_dims=[2], shape_dim=0)
1984@RegisterPForWithArgs("AvgPoolGrad", flatten_dims=[1], shape_dim=0)
1985@RegisterPForWithArgs("AvgPool3DGrad", flatten_dims=[1], shape_dim=0)
1986def _convert_flatten_batch_shape_input(pfor_input, op_type, flatten_dims,
1987                                       shape_dim):
1988  del op_type
1989  inputs = _inputs_with_flattening(pfor_input, flatten_dims)
1990  n = pfor_input.pfor.loop_len_vector
1991  # Adjust the `input_sizes` input.
1992  ones = array_ops.ones([array_ops.shape(inputs[shape_dim])[0] - 1],
1993                        dtype=n.dtype)
1994  inputs[shape_dim] *= array_ops.concat([n, ones], axis=0)
1995  outputs = _create_op(
1996      pfor_input.op_type,
1997      inputs, [x.dtype for x in pfor_input.outputs],
1998      attrs=pfor_input.op.node_def.attr).outputs
1999  outputs = [_unflatten_first_dim(x, n) for x in outputs]
2000  return [wrap(x, True) for x in outputs]
2001
2002
2003@RegisterPFor("Conv2DBackpropFilter")
2004def _convert_conv2d_backprop_filter(pfor_input):
2005  pfor_input.stack_inputs(stack_indices=[2])
2006  inputs, inputs_stacked, _ = pfor_input.input(0)
2007  filter_sizes = pfor_input.unstacked_input(1)
2008  grads = pfor_input.stacked_input(2)
2009  strides = pfor_input.get_attr("strides")
2010  padding = pfor_input.get_attr("padding")
2011  use_cudnn_on_gpu = pfor_input.get_attr("use_cudnn_on_gpu")
2012  data_format = pfor_input.get_attr("data_format")
2013  dilations = pfor_input.get_attr("dilations")
2014  if inputs_stacked:
2015    # TODO(agarwal): Implement this efficiently.
2016    logging.warning("Conv2DBackpropFilter uses a while_loop. Fix that!")
2017
2018    def while_body(i, ta):
2019      inp_i = inputs[i, ...]
2020      grad_i = grads[i, ...]
2021      output = nn_ops.conv2d_backprop_filter(
2022          inp_i,
2023          filter_sizes,
2024          grad_i,
2025          strides=strides,
2026          padding=padding,
2027          use_cudnn_on_gpu=use_cudnn_on_gpu,
2028          data_format=data_format,
2029          dilations=dilations)
2030      return i + 1, ta.write(i, array_ops.expand_dims(output, 0))
2031
2032    n = array_ops.reshape(pfor_input.pfor.loop_len_vector, [])
2033    _, ta = control_flow_ops.while_loop(
2034        lambda i, ta: i < n, while_body,
2035        (0, tensor_array_ops.TensorArray(inputs.dtype, n)))
2036    output = ta.concat()
2037    return wrap(output, True)
2038  else:
2039    # We merge the stack dimension with the channel dimension of the gradients
2040    # and pretend we had a larger filter (see change to filter_sizes below).
2041    # Once the filter backprop is computed, we reshape and transpose back
2042    # appropriately.
2043    grads, _, _ = _channel_flatten_input(grads, data_format)
2044    n = pfor_input.pfor.loop_len_vector
2045    old_filter_sizes = filter_sizes
2046    filter_sizes *= array_ops.concat([[1, 1, 1], n], axis=0)
2047    output = nn_ops.conv2d_backprop_filter(
2048        inputs,
2049        filter_sizes,
2050        grads,
2051        strides=strides,
2052        padding=padding,
2053        use_cudnn_on_gpu=use_cudnn_on_gpu,
2054        data_format=data_format,
2055        dilations=dilations)
2056    new_filter_shape = array_ops.concat([old_filter_sizes[:3], n, [-1]], axis=0)
2057    output = array_ops.reshape(output, new_filter_shape)
2058    output = array_ops.transpose(output, [3, 0, 1, 2, 4])
2059    return wrap(output, True)
2060
2061
2062def _flatten_with_inner_dim(x, dim, x_rank):
2063  """Merges the first dim with the specified dim."""
2064  shape = array_ops.shape(x)
2065  x = array_ops.transpose(x,
2066                          list(range(1, dim)) + [0] + list(range(dim, x_rank)))
2067
2068  if dim < x_rank - 1:
2069    new_shape_pieces = [shape[1:dim], [-1], shape[dim + 1:]]
2070  else:
2071    new_shape_pieces = [shape[1:dim], [-1]]
2072  new_shape = array_ops.concat(new_shape_pieces, axis=0)
2073  return array_ops.reshape(x, new_shape)
2074
2075
2076def _unflatten_with_inner_dim(x, dim, x_rank, stack_size):
2077  """Undoes _flatten_with_inner_dim."""
2078  shape = array_ops.shape(x)
2079  if dim < x_rank - 1:
2080    new_shape_pieces = [shape[:dim], [stack_size], [-1], shape[dim + 1:]]
2081  else:
2082    new_shape_pieces = [shape[:dim], [stack_size], [-1]]
2083  new_shape = array_ops.concat(new_shape_pieces, axis=0)
2084  x = array_ops.reshape(x, new_shape)
2085  dims_permutation = [dim] + list(range(dim)) + list(range(dim + 1, x_rank + 1))
2086  return array_ops.transpose(x, dims_permutation)
2087
2088
2089@RegisterPFor("DepthwiseConv2dNative")
2090def _convert_depthwise_conv2d_native(pfor_input):
2091  # Kernel can be vectorized, so folding to batch dimension does not work. We
2092  # instead fold into the channel dimension because it is parallel.
2093  stack_size = pfor_input.pfor.loop_len_vector[0]
2094  data_format = pfor_input.get_attr("data_format")
2095  c_dim = 1 if data_format == b"NCHW" else 3
2096  t = _flatten_with_inner_dim(pfor_input.stacked_input(0), c_dim + 1, 5)
2097  kernel = _flatten_with_inner_dim(pfor_input.stacked_input(1), 3, 5)
2098  conv = _create_op(
2099      "DepthwiseConv2dNative", [t, kernel],
2100      [x.dtype for x in pfor_input.outputs],
2101      attrs=pfor_input.op.node_def.attr).outputs[0]
2102  return wrap(_unflatten_with_inner_dim(conv, c_dim, 4, stack_size), True)
2103
2104
2105@RegisterPFor("DepthwiseConv2dNativeBackpropInput")
2106def _convert_depthwise_conv2d_native_backprop_input(pfor_input):
2107  stack_size = pfor_input.pfor.loop_len_vector[0]
2108  input_sizes = pfor_input.unstacked_input(0)
2109  data_format = pfor_input.get_attr("data_format")
2110  c_dim = 1 if data_format == b"NCHW" else 3
2111  input_sizes_mutipliers = [
2112      constant_op.constant([1] * c_dim, dtype=dtypes.int32), [stack_size]
2113  ]
2114  if c_dim < 3:
2115    input_sizes_mutipliers += [
2116        constant_op.constant([1] * (3 - c_dim), dtype=dtypes.int32)
2117    ]
2118  input_sizes *= array_ops.concat(input_sizes_mutipliers, axis=0)
2119  kernel = _flatten_with_inner_dim(pfor_input.stacked_input(1), 3, 5)
2120  out_backprop = _flatten_with_inner_dim(
2121      pfor_input.stacked_input(2), c_dim + 1, 5)
2122  result = _create_op(
2123      "DepthwiseConv2dNativeBackpropInput", [input_sizes, kernel, out_backprop],
2124      [x.dtype for x in pfor_input.outputs],
2125      attrs=pfor_input.op.node_def.attr).outputs[0]
2126  return wrap(_unflatten_with_inner_dim(result, c_dim, 4, stack_size), True)
2127
2128
2129@RegisterPFor("DepthwiseConv2dNativeBackpropFilter")
2130def _convert_depthwise_conv2d_native_backprop_filter(pfor_input):
2131  stack_size = pfor_input.pfor.loop_len_vector[0]
2132  data_format = pfor_input.get_attr("data_format")
2133  c_dim = 1 if data_format == b"NCHW" else 3
2134  inputs = _flatten_with_inner_dim(pfor_input.stacked_input(0), c_dim + 1, 5)
2135  filter_sizes = pfor_input.unstacked_input(1)
2136  filter_sizes_multipliers = [
2137      constant_op.constant([1, 1], dtype=dtypes.int32), [stack_size],
2138      constant_op.constant([1], dtype=dtypes.int32)
2139  ]
2140  filter_sizes *= array_ops.concat(filter_sizes_multipliers, axis=0)
2141  out_backprop = _flatten_with_inner_dim(
2142      pfor_input.stacked_input(2), c_dim + 1, 5)
2143  result = _create_op(
2144      "DepthwiseConv2dNativeBackpropFilter",
2145      [inputs, filter_sizes, out_backprop],
2146      [x.dtype for x in pfor_input.outputs],
2147      attrs=pfor_input.op.node_def.attr).outputs[0]
2148  return wrap(_unflatten_with_inner_dim(result, 2, 4, stack_size), True)
2149
2150
2151@RegisterPForWithArgs("LogSoftmax", gen_nn_ops.log_softmax)
2152@RegisterPForWithArgs("Softmax", gen_nn_ops.softmax)
2153def _convert_softmax(pfor_input, op_type, op_func):
2154  del op_type
2155  return wrap(op_func(pfor_input.stacked_input(0)), True)
2156
2157
2158# array_ops
2159
2160
2161@RegisterPForWithArgs("Identity", array_ops.identity)
2162@RegisterPForWithArgs("StopGradient", array_ops.stop_gradient)
2163@RegisterPForWithArgs("MatrixDiag", array_ops.matrix_diag)
2164@RegisterPForWithArgs("MatrixDiagPart", array_ops.matrix_diag_part)
2165@RegisterPForWithArgs("_EagerConst", array_ops.identity)
2166def _convert_identity(pfor_input, op_type, op_func):
2167  del op_type
2168  return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
2169
2170
2171@RegisterPFor("IdentityN")
2172def _convert_identity_n(pfor_input):
2173  outputs = array_ops.identity_n([x.t for x in pfor_input.inputs])
2174  return [
2175      wrap(out, inp.is_stacked) for out, inp in zip(outputs, pfor_input.inputs)
2176  ]
2177
2178
2179@RegisterPFor("Reshape")
2180def _convert_reshape(pfor_input):
2181  t = pfor_input.stacked_input(0)
2182  shape = pfor_input.unstacked_input(1)
2183  new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0)
2184  return wrap(array_ops.reshape(t, new_shape), True)
2185
2186
2187@RegisterPFor("Fill")
2188def _convert_fill(pfor_input):
2189  dims = pfor_input.unstacked_input(0)
2190  value = pfor_input.stacked_input(1)
2191  # Expand the rank of `value`
2192  new_shape = array_ops.concat(
2193      [[-1], array_ops.ones([array_ops.size(dims)], dtype=dtypes.int32)],
2194      axis=0)
2195  value = array_ops.reshape(value, new_shape)
2196  # Compute the new output shape
2197  new_dims = array_ops.concat([pfor_input.pfor.loop_len_vector, dims], axis=0)
2198  # Broadcast
2199  return wrap(array_ops.broadcast_to(value, new_dims), True)
2200
2201
2202@RegisterPFor("BroadcastTo")
2203def _convert_broadcast_to(pfor_input):
2204  t = pfor_input.stacked_input(0)
2205  shape = pfor_input.unstacked_input(1)
2206  new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0)
2207
2208  # Expand dims of stacked t to broadcast against the new shape.
2209  # TODO(davmre): consider factoring out common code with
2210  # `expanddim_inputs_for_broadcast`, which has similar logic but with
2211  # implicit shapes (of input Tensors) rather than explicit shapes.
2212  rank_diff = array_ops.shape(new_shape)[0] - array_ops.rank(t)
2213  ones = array_ops.tile([1], array_ops.reshape(rank_diff, [1]))
2214  t_shape = array_ops.shape(t)
2215  t_expanded_shape = array_ops.concat([t_shape[:1], ones, t_shape[1:]], axis=0)
2216
2217  return wrap(
2218      array_ops.broadcast_to(array_ops.reshape(t, t_expanded_shape), new_shape),
2219      True)
2220
2221
2222@RegisterPFor("ExpandDims")
2223def _convert_expanddims(pfor_input):
2224  t = pfor_input.stacked_input(0)
2225  dim = pfor_input.unstacked_input(1)
2226  dim += math_ops.cast(dim >= 0, dim.dtype)
2227  return wrap(array_ops.expand_dims(t, axis=dim), True)
2228
2229
2230@RegisterPForWithArgs("LowerBound", gen_array_ops.lower_bound)
2231@RegisterPForWithArgs("UpperBound", gen_array_ops.upper_bound)
2232def _convert_searchsorted(pfor_input, _, op_func):
2233  pfor_input.stack_inputs()
2234  sorted_inputs = _flatten_first_two_dims(pfor_input.stacked_input(0))
2235  values = _flatten_first_two_dims(pfor_input.stacked_input(1))
2236  out_type = pfor_input.get_attr("out_type")
2237  output = op_func(sorted_inputs, values, out_type)
2238  return wrap(
2239      _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector), True)
2240
2241
2242@RegisterPFor("MatrixBandPart")
2243def _convert_matrix_band_part(pfor_input):
2244  t = pfor_input.stacked_input(0)
2245  num_lower = pfor_input.unstacked_input(1)
2246  num_upper = pfor_input.unstacked_input(2)
2247  return wrap(
2248      array_ops.matrix_band_part(t, num_lower=num_lower, num_upper=num_upper),
2249      True)
2250
2251
2252@RegisterPFor("MatrixSetDiag")
2253def _convert_matrix_set_diag(pfor_input):
2254  pfor_input.stack_inputs()
2255  t = pfor_input.stacked_input(0)
2256  diag = pfor_input.stacked_input(1)
2257  return wrap(array_ops.matrix_set_diag(t, diag), True)
2258
2259
2260# Registrations for Matrix{Diag,DiagPart,SetDiag}V2-3.
2261# The input orders defined in the OpKernel and the actual python API are
2262# different (for compatibility with V1), so we cannot use _convert_identity.
2263# v2 is not compatible with v3 and is never exposed on the public API.
2264@RegisterPFor("MatrixDiagV2")
2265@RegisterPFor("MatrixDiagV3")
2266def _convert_matrix_diag_v2(pfor_input):
2267  params = {
2268      "diagonal": pfor_input.stacked_input(0),
2269      "k": pfor_input.unstacked_input(1),
2270      "num_rows": pfor_input.unstacked_input(2),
2271      "num_cols": pfor_input.unstacked_input(3),
2272      "padding_value": pfor_input.unstacked_input(4)
2273  }
2274  if pfor_input.op_type == "MatrixDiagV2":
2275    return wrap(array_ops.matrix_diag_v2(**params), True)
2276  params["align"] = pfor_input.get_attr("align")
2277  return wrap(array_ops.matrix_diag(**params), True)
2278
2279
2280@RegisterPFor("Diag")
2281def _convert_diag(pfor_input):
2282  diag = pfor_input.stacked_input(0)
2283  if diag.shape.ndims == 2:
2284    # We can use matrix_diag.
2285    return wrap(array_ops.matrix_diag(diag), True)
2286  else:
2287    # It is not clear if we can do better than a while loop here with existing
2288    # kernels.
2289    return _fallback_converter(pfor_input, warn=False)
2290
2291
2292# See notes for MatrixDiagV2
2293@RegisterPFor("MatrixDiagPartV2")
2294@RegisterPFor("MatrixDiagPartV3")
2295def _convert_matrix_diag_part_v2(pfor_input):
2296  params = {
2297      "input": pfor_input.stacked_input(0),
2298      "k": pfor_input.unstacked_input(1),
2299      "padding_value": pfor_input.unstacked_input(2)
2300  }
2301  if pfor_input.op_type == "MatrixDiagPartV2":
2302    return wrap(array_ops.matrix_diag_part_v2(**params), True)
2303  params["align"] = pfor_input.get_attr("align")
2304  return wrap(array_ops.matrix_diag_part(**params), True)
2305
2306
2307# See notes for MatrixDiagV2
2308@RegisterPFor("MatrixSetDiagV2")
2309@RegisterPFor("MatrixSetDiagV3")
2310def _convert_matrix_set_diag_v2(pfor_input):
2311  pfor_input.stack_inputs([0, 1])
2312  params = {
2313      "input": pfor_input.stacked_input(0),
2314      "diagonal": pfor_input.stacked_input(1),
2315      "k": pfor_input.unstacked_input(2)
2316  }
2317  if pfor_input.op_type == "MatrixSetDiagV2":
2318    return wrap(array_ops.matrix_set_diag_v2(**params), True)
2319  params["align"] = pfor_input.get_attr("align")
2320  return wrap(array_ops.matrix_set_diag(**params), True)
2321
2322
2323@RegisterPFor("DiagPart")
2324def _convert_diag_part(pfor_input):
2325  inp = pfor_input.stacked_input(0)
2326  if inp.shape.ndims == 3:
2327    # We can use matrix_diag_part.
2328    return wrap(array_ops.matrix_diag_part(inp), True)
2329  else:
2330    # It is not clear if we can do better than a while loop here with existing
2331    # kernels.
2332    return _fallback_converter(pfor_input, warn=False)
2333
2334
2335@RegisterPFor("OneHot")
2336def _convert_one_hot(pfor_input):
2337  indices = pfor_input.stacked_input(0)
2338  depth = pfor_input.unstacked_input(1)
2339  on_value = pfor_input.unstacked_input(2)
2340  off_value = pfor_input.unstacked_input(3)
2341  axis = pfor_input.get_attr("axis")
2342  if axis >= 0:
2343    axis += 1
2344  return wrap(
2345      array_ops.one_hot(indices, depth, on_value, off_value, axis), True)
2346
2347
2348@RegisterPFor("Slice")
2349def _convert_slice(pfor_input):
2350  t = pfor_input.stacked_input(0)
2351  begin = pfor_input.unstacked_input(1)
2352  size = pfor_input.unstacked_input(2)
2353  begin = array_ops.concat([[0], begin], axis=0)
2354  size = array_ops.concat([[-1], size], axis=0)
2355  return wrap(array_ops.slice(t, begin, size), True)
2356
2357
2358@RegisterPFor("Tile")
2359def _convert_tile(pfor_input):
2360  t = pfor_input.stacked_input(0)
2361  multiples = pfor_input.unstacked_input(1)
2362  multiples = array_ops.concat([[1], multiples], 0)
2363  return wrap(array_ops.tile(t, multiples), True)
2364
2365
2366@RegisterPFor("Pack")
2367def _convert_pack(pfor_input):
2368  pfor_input.stack_inputs()
2369  axis = pfor_input.get_attr("axis")
2370  if axis >= 0:
2371    axis += 1
2372  return wrap(
2373      array_ops.stack([x.t for x in pfor_input.inputs], axis=axis), True)
2374
2375
2376@RegisterPFor("Unpack")
2377def _convert_unpack(pfor_input):
2378  value = pfor_input.stacked_input(0)
2379  axis = pfor_input.get_attr("axis")
2380  if axis >= 0:
2381    axis += 1
2382  num = pfor_input.get_attr("num")
2383  return [wrap(x, True) for x in array_ops.unstack(value, axis=axis, num=num)]
2384
2385
2386@RegisterPFor("Pad")
2387def _convert_pad(pfor_input):
2388  t = pfor_input.stacked_input(0)
2389  paddings = pfor_input.unstacked_input(1)
2390  paddings = array_ops.concat([[[0, 0]], paddings], 0)
2391  return wrap(array_ops.pad(t, paddings, mode="CONSTANT"), True)
2392
2393
2394@RegisterPFor("Split")
2395def _convert_split(pfor_input):
2396  split_dim = pfor_input.unstacked_input(0)
2397  t = pfor_input.stacked_input(1)
2398  num_split = pfor_input.get_attr("num_split")
2399  split_dim += math_ops.cast(split_dim >= 0, dtypes.int32)
2400  return [wrap(x, True) for x in array_ops.split(t, num_split, axis=split_dim)]
2401
2402
2403@RegisterPFor("SplitV")
2404def _convert_split_v(pfor_input):
2405  t = pfor_input.stacked_input(0)
2406  splits = pfor_input.unstacked_input(1)
2407  split_dim = pfor_input.unstacked_input(2)
2408  split_dim += math_ops.cast(split_dim >= 0, dtypes.int32)
2409  return [wrap(x, True) for x in array_ops.split(t, splits, axis=split_dim)]
2410
2411
2412@RegisterPFor("Squeeze")
2413def _convert_squeeze(pfor_input):
2414  t = pfor_input.stacked_input(0)
2415  squeeze_dims = pfor_input.get_attr("squeeze_dims")
2416  squeeze_dims = [i + 1 if i >= 0 else i for i in squeeze_dims]
2417  return wrap(array_ops.squeeze(t, axis=squeeze_dims), True)
2418
2419
2420@RegisterPFor("ReverseV2")
2421def _convert_reverse(pfor_input):
2422  value = pfor_input.stacked_input(0)
2423  axis = pfor_input.unstacked_input(1)
2424  new_axis = array_ops.where_v2(axis >= 0, axis + 1, axis)
2425  return wrap(gen_array_ops.reverse_v2(value, axis=new_axis), True)
2426
2427
2428@RegisterPForWithArgs("Transpose", gen_array_ops.transpose)
2429@RegisterPForWithArgs("ConjugateTranspose", gen_array_ops.conjugate_transpose)
2430def _convert_transpose(pfor_input, _, op_func):
2431  t = pfor_input.stacked_input(0)
2432  perm = pfor_input.unstacked_input(1)
2433  new_perm = array_ops.concat([[0], perm + 1], axis=0)
2434  return wrap(op_func(t, new_perm), True)
2435
2436
2437@RegisterPFor("ZerosLike")
2438def _convert_zeroslike(pfor_input):
2439  t = pfor_input.stacked_input(0)
2440  shape = array_ops.shape(t)[1:]
2441  return wrap(array_ops.zeros(shape, dtype=t.dtype), False)
2442
2443
2444@RegisterPFor("Gather")
2445@RegisterPFor("GatherV2")
2446def _convert_gather(pfor_input):
2447  param, param_stacked, _ = pfor_input.input(0)
2448  indices, indices_stacked, _ = pfor_input.input(1)
2449  batch_dims = pfor_input.get_attr("batch_dims")
2450
2451  op_type = pfor_input.op_type
2452  if op_type == "Gather":
2453    validate_indices = pfor_input.get_attr("validate_indices")
2454    axis = 0
2455  else:
2456    validate_indices = None
2457    # Assume we will never have a Tensor with rank > 2**32.
2458    axis = math_ops.cast(pfor_input.unstacked_input(2), dtypes.int32)
2459    axis_value = tensor_util.constant_value(axis)
2460    if axis_value is not None:
2461      axis = axis_value
2462  if indices_stacked and not param_stacked:
2463    if indices is pfor_input.pfor.all_indices and axis == 0:
2464      param_shape0 = tensor_shape.dimension_value(param.shape[0])
2465      indices_shape0 = tensor_shape.dimension_value(indices.shape[0])
2466      if param_shape0 is not None and indices_shape0 == param_shape0:
2467        # Note that with loops and conditionals, indices may not be contiguous.
2468        # However they will be sorted and unique. So if the shape matches, then
2469        # it must be picking up all the rows of param.
2470        return wrap(param, True)
2471
2472    if batch_dims != 0:
2473      # Convert `batch_dims` to its positive equivalent if necessary.
2474      batch_dims_pos = batch_dims
2475      if batch_dims < 0:
2476        batch_dims_pos += array_ops.rank(indices)
2477      # In order to maintain
2478      #   indices.shape[:batch_dims] == params.shape[:batch_dims]
2479      # with stacked indices, we move the first dimension of `indices` to the
2480      # `batch_dims + 1`th position. The (non-batch) index dimensions will be
2481      # inserted into the shape of `output` at the `axis` dimension, which is
2482      # then transposed to the front (below).
2483      order = array_ops.concat([
2484          math_ops.range(1, batch_dims_pos + 1),
2485          [0],
2486          math_ops.range(batch_dims_pos + 1, array_ops.rank(indices))], axis=0)
2487      indices = array_ops.transpose(indices, order)
2488
2489    output = array_ops.gather(
2490        param, indices, validate_indices=validate_indices, axis=axis,
2491        batch_dims=batch_dims)
2492    if axis != 0:
2493      axis = control_flow_ops.cond(axis < 0,
2494                                   lambda: axis + array_ops.rank(param),
2495                                   lambda: axis)
2496      order = array_ops.concat(
2497          [[axis],
2498           math_ops.range(axis),
2499           math_ops.range(axis + 1, array_ops.rank(output))],
2500          axis=0)
2501      output = control_flow_ops.cond(
2502          math_ops.equal(axis, 0), lambda: output,
2503          lambda: array_ops.transpose(output, order))
2504    return wrap(output, True)
2505  if param_stacked:
2506    pfor_input.stack_inputs(stack_indices=[1])
2507    indices = pfor_input.stacked_input(1)
2508
2509    output = array_ops.gather(
2510        param, indices,
2511        axis=array_ops.where(axis >= 0, axis + 1, axis),
2512        batch_dims=(batch_dims + 1 if batch_dims >= 0 else batch_dims))
2513    return wrap(output, True)
2514
2515
2516@RegisterPFor("GatherNd")
2517def _convert_gather_nd(pfor_input):
2518  # TODO(jmenick): Add support for unstacked params.
2519  pfor_input.stack_inputs(stack_indices=[1])
2520  params = pfor_input.stacked_input(0)
2521  indices = pfor_input.stacked_input(1)
2522  stacked_result = array_ops.gather_nd(params, indices, batch_dims=1)
2523  return wrap(stacked_result, True)
2524
2525
2526@RegisterPFor("ConcatV2")
2527def _convert_concatv2(pfor_input):
2528  n = pfor_input.num_inputs
2529  pfor_input.stack_inputs(stack_indices=range(n - 1))
2530  axis = pfor_input.unstacked_input(n - 1)
2531  axis += math_ops.cast(axis >= 0, axis.dtype)
2532  return wrap(
2533      array_ops.concat([x.t for x in pfor_input.inputs[:n - 1]], axis=axis),
2534      True)
2535
2536
2537@RegisterPFor("StridedSlice")
2538def _convert_strided_slice(pfor_input):
2539  inp = pfor_input.stacked_input(0)
2540  begin = pfor_input.unstacked_input(1)
2541  end = pfor_input.unstacked_input(2)
2542  strides = pfor_input.unstacked_input(3)
2543  begin_mask = pfor_input.get_attr("begin_mask")
2544  end_mask = pfor_input.get_attr("end_mask")
2545  ellipsis_mask = pfor_input.get_attr("ellipsis_mask")
2546  new_axis_mask = pfor_input.get_attr("new_axis_mask")
2547  shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask")
2548
2549  begin = array_ops.concat([[0], begin], axis=0)
2550  end = array_ops.concat([[0], end], axis=0)
2551  strides = array_ops.concat([[1], strides], axis=0)
2552  begin_mask = begin_mask << 1 | 1
2553  end_mask = end_mask << 1 | 1
2554  ellipsis_mask <<= 1
2555  new_axis_mask <<= 1
2556  shrink_axis_mask <<= 1
2557  return wrap(
2558      array_ops.strided_slice(
2559          inp,
2560          begin,
2561          end,
2562          strides,
2563          begin_mask=begin_mask,
2564          end_mask=end_mask,
2565          ellipsis_mask=ellipsis_mask,
2566          new_axis_mask=new_axis_mask,
2567          shrink_axis_mask=shrink_axis_mask), True)
2568
2569
2570@RegisterPFor("StridedSliceGrad")
2571def _convert_strided_slice_grad(pfor_input):
2572  shape = pfor_input.unstacked_input(0)
2573  begin = pfor_input.unstacked_input(1)
2574  end = pfor_input.unstacked_input(2)
2575  strides = pfor_input.unstacked_input(3)
2576  dy = pfor_input.stacked_input(4)
2577  begin_mask = pfor_input.get_attr("begin_mask")
2578  end_mask = pfor_input.get_attr("end_mask")
2579  ellipsis_mask = pfor_input.get_attr("ellipsis_mask")
2580  new_axis_mask = pfor_input.get_attr("new_axis_mask")
2581  shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask")
2582
2583  shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0)
2584  begin = array_ops.concat([[0], begin], axis=0)
2585  end = array_ops.concat([[0], end], axis=0)
2586  strides = array_ops.concat([[1], strides], axis=0)
2587  begin_mask = begin_mask << 1 | 1
2588  end_mask = end_mask << 1 | 1
2589  ellipsis_mask <<= 1
2590  new_axis_mask <<= 1
2591  shrink_axis_mask <<= 1
2592  return wrap(
2593      array_ops.strided_slice_grad(
2594          shape,
2595          begin,
2596          end,
2597          strides,
2598          dy,
2599          begin_mask=begin_mask,
2600          end_mask=end_mask,
2601          ellipsis_mask=ellipsis_mask,
2602          new_axis_mask=new_axis_mask,
2603          shrink_axis_mask=shrink_axis_mask), True)
2604
2605
2606@RegisterPFor("CheckNumerics")
2607def _convert_check_numerics(pfor_input):
2608  t = pfor_input.stacked_input(0)
2609  message = pfor_input.get_attr("message")
2610  return wrap(gen_array_ops.check_numerics(t, message), True)
2611
2612
2613@RegisterPFor("EnsureShape")
2614def _convert_ensure_shape(pfor_input):
2615  t = pfor_input.stacked_input(0)
2616  shape = tensor_shape.TensorShape(pfor_input.get_attr("shape"))
2617  return wrap(gen_array_ops.ensure_shape(t, [None] + shape), True)
2618
2619
2620# manip_ops
2621
2622
2623@RegisterPFor("Roll")
2624def _convert_roll(pfor_input):
2625  t = pfor_input.stacked_input(0)
2626  shift = pfor_input.unstacked_input(1)
2627  axis = pfor_input.unstacked_input(2)
2628  return wrap(manip_ops.roll(t, shift, axis + 1), True)
2629
2630
2631# math_ops
2632
2633
2634@RegisterPFor("MatMul")
2635def _convert_matmul(pfor_input):
2636  # TODO(agarwal): Check if tiling is faster than two transposes.
2637  a, a_stacked, _ = pfor_input.input(0)
2638  b, b_stacked, _ = pfor_input.input(1)
2639  tr_a = pfor_input.get_attr("transpose_a")
2640  tr_b = pfor_input.get_attr("transpose_b")
2641  if a_stacked and b_stacked:
2642    output = wrap(math_ops.matmul(a, b, adjoint_a=tr_a, adjoint_b=tr_b), True)
2643    return output
2644  elif a_stacked:
2645    if tr_a:
2646      a = array_ops.transpose(a, [0, 2, 1])
2647    if a.shape.is_fully_defined():
2648      x, y, z = a.shape
2649    else:
2650      x, y, z = [
2651          array_ops.reshape(i, [])
2652          for i in array_ops.split(array_ops.shape(a), 3)
2653      ]
2654    a = array_ops.reshape(a, [x * y, z])
2655    prod = math_ops.matmul(a, b, transpose_b=tr_b)
2656    return wrap(array_ops.reshape(prod, [x, y, -1]), True)
2657  else:
2658    assert b_stacked
2659    if tr_b:
2660      perm = [2, 0, 1]
2661      b = array_ops.transpose(b, perm)
2662    else:
2663      # As an optimization, if one of the first two dimensions is 1, then we can
2664      # reshape instead of transpose.
2665      # TODO(agarwal): This check can be done inside Transpose kernel.
2666      b_shape = array_ops.shape(b)
2667      min_dim = math_ops.minimum(b_shape[0], b_shape[1])
2668      perm = array_ops.where(
2669          math_ops.equal(min_dim, 1), [0, 1, 2], [1, 0, 2])
2670      new_shape = array_ops.stack([b_shape[1], b_shape[0], b_shape[2]])
2671      b = array_ops.transpose(b, perm)
2672      b = array_ops.reshape(b, new_shape)
2673
2674    if b.shape.is_fully_defined():
2675      x, y, z = b.shape
2676    else:
2677      x, y, z = [
2678          array_ops.reshape(i, [])
2679          for i in array_ops.split(array_ops.shape(b), 3)
2680      ]
2681    b = array_ops.reshape(b, [x, y * z])
2682    prod = math_ops.matmul(a, b, transpose_a=tr_a)
2683    prod = array_ops.reshape(prod, [-1, y, z])
2684    prod = array_ops.transpose(prod, [1, 0, 2])
2685    return wrap(prod, True)
2686
2687
2688# TODO(rmlarsen): Use the converter of BatchMatMulV2 once compatibility window
2689# is met.
2690@RegisterPFor("BatchMatMul")
2691def _convert_batch_mat_mul(pfor_input):
2692  # TODO(agarwal): There may be a more efficient way to do this instead of
2693  # stacking the inputs.
2694  pfor_input.stack_inputs()
2695  x = pfor_input.stacked_input(0)
2696  y = pfor_input.stacked_input(1)
2697  adj_x = pfor_input.get_attr("adj_x")
2698  adj_y = pfor_input.get_attr("adj_y")
2699
2700  x = _flatten_first_two_dims(x)
2701  y = _flatten_first_two_dims(y)
2702  output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
2703  output = _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector)
2704  return wrap(output, True)
2705
2706
2707@RegisterPFor("BatchMatMulV2")
2708def _convert_batch_mat_mul_v2(pfor_input):
2709  pfor_input.expanddim_inputs_for_broadcast()
2710  x = pfor_input.input(0)[0]
2711  y = pfor_input.input(1)[0]
2712  adj_x = pfor_input.get_attr("adj_x")
2713  adj_y = pfor_input.get_attr("adj_y")
2714
2715  output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
2716  return wrap(output, True)
2717
2718
2719@RegisterPForWithArgs("Sum", math_ops.reduce_sum)
2720@RegisterPForWithArgs("Prod", math_ops.reduce_prod)
2721@RegisterPForWithArgs("Max", math_ops.reduce_max)
2722@RegisterPForWithArgs("Min", math_ops.reduce_min)
2723@RegisterPForWithArgs("Mean", math_ops.reduce_mean)
2724@RegisterPForWithArgs("All", math_ops.reduce_all)
2725@RegisterPForWithArgs("Any", math_ops.reduce_any)
2726def _convert_reduction(pfor_input, _, op_func):
2727  t = pfor_input.stacked_input(0)
2728  indices = pfor_input.unstacked_input(1)
2729  # Shift positive indices by one to account for the extra dimension.
2730  indices += math_ops.cast(indices >= 0, indices.dtype)
2731  keep_dims = pfor_input.get_attr("keep_dims")
2732  return wrap(op_func(t, indices, keepdims=keep_dims), True)
2733
2734
2735@RegisterPForWithArgs("ArgMax", math_ops.argmax)
2736@RegisterPForWithArgs("ArgMin", math_ops.argmin)
2737def _convert_argmax_argmin(pfor_input, _, op_func):
2738  t = pfor_input.stacked_input(0)
2739  dimension = pfor_input.unstacked_input(1)
2740  dimension += math_ops.cast(dimension >= 0, dimension.dtype)
2741  output_type = pfor_input.get_attr("output_type")
2742  return wrap(op_func(t, axis=dimension, output_type=output_type), True)
2743
2744
2745@RegisterPFor("Bucketize")
2746def _convert_bucketize(pfor_input):
2747  t = pfor_input.stacked_input(0)
2748  boundaries = pfor_input.get_attr("boundaries")
2749  return wrap(math_ops.bucketize(t, boundaries), True)
2750
2751
2752@RegisterPFor("ClipByValue")
2753def _convert_clip_by_value(pfor_input):
2754  t = pfor_input.stacked_input(0)
2755  clip_value_min = pfor_input.unstacked_input(1)
2756  clip_value_max = pfor_input.unstacked_input(2)
2757  return wrap(gen_math_ops.clip_by_value(t, clip_value_min, clip_value_max),
2758              True)
2759
2760
2761@RegisterPForWithArgs("Cumsum", math_ops.cumsum)
2762@RegisterPForWithArgs("Cumprod", math_ops.cumprod)
2763def _convert_cumfoo(pfor_input, _, op_func):
2764  t = pfor_input.stacked_input(0)
2765  axis = pfor_input.unstacked_input(1)
2766  # Shift positive indices by one to account for the extra dimension.
2767  axis += math_ops.cast(axis >= 0, axis.dtype)
2768  exclusive = pfor_input.get_attr("exclusive")
2769  reverse = pfor_input.get_attr("reverse")
2770  return wrap(op_func(t, axis, exclusive=exclusive, reverse=reverse), True)
2771
2772
2773@RegisterPFor("BiasAdd")
2774def _convert_biasadd(pfor_input):
2775  t, t_stacked, _ = pfor_input.input(0)
2776  bias, bias_stacked, _ = pfor_input.input(1)
2777  data_format = pfor_input.get_attr("data_format").decode()
2778  if bias_stacked:
2779    # BiasAdd only supports 1-D biases, so cast bias to match value and use Add.
2780    pfor_input.expanddim_inputs_for_broadcast()
2781    t, _, _ = pfor_input.input(0)
2782    bias = math_ops.cast(pfor_input.stacked_input(1), t.dtype)
2783    if compat.as_bytes(data_format) == b"NCHW":
2784      b_shape = array_ops.shape(bias)
2785      new_b_shape = array_ops.concat(
2786          [b_shape[:-3], b_shape[-1:], b_shape[-3:-1]], axis=0)
2787      bias = array_ops.reshape(bias, new_b_shape)
2788    return wrap(math_ops.add(t, bias), True)
2789  else:
2790    assert t_stacked, "At least one input to BiasAdd should be loop variant."
2791    if compat.as_bytes(data_format) == b"NCHW":
2792      shape = array_ops.shape(t)
2793      flattened_shape = array_ops.concat([[-1], shape[2:]], axis=0)
2794      t = array_ops.reshape(t, flattened_shape)
2795      t = nn_ops.bias_add(t, bias, data_format="NCHW")
2796      t = array_ops.reshape(t, shape)
2797      return wrap(t, True)
2798    return wrap(nn_ops.bias_add(t, bias, data_format=data_format), True)
2799
2800
2801@RegisterPForWithArgs("UnsortedSegmentSum", math_ops.unsorted_segment_sum)
2802@RegisterPForWithArgs("UnsortedSegmentMax", math_ops.unsorted_segment_max)
2803@RegisterPForWithArgs("UnsortedSegmentMin", math_ops.unsorted_segment_min)
2804@RegisterPForWithArgs("UnsortedSegmentProd", math_ops.unsorted_segment_prod)
2805def _convert_unsortedsegmentsum(pfor_input, _, op_func):
2806  pfor_input.stack_inputs([0, 1])
2807  data = pfor_input.stacked_input(0)
2808  segment_ids = pfor_input.stacked_input(1)
2809  # TODO(agarwal): handle stacked?
2810  num_segments = pfor_input.unstacked_input(2)
2811  if segment_ids.dtype != num_segments.dtype:
2812    segment_ids = math_ops.cast(segment_ids, dtypes.int64)
2813    num_segments = math_ops.cast(num_segments, dtypes.int64)
2814  dtype = segment_ids.dtype
2815  segment_shape = array_ops.shape(segment_ids, out_type=dtype)
2816  n = segment_shape[0]
2817  ones = array_ops.ones_like(segment_shape, dtype=dtype)[1:]
2818  segment_offset = num_segments * math_ops.range(n, dtype=dtype)
2819  segment_offset = array_ops.reshape(segment_offset,
2820                                     array_ops.concat([[n], ones], axis=0))
2821  segment_ids += segment_offset
2822  num_segments = math_ops.cast(num_segments, dtypes.int64) * math_ops.cast(
2823      n, dtypes.int64)
2824  output = op_func(data, segment_ids, num_segments)
2825  new_output_shape = array_ops.concat(
2826      [[n, -1], array_ops.shape(output)[1:]], axis=0)
2827  output = array_ops.reshape(output, new_output_shape)
2828  return wrap(output, True)
2829
2830
2831def _flatten_array_with_offset(ids, offset_delta, num_rows):
2832  """Flattens a rank 2 tensor, adding an offset to each row."""
2833  # Note that if `ids` is rank 1, it is broadcast to rank 2.
2834  offset_delta = math_ops.cast(offset_delta, ids.dtype)
2835  n = math_ops.cast(num_rows, dtype=ids.dtype)
2836  offsets = math_ops.range(
2837      start=0, limit=n * offset_delta, delta=offset_delta, dtype=ids.dtype)
2838  offsets = array_ops.expand_dims(offsets, -1)
2839  ids += offsets
2840  return array_ops.reshape(ids, [-1])
2841
2842
2843@RegisterPForWithArgs("SparseSegmentSum", math_ops.sparse_segment_sum_v2)
2844@RegisterPForWithArgs("SparseSegmentMean", math_ops.sparse_segment_mean_v2)
2845@RegisterPForWithArgs("SparseSegmentSqrtN", math_ops.sparse_segment_sqrt_n_v2)
2846@RegisterPForWithArgs("SparseSegmentSumWithNumSegments",
2847                      math_ops.sparse_segment_sum_v2)
2848@RegisterPForWithArgs("SparseSegmentMeanWithNumSegments",
2849                      math_ops.sparse_segment_mean_v2)
2850@RegisterPForWithArgs("SparseSegmentSqrtNWithNumSegments",
2851                      math_ops.sparse_segment_sqrt_n_v2)
2852def _convert_sparse_segment(pfor_input, _, op_func):
2853  _, segment_ids_stacked, _ = pfor_input.input(2)
2854  if segment_ids_stacked:
2855    pfor_input.stack_inputs([1])
2856  data, data_stacked, _ = pfor_input.input(0)
2857  indices, _, _ = pfor_input.input(1)
2858  num_inputs = len(pfor_input.inputs)
2859  assert num_inputs in (3, 4)
2860  if num_inputs == 3:
2861    # `segment_ids` needs to be unstacked since otherwise output sizes could
2862    # differ across pfor iterations.
2863    segment_ids = pfor_input.unstacked_input(2)
2864    num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1)
2865  else:
2866    segment_ids, _, _ = pfor_input.input(2)
2867    num_segments = pfor_input.unstacked_input(3)
2868
2869  n = pfor_input.pfor.loop_len_vector[0]
2870  if data_stacked:
2871    indices = _flatten_array_with_offset(indices, array_ops.shape(data)[1], n)
2872    data = _flatten_first_two_dims(data)
2873  else:
2874    indices = array_ops.reshape(indices, [-1])
2875  segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n)
2876
2877  if num_inputs == 3:
2878    num_segments = None
2879  else:
2880    num_segments *= n
2881  output = op_func(data, indices, segment_ids, num_segments=num_segments)
2882  output = _unflatten_first_dim(output, [n])
2883  return wrap(output, True)
2884
2885
2886@RegisterPForWithArgs("SparseSegmentSumGrad", math_ops.sparse_segment_sum_grad)
2887@RegisterPForWithArgs("SparseSegmentMeanGrad",
2888                      math_ops.sparse_segment_mean_grad)
2889@RegisterPForWithArgs("SparseSegmentSqrtNGrad",
2890                      math_ops.sparse_segment_sqrt_n_grad)
2891def _convert_sparse_segment_grad(pfor_input, _, op_func):
2892  grad = pfor_input.stacked_input(0)
2893  indices = pfor_input.unstacked_input(1)
2894  segment_ids = pfor_input.unstacked_input(2)
2895  dim0 = pfor_input.unstacked_input(3)
2896
2897  n = pfor_input.pfor.loop_len_vector[0]
2898  indices = _flatten_array_with_offset(indices, dim0, n)
2899  num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1)
2900  segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n)
2901  grad = _flatten_first_two_dims(grad)
2902  dim0 *= n
2903  output = op_func(grad, indices, segment_ids, dim0)
2904  output = _unflatten_first_dim(output, [n])
2905  return wrap(output, True)
2906
2907
2908@RegisterPFor("Cast")
2909def _convert_cast(pfor_input):
2910  inp = pfor_input.stacked_input(0)
2911  dtype = pfor_input.get_attr("DstT")
2912  return wrap(math_ops.cast(inp, dtype), True)
2913
2914
2915@RegisterPFor("Abs")
2916@RegisterPFor("Acos")
2917@RegisterPFor("Acosh")
2918@RegisterPFor("Add")
2919@RegisterPFor("AddV2")
2920@RegisterPFor("Angle")
2921@RegisterPFor("Asin")
2922@RegisterPFor("Asinh")
2923@RegisterPFor("Atan")
2924@RegisterPFor("Atan2")
2925@RegisterPFor("Atanh")
2926@RegisterPFor("BesselI0")
2927@RegisterPFor("BesselI1")
2928@RegisterPFor("BesselI0e")
2929@RegisterPFor("BesselI1e")
2930@RegisterPFor("BesselK0")
2931@RegisterPFor("BesselK1")
2932@RegisterPFor("BesselK0e")
2933@RegisterPFor("BesselK1e")
2934@RegisterPFor("BesselJ0")
2935@RegisterPFor("BesselJ1")
2936@RegisterPFor("BesselY0")
2937@RegisterPFor("BesselY1")
2938@RegisterPFor("BitwiseAnd")
2939@RegisterPFor("BitwiseOr")
2940@RegisterPFor("BitwiseXor")
2941@RegisterPFor("Ceil")
2942@RegisterPFor("Complex")
2943@RegisterPFor("ComplexAbs")
2944@RegisterPFor("Conj")
2945@RegisterPFor("Cos")
2946@RegisterPFor("Cosh")
2947@RegisterPFor("Dawsn")
2948@RegisterPFor("Digamma")
2949@RegisterPFor("Div")
2950@RegisterPFor("DivNoNan")
2951@RegisterPFor("Elu")
2952@RegisterPFor("Erf")
2953@RegisterPFor("Erfc")
2954@RegisterPFor("Erfinv")
2955@RegisterPFor("Exp")
2956@RegisterPFor("Expint")
2957@RegisterPFor("Expm1")
2958@RegisterPFor("Floor")
2959@RegisterPFor("FloorDiv")
2960@RegisterPFor("FloorMod")
2961@RegisterPFor("FresnelCos")
2962@RegisterPFor("FresnelSin")
2963@RegisterPFor("Greater")
2964@RegisterPFor("GreaterEqual")
2965@RegisterPFor("Igamma")
2966@RegisterPFor("IgammaGradA")
2967@RegisterPFor("Igammac")
2968@RegisterPFor("Imag")
2969@RegisterPFor("Inv")
2970@RegisterPFor("Invert")
2971@RegisterPFor("IsFinite")
2972@RegisterPFor("IsInf")
2973@RegisterPFor("IsNan")
2974@RegisterPFor("LeftShift")
2975@RegisterPFor("Less")
2976@RegisterPFor("LessEqual")
2977@RegisterPFor("Lgamma")
2978@RegisterPFor("Log")
2979@RegisterPFor("Log1p")
2980@RegisterPFor("LogicalAnd")
2981@RegisterPFor("LogicalNot")
2982@RegisterPFor("LogicalOr")
2983@RegisterPFor("LogicalXor")
2984@RegisterPFor("Maximum")
2985@RegisterPFor("Minimum")
2986@RegisterPFor("Mod")
2987@RegisterPFor("Mul")
2988@RegisterPFor("MulNoNan")
2989@RegisterPFor("Ndtri")
2990@RegisterPFor("Neg")
2991@RegisterPFor("Polygamma")
2992@RegisterPFor("Pow")
2993@RegisterPFor("Real")
2994@RegisterPFor("RealDiv")
2995@RegisterPFor("Reciprocal")
2996@RegisterPFor("Relu")
2997@RegisterPFor("Relu6")
2998@RegisterPFor("RightShift")
2999@RegisterPFor("Rint")
3000@RegisterPFor("Round")
3001@RegisterPFor("Rsqrt")
3002@RegisterPFor("Selu")
3003@RegisterPFor("Sigmoid")
3004@RegisterPFor("Sign")
3005@RegisterPFor("Sin")
3006@RegisterPFor("Sinh")
3007@RegisterPFor("Softplus")
3008@RegisterPFor("Softsign")
3009@RegisterPFor("Spence")
3010@RegisterPFor("Sqrt")
3011@RegisterPFor("Square")
3012@RegisterPFor("SquaredDifference")
3013@RegisterPFor("Sub")
3014@RegisterPFor("Tan")
3015@RegisterPFor("Tanh")
3016@RegisterPFor("TruncateDiv")
3017@RegisterPFor("TruncateMod")
3018@RegisterPFor("Xdivy")
3019@RegisterPFor("Xlogy")
3020@RegisterPFor("Xlog1py")
3021@RegisterPFor("Zeta")
3022def _convert_cwise(pfor_input):
3023  if pfor_input.num_inputs > 1:
3024    pfor_input.expanddim_inputs_for_broadcast()
3025
3026  out = _create_op(
3027      pfor_input.op_type, [x.t for x in pfor_input.inputs],
3028      [x.dtype for x in pfor_input.outputs],
3029      attrs=pfor_input.op.node_def.attr).outputs
3030  assert len(out) == 1
3031  out = out[0]
3032
3033  op_output = wrap(out, True)
3034  return op_output
3035
3036
3037@RegisterPFor("XlaSharding")
3038def _convert_xla_sharding(pfor_input):
3039  t = pfor_input.stacked_input(0)
3040  sharding = pfor_input.get_attr("sharding")
3041  return wrap(xla.sharding(t, sharding=sharding), True)
3042
3043
3044@RegisterPFor("LeakyRelu")
3045def _convert_leaky_relu(pfor_input):
3046  t = pfor_input.stacked_input(0)
3047  alpha = pfor_input.get_attr("alpha")
3048  return wrap(gen_nn_ops.leaky_relu(t, alpha=alpha), True)
3049
3050
3051@RegisterPFor("Equal")
3052def _convert_equal(pfor_input):
3053  pfor_input.expanddim_inputs_for_broadcast()
3054  x = pfor_input.input(0)[0]
3055  y = pfor_input.input(1)[0]
3056  incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error")
3057  return wrap(gen_math_ops.equal(
3058      x, y, incompatible_shape_error=incompatible_shape_error), True)
3059
3060
3061@RegisterPFor("NotEqual")
3062def _convert_not_equal(pfor_input):
3063  pfor_input.expanddim_inputs_for_broadcast()
3064  x = pfor_input.input(0)[0]
3065  y = pfor_input.input(1)[0]
3066  incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error")
3067  return wrap(gen_math_ops.not_equal(
3068      x, y, incompatible_shape_error=incompatible_shape_error), True)
3069
3070
3071@RegisterPFor("ApproximateEqual")
3072def _convert_approximate_equal(pfor_input):
3073  pfor_input.expanddim_inputs_for_broadcast()
3074  x = pfor_input.input(0)[0]
3075  y = pfor_input.input(1)[0]
3076  tolerance = pfor_input.get_attr("tolerance")
3077  return wrap(math_ops.approximate_equal(x, y, tolerance=tolerance), True)
3078
3079
3080@RegisterPFor("Shape")
3081def _convert_shape(pfor_input):
3082  out_type = pfor_input.get_attr("out_type")
3083  return wrap(
3084      array_ops.shape(pfor_input.stacked_input(0), out_type=out_type)[1:],
3085      False)
3086
3087
3088@RegisterPFor("ShapeN")
3089def _convert_shape_n(pfor_input):
3090  out_type = pfor_input.get_attr("out_type")
3091  shapes = [
3092      array_ops.shape(x, out_type=out_type)[1:] if stacked else array_ops.shape(
3093          x, out_type=out_type) for x, stacked, _ in pfor_input.inputs
3094  ]
3095  return [wrap(x, False) for x in shapes]
3096
3097
3098@RegisterPFor("Size")
3099def _convert_size(pfor_input):
3100  out_type = pfor_input.get_attr("out_type")
3101  n = math_ops.cast(pfor_input.pfor.loop_len_vector[0], out_type)
3102  return wrap(
3103      array_ops.size(pfor_input.stacked_input(0), out_type=out_type) // n,
3104      False)
3105
3106
3107@RegisterPFor("Rank")
3108def _convert_rank(pfor_input):
3109  return wrap(array_ops.rank(pfor_input.stacked_input(0)) - 1, False)
3110
3111
3112@RegisterPFor("AddN")
3113def _convert_addn(pfor_input):
3114  # AddN does not support broadcasting.
3115  pfor_input.stack_inputs(tile_variants=False)
3116  return _wrap_and_tile_variants(
3117      math_ops.add_n([x.t for x in pfor_input.inputs]),
3118      pfor_input.pfor.loop_len_vector)
3119
3120
3121@RegisterPFor("Cross")
3122def _convert_cross(pfor_input):
3123  pfor_input.stack_inputs()
3124  a = pfor_input.stacked_input(0)
3125  b = pfor_input.stacked_input(1)
3126  return wrap(math_ops.cross(a, b), True)
3127
3128
3129@RegisterPFor("BiasAddGrad")
3130def _convert_biasaddgrad(pfor_input):
3131  grad = pfor_input.stacked_input(0)
3132  fmt = pfor_input.get_attr("data_format")
3133  if fmt == b"NCHW":
3134    output = math_ops.reduce_sum(grad, axis=[1, 3, 4], keepdims=False)
3135  else:
3136    grad_shape = array_ops.shape(grad)
3137    last_dim_shape = grad_shape[-1]
3138    first_dim_shape = grad_shape[0]
3139    output = array_ops.reshape(grad, [first_dim_shape, -1, last_dim_shape])
3140    output = math_ops.reduce_sum(output, axis=[1], keepdims=False)
3141  return wrap(output, True)
3142
3143
3144# Some required ops are not exposed under the tf namespace. Hence relying on
3145# _create_op to create them.
3146@RegisterPForWithArgs("EluGrad")
3147@RegisterPForWithArgs("LeakyReluGrad")
3148@RegisterPForWithArgs("ReciprocalGrad")
3149@RegisterPForWithArgs("Relu6Grad")
3150@RegisterPForWithArgs("ReluGrad")
3151@RegisterPForWithArgs("RsqrtGrad")
3152@RegisterPForWithArgs("SeluGrad")
3153@RegisterPForWithArgs("SigmoidGrad")
3154@RegisterPForWithArgs("SoftplusGrad")
3155@RegisterPForWithArgs("SoftsignGrad")
3156@RegisterPForWithArgs("SqrtGrad")
3157@RegisterPForWithArgs("TanhGrad")
3158def _convert_grads(pfor_input, op_type, *args, **kw_args):
3159  del args
3160  del kw_args
3161  # TODO(agarwal): Looks like these ops don't support broadcasting. Hence we
3162  # have to use tiling here.
3163  pfor_input.stack_inputs()
3164  outputs = _create_op(
3165      op_type, [x.t for x in pfor_input.inputs],
3166      [x.dtype for x in pfor_input.outputs],
3167      attrs=pfor_input.op.node_def.attr).outputs
3168  return [wrap(x, True) for x in outputs]
3169
3170
3171@RegisterPFor("Select")
3172def _convert_select(pfor_input):
3173  pfor_input.stack_inputs()
3174  cond = pfor_input.stacked_input(0)
3175  t = pfor_input.stacked_input(1)
3176  e = pfor_input.stacked_input(2)
3177  cond_rank = array_ops.rank(cond)
3178  cond, t, e = smart_cond.smart_cond(
3179      cond_rank > 1, lambda: _inputs_with_flattening(pfor_input, [0, 1, 2]),
3180      lambda: [cond, t, e])
3181  outputs = _create_op(
3182      pfor_input.op_type, [cond, t, e], [x.dtype for x in pfor_input.outputs],
3183      attrs=pfor_input.op.node_def.attr).outputs
3184  n = pfor_input.pfor.loop_len_vector
3185  out = smart_cond.smart_cond(cond_rank > 1,
3186                              lambda: _unflatten_first_dim(outputs[0], n),
3187                              lambda: outputs[0])
3188  return [wrap(out, True) for x in outputs]
3189
3190
3191@RegisterPFor("SelectV2")
3192def _convert_selectv2(pfor_input):
3193  pfor_input.expanddim_inputs_for_broadcast()
3194  cond = pfor_input.input(0)[0]
3195  t = pfor_input.input(1)[0]
3196  e = pfor_input.input(2)[0]
3197  out = array_ops.where_v2(cond, t, e)
3198  return wrap(out, True)
3199
3200
3201# random_ops
3202
3203
3204def _transpose_dim_to_front(x, dim):
3205  rank = array_ops.rank(x)
3206  return array_ops.transpose(
3207      x,
3208      perm=array_ops.concat(
3209          [[dim], math_ops.range(0, dim),
3210           math_ops.range(dim + 1, rank)],
3211          axis=0))
3212
3213
3214@RegisterPForWithArgs("RandomUniform")
3215@RegisterPForWithArgs("RandomUniformInt")
3216@RegisterPForWithArgs("RandomStandardNormal")
3217@RegisterPForWithArgs("TruncatedNormal")
3218def _convert_random(pfor_input, op_type, *args, **kw_args):
3219  del args
3220  del kw_args
3221  inputs = [pfor_input.unstacked_input(i) for i in range(pfor_input.num_inputs)]
3222  # inputs[0] is "shape"
3223  inputs[0] = array_ops.concat([pfor_input.pfor.loop_len_vector, inputs[0]],
3224                               axis=0)
3225  logging.warning(
3226      "Note that %s inside pfor op may not give same output as "
3227      "inside a sequential loop.", op_type)
3228  outputs = _create_op(
3229      op_type,
3230      inputs, [x.dtype for x in pfor_input.outputs],
3231      attrs=pfor_input.op.node_def.attr).outputs
3232  return [wrap(x, True) for x in outputs]
3233
3234
3235@RegisterPFor("RandomGamma")
3236@RegisterPFor("RandomPoissonV2")
3237def _convert_random_with_param(pfor_input):
3238  shape = pfor_input.unstacked_input(0)
3239  # param is lam (Poisson rate) or alpha (Gamma shape).
3240  param, param_stacked, _ = pfor_input.input(1)
3241  logging.warning(
3242      "Note that %s inside pfor op may not give same output as "
3243      "inside a sequential loop.", pfor_input.op_type)
3244
3245  if param_stacked:
3246    samples = _create_op(
3247        pfor_input.op_type,
3248        inputs=[shape, param],
3249        op_dtypes=[x.dtype for x in pfor_input.outputs],
3250        attrs=pfor_input.op.node_def.attr).outputs[0]
3251    loop_dim = array_ops.shape(shape)[0]
3252    stacked_samples = _transpose_dim_to_front(samples, loop_dim)
3253  else:
3254    shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0)
3255    stacked_samples = _create_op(
3256        pfor_input.op_type,
3257        inputs=[shape, param],
3258        op_dtypes=[x.dtype for x in pfor_input.outputs],
3259        attrs=pfor_input.op.node_def.attr).outputs[0]
3260
3261  return wrap(stacked_samples, True)
3262
3263
3264@RegisterPFor("Multinomial")
3265def _convert_multinomial(pfor_input):
3266  logits, logits_stacked, _ = pfor_input.input(0)
3267  num_samples = pfor_input.unstacked_input(1)
3268  seed = pfor_input.get_attr("seed")
3269  seed2 = pfor_input.get_attr("seed2")
3270  output_dtype = pfor_input.get_attr("output_dtype")
3271  logging.warning(
3272      "Note that Multinomial inside pfor op may not give same output as "
3273      "inside a sequential loop.")
3274
3275  n = pfor_input.pfor.loop_len_vector[0]
3276  if logits_stacked:
3277    flattened_logits = _flatten_first_two_dims(logits)
3278    samples = gen_random_ops.multinomial(
3279        flattened_logits,
3280        num_samples,
3281        seed=seed,
3282        seed2=seed2,
3283        output_dtype=output_dtype)
3284    stacked_samples = _unflatten_first_dim(samples, [n])
3285  else:
3286    samples = gen_random_ops.multinomial(
3287        logits,
3288        num_samples * n,
3289        seed=seed,
3290        seed2=seed2,
3291        output_dtype=output_dtype)
3292    stacked_samples = array_ops.transpose(
3293        array_ops.reshape(samples, [-1, n, num_samples]), [1, 0, 2])
3294
3295  return wrap(stacked_samples, True)
3296
3297
3298@RegisterPFor("StatelessMultinomial")
3299@RegisterPFor("StatelessParameterizedTruncatedNormal")
3300@RegisterPFor("StatelessRandomBinomial")
3301@RegisterPFor("StatelessRandomGammaV2")
3302@RegisterPFor("StatelessRandomNormal")
3303@RegisterPFor("StatelessRandomPoisson")
3304@RegisterPFor("StatelessRandomUniform")
3305@RegisterPFor("StatelessRandomUniformInt")
3306@RegisterPFor("StatelessRandomUniformFullInt")
3307@RegisterPFor("StatelessTruncatedNormal")
3308def _convert_stateless_multinomial(pfor_input):
3309  # Unlike stateful random ops, for stateless ones we want better
3310  # reproducibility based on seed. Hence we don't want to use a similar strategy
3311  # as used for stateful ones where we generate a possibly different set of
3312  # random numbers under vectorization.
3313  # Unfortunately, the kernels currently are not necessarily setup to do this
3314  # efficiently and hence we fallback to a sequential loop for vectorization.
3315  return _fallback_converter(pfor_input, warn=False)
3316
3317
3318# linalg_ops
3319
3320
3321@RegisterPForWithArgs("XlaEinsum")
3322@RegisterPForWithArgs("Einsum")
3323def _convert_einsum(pfor_input, op_type):
3324  # Einsum may have either 1 or 2 inputs.
3325  inputs, input_stacked, _ = zip(*[
3326      pfor_input.input(i)
3327      for i in range(pfor_input.num_inputs)])
3328
3329  # Parse the einsum equation.
3330  equation = pfor_input.get_attr("equation").decode("utf-8")
3331  input_expr, output_expr = equation.split("->")
3332  input_exprs = input_expr.split(",")
3333
3334  # Pick a placeholder symbol to use for the new axis.
3335  chosen_symbol = None
3336  for s in string.ascii_letters:
3337    if s in equation:
3338      continue
3339    else:
3340      chosen_symbol = s
3341      break
3342
3343  if chosen_symbol is None:
3344    raise ValueError("Could not figure out what symbol to use for new axis.")
3345
3346  assert any(input_stacked)
3347  for i in range(len(inputs)):
3348    if input_stacked[i]:
3349      input_exprs[i] = "{}{}".format(chosen_symbol, input_exprs[i])
3350  output_expr = "{}{}".format(chosen_symbol, output_expr)
3351
3352  new_equation = "{}->{}".format(",".join(input_exprs), output_expr)
3353
3354  if op_type == "XlaEinsum":
3355    if len(inputs) == 1:
3356      result = xla.einsum(equation=new_equation, a=inputs[0])
3357    else:
3358      result = xla.einsum(equation=new_equation, a=inputs[0], b=inputs[1])
3359  else:
3360    assert op_type == "Einsum"
3361    result = special_math_ops.einsum(new_equation, *inputs)
3362
3363  return wrap(result, True)
3364
3365
3366@RegisterPFor("Cholesky")
3367def _convert_cholesky(pfor_input):
3368  t = pfor_input.stacked_input(0)
3369  return wrap(linalg_ops.cholesky(t), True)
3370
3371
3372@RegisterPFor("LogMatrixDeterminant")
3373def _convert_log_matrix_determinant(pfor_input):
3374  t = pfor_input.stacked_input(0)
3375  return [wrap(x, True) for x in linalg_ops.log_matrix_determinant(t)]
3376
3377
3378@RegisterPFor("MatrixInverse")
3379def _convert_matrix_inverse(pfor_input):
3380  t = pfor_input.stacked_input(0)
3381  adjoint = pfor_input.get_attr("adjoint")
3382  return wrap(gen_linalg_ops.matrix_inverse(t, adjoint=adjoint), True)
3383
3384
3385@RegisterPFor("MatrixSolve")
3386def _convert_matrix_solve(pfor_input):
3387  pfor_input.stack_inputs()
3388  matrix = pfor_input.stacked_input(0)
3389  rhs = pfor_input.stacked_input(1)
3390  adjoint = pfor_input.get_attr("adjoint")
3391  output = gen_linalg_ops.matrix_solve(
3392      matrix, rhs, adjoint=adjoint)
3393  return wrap(output, True)
3394
3395
3396@RegisterPFor("MatrixTriangularSolve")
3397def _convert_matrix_triangular_solve(pfor_input):
3398  pfor_input.expanddim_inputs_for_broadcast()
3399  matrix = pfor_input.input(0)[0]
3400  rhs = pfor_input.input(1)[0]
3401  lower = pfor_input.get_attr("lower")
3402  adjoint = pfor_input.get_attr("adjoint")
3403  output = linalg_ops.matrix_triangular_solve(
3404      matrix, rhs, lower=lower, adjoint=adjoint)
3405  return wrap(output, True)
3406
3407
3408@RegisterPFor("SelfAdjointEigV2")
3409def _convert_self_adjoint_eig(pfor_input):
3410  t = pfor_input.stacked_input(0)
3411  compute_v = pfor_input.get_attr("compute_v")
3412  e, v = gen_linalg_ops.self_adjoint_eig_v2(t, compute_v=compute_v)
3413  # If compute_v is False, v will have shape [0].
3414  return wrap(e, True), wrap(v, compute_v)
3415
3416
3417# logging_ops
3418
3419
3420@RegisterPFor("Assert")
3421def _convert_assert(pfor_input):
3422  cond, cond_stacked, _ = pfor_input.input(0)
3423  if cond_stacked:
3424    cond = math_ops.reduce_all(cond)
3425
3426  data_list = [x.t for x in pfor_input.inputs][1:]
3427  return _create_op(
3428      "Assert", [cond] + data_list, [], attrs=pfor_input.op.node_def.attr)
3429
3430
3431@RegisterPFor("Print")
3432def _convert_print(pfor_input):
3433  # Note that we don't stack all the inputs. Hence unstacked values are printed
3434  # once here vs multiple times in a while_loop.
3435  pfor_input.stack_inputs([0])
3436  outputs = _create_op(
3437      "Print", [x.t for x in pfor_input.inputs],
3438      [x.dtype for x in pfor_input.outputs],
3439      attrs=pfor_input.op.node_def.attr).outputs
3440  return [wrap(x, True) for x in outputs]
3441
3442
3443@RegisterPFor("PrintV2")
3444def _convert_print_v2(pfor_input):
3445  # Print the full input Tensor(s), including the batch dimension if stacked.
3446  return _create_op(
3447      "PrintV2", [x.t for x in pfor_input.inputs],
3448      [x.dtype for x in pfor_input.outputs],
3449      attrs=pfor_input.op.node_def.attr)
3450
3451
3452@RegisterPFor("StringFormat")
3453def _convert_string_format(pfor_input):
3454  # Format using the full input Tensor(s), including the batch dimension if
3455  # stacked.
3456  op = _create_op(
3457      "StringFormat", [x.t for x in pfor_input.inputs],
3458      [x.dtype for x in pfor_input.outputs],
3459      attrs=pfor_input.op.node_def.attr)
3460  return [wrap(output, False) for output in op.outputs]
3461
3462
3463# data_flow_ops
3464
3465# TensorArray conversion is tricky since we don't support arrays of
3466# TensorArrays. For converting them, we consider two distinct cases:
3467#
3468# 1. The array is constructed outside the pfor call, and read/written inside the
3469# loop.
3470# This is an easier case since we don't need to make an array of TensorArrays.
3471# A correctness requirement is that these parallel iterations shouldn't attempt
3472# to write to the same location. Hence at conversion time we disallow indices to
3473# be loop-invariant as that would guarantee a collision. Even if the indices are
3474# not loop-invariant, they could conflict and that shall trigger runtime errors.
3475#
3476# 2. The array is constructed and used entirely inside each pfor iteration.
3477# For simplicity, here we require that the indices used for write/scatter are
3478# "unstacked". Otherwise it becomes hard to merge the TensorArrays created in
3479# different pfor iterations. We consider two sub_cases:
3480#
3481# 2a Elements written to the array are "stacked"
3482# To simulate multiple TensorArrays, we may increase the dimension of each
3483# element of the array. i.e. the i_th row of the j_th entry of the converted
3484# TensorArray corresponds to the j_th entry of the TensorArray in the i_th
3485# pfor iteration.
3486#
3487# 2b Elements written to the array are "unstacked"
3488# In this case we don't increase the dimensions to avoid redundant tiling. Each
3489# iteration is trying to write the same value. So we convert that to a single
3490# write.
3491#
3492# Here are some tricks used to implement the above:
3493# - TensorArrayV3 constructor encodes the element shape as an attr. Instead of
3494# trying to trace whether future writes are stacked or unstacked in order to set
3495# this attr, we set it to correspond to unknown shape.
3496# - We use the "flow" output of the different ops to track whether the array
3497# elements are stacked or unstacked. If a stacked write/scatter is done, we make
3498# the flow stacked as well.
3499# - We use some heuristic traversal of the graph to track whether the
3500# TensorArray handle was created inside or outside the pfor loop.
3501
3502
3503@RegisterPFor("TensorArrayV3")
3504def _convert_tensor_array_v3(pfor_input):
3505  size = pfor_input.unstacked_input(0)
3506  dtype = pfor_input.get_attr("dtype")
3507  dynamic_size = pfor_input.get_attr("dynamic_size")
3508  clear_after_read = pfor_input.get_attr("clear_after_read")
3509  identical_element_shapes = pfor_input.get_attr("identical_element_shapes")
3510  tensor_array_name = pfor_input.get_attr("tensor_array_name")
3511  handle, flow = data_flow_ops.tensor_array_v3(
3512      size,
3513      dtype=dtype,
3514      # We don't set element shape since we don't know if writes are stacked or
3515      # not yet.
3516      element_shape=None,
3517      dynamic_size=dynamic_size,
3518      clear_after_read=clear_after_read,
3519      identical_element_shapes=identical_element_shapes,
3520      tensor_array_name=tensor_array_name)
3521  # Note we keep flow unstacked for now since we don't know if writes will be
3522  # stacked or not.
3523  return wrap(handle, False), wrap(flow, False)
3524
3525
3526@RegisterPFor("TensorArraySizeV3")
3527def _convert_tensor_array_size_v3(pfor_input):
3528  handle = pfor_input.unstacked_input(0)
3529  flow, flow_stacked, _ = pfor_input.input(1)
3530  if flow_stacked:
3531    flow = _unstack_flow(flow)
3532  size = data_flow_ops.tensor_array_size_v3(handle, flow)
3533  return wrap(size, False)
3534
3535
3536def _handle_inside_pfor(pfor_input, handle):
3537  """Returns True if handle was created inside the pfor loop."""
3538  # We use some heuristic to find the original TensorArray creation op.
3539  # The logic should handle the common cases (except cond based subgraphs).
3540  # In theory the user could perform different operations on the handle (like
3541  # Reshape, stack multiple handles, etc) which could break this logic.
3542  # TODO(agarwal): handle Switch/Merge.
3543  while handle.op.type in ("Enter", "Identity"):
3544    handle = handle.op.inputs[0]
3545  if handle.op.type not in [
3546      "TensorArrayV3", "TensorArrayGradV3", "TensorArrayGradWithShape"
3547  ]:
3548    raise ValueError("Unable to find source for handle %s" % handle)
3549  else:
3550    return pfor_input.pfor.op_is_inside_loop(handle.op)
3551
3552
3553def _unstack_flow(value):
3554  # TODO(agarwal): consider looking if this is a Tile op then get its input.
3555  # This may avoid running the Tile operations.
3556  return array_ops.gather(value, 0)
3557
3558
3559@RegisterPFor("TensorArrayReadV3")
3560def _convert_tensor_array_read_v3(pfor_input):
3561  handle = pfor_input.unstacked_input(0)
3562  index, index_stacked, _ = pfor_input.input(1)
3563  dtype = pfor_input.get_attr("dtype")
3564  flow, flow_stacked, _ = pfor_input.input(2)
3565  if flow_stacked:
3566    flow = _unstack_flow(flow)
3567
3568  is_inside_pfor = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
3569  if is_inside_pfor:
3570    # Note that if we are inside a control flow construct inside the pfor, and
3571    # only some of the iterations are doing the read (i.e.
3572    # `all_indices_partitioned` is True), then the read operation should only
3573    # return values for the currently active pfor iterations (`all_indices`
3574    # below). Hence, whenever the returned value is stacked (i.e. `flow` is
3575    # stacked), we may need to do an extra gather after reading the values. Also
3576    # note that if `is_inside` is false, then values in the tensor array are
3577    # unstacked. So the check is only needed in this branch.
3578    all_indices = pfor_input.pfor.all_indices
3579    all_indices_partitioned = pfor_input.pfor.all_indices_partitioned
3580    # Note: flow_stacked indicates if values in the TensorArray are stacked or
3581    # not.
3582    if index_stacked:
3583      if flow_stacked:
3584        raise ValueError(
3585            "It looks like TensorArrayReadV3 was called on a TensorArray whose"
3586            " values are not loop-invariant, and the read indices were also"
3587            " not loop invariant. This is currently unsupported.")
3588      value = data_flow_ops.tensor_array_gather_v3(
3589          handle, index, flow, dtype=dtype)
3590      return wrap(value, True)
3591    value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype)
3592    if flow_stacked and all_indices_partitioned:
3593      value = array_ops.gather(value, all_indices)
3594    return wrap(value, flow_stacked)
3595  # Values in the TensorArray should be unstacked (since different iterations
3596  # couldn't write to the same location). So whether output is stacked or not
3597  # depends on index_stacked.
3598  if index_stacked:
3599    value = data_flow_ops.tensor_array_gather_v3(
3600        handle, index, flow, dtype=dtype)
3601  else:
3602    value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype)
3603  return wrap(value, index_stacked)
3604
3605
3606@RegisterPFor("TensorArrayWriteV3")
3607def _convert_tensor_array_write_v3(pfor_input):
3608  handle = pfor_input.unstacked_input(0)
3609  index, index_stacked, _ = pfor_input.input(1)
3610  value, value_stacked, _ = pfor_input.input(2)
3611  flow, flow_stacked, _ = pfor_input.input(3)
3612  if value_stacked and pfor_input.pfor.all_indices_partitioned:
3613    # Looks like we are in a control flow in a pfor where not all iterations are
3614    # active now. We don't allow that since that could lead to different indices
3615    # having different shapes which will be hard to merge later.
3616    raise ValueError("Writing non loop invariant values to TensorArray from "
3617                     "inside a while_loop/cond not supported.")
3618  if flow_stacked:
3619    flow = _unstack_flow(flow)
3620  is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
3621  if is_inside:
3622    if index_stacked:
3623      raise ValueError("Need indices for %s to be loop invariant" % handle)
3624    if not flow_stacked and not value_stacked:
3625      flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow)
3626      return wrap(flow_out, False)
3627    else:
3628      if not value_stacked:
3629        value = _stack(value, pfor_input.pfor.loop_len_vector).t
3630      # TODO(agarwal): Note that if flow is unstacked and value is stacked, then
3631      # this may or may not be a safe situation. flow is unstacked both for a
3632      # freshly created TensorArray, as well as after unstacked values are
3633      # written to it. If it is the latter, then we cannot write a stacked value
3634      # now since that may cause runtime errors due to different shapes in the
3635      # array. At the moment we are not able to handle this gracefully and
3636      # distinguish between the two cases. That would require some heuristic
3637      # traversal of the graph to figure out whether all the writes are
3638      # unstacked or not.
3639      flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow)
3640      return _stack(flow_out, pfor_input.pfor.loop_len_vector)
3641  else:
3642    if not index_stacked:
3643      raise ValueError("Need indices for %s to be not loop invariant" % handle)
3644    # Note that even when index_stacked is true, actual values in index may
3645    # still not be unique. However that will cause runtime error when executing
3646    # the scatter operation below.
3647    if not value_stacked:
3648      value = _stack(value, pfor_input.pfor.loop_len_vector).t
3649    flow_out = data_flow_ops.tensor_array_scatter_v3(handle, index, value, flow)
3650    return _stack(flow_out, pfor_input.pfor.loop_len_vector)
3651
3652
3653def _transpose_first_two_dims(value):
3654  # TODO(agarwal): optimize if one of the dims == 1.
3655  value_shape = array_ops.shape(value)
3656  v0 = value_shape[0]
3657  v1 = value_shape[1]
3658  value = array_ops.reshape(value, [v0, v1, -1])
3659  value = array_ops.transpose(value, [1, 0, 2])
3660  new_shape = array_ops.concat([[v1, v0], value_shape[2:]], axis=0)
3661  return array_ops.reshape(value, new_shape)
3662
3663
3664@RegisterPFor("TensorArrayGatherV3")
3665def _convert_tensor_array_gather_v3(pfor_input):
3666  handle = pfor_input.unstacked_input(0)
3667  indices, indices_stacked, _ = pfor_input.input(1)
3668  indices = array_ops.reshape(indices, [-1])
3669  flow, flow_stacked, _ = pfor_input.input(2)
3670  if flow_stacked:
3671    flow = _unstack_flow(flow)
3672  dtype = pfor_input.get_attr("dtype")
3673  # TODO(agarwal): support element_shape attr?
3674
3675  n = pfor_input.pfor.loop_len_vector
3676  value = data_flow_ops.tensor_array_gather_v3(
3677      handle, indices, flow, dtype=dtype)
3678  is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
3679  if is_inside:
3680    # flow_stacked indicates if values in the TensorArray are stacked or not.
3681    if indices_stacked:
3682      if flow_stacked:
3683        raise ValueError(
3684            "It looks like TensorArrayGatherV3 was called on a TensorArray "
3685            "whose values are not loop-invariant, and the indices were also "
3686            "not loop invariant. This is currently unsupported.")
3687      else:
3688        value = _unflatten_first_dim(value, n)
3689        return wrap(value, True)
3690    else:
3691      if flow_stacked:
3692        # Since elements in this array are stacked and `value` was produced by
3693        # gather, its first two dims are "gathered elements" and "stack
3694        # dimension". Our semantics require these two to be flipped.
3695        value = _transpose_first_two_dims(value)
3696      return wrap(value, flow_stacked)
3697  else:
3698    # Values in the TensorArray should be unstacked (since different iterations
3699    # couldn't write to the same location). So whether output is stacked or not
3700    # depends on indices_stacked.
3701    if indices_stacked:
3702      value = _unflatten_first_dim(value, n)
3703    return wrap(value, indices_stacked)
3704
3705
3706@RegisterPFor("TensorArrayScatterV3")
3707def _convert_tensor_array_scatter_v3(pfor_input):
3708  handle = pfor_input.unstacked_input(0)
3709  indices, indices_stacked, _ = pfor_input.input(1)
3710  indices = array_ops.reshape(indices, [-1])
3711  value, value_stacked, _ = pfor_input.input(2)
3712  flow, flow_stacked, _ = pfor_input.input(3)
3713
3714  if flow_stacked:
3715    flow = _unstack_flow(flow)
3716
3717  is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
3718  if is_inside:
3719    if indices_stacked:
3720      raise ValueError("Need indices for %s to be loop invariant" % handle)
3721    # Note that flow_stacked indicates if existing values in the array are
3722    # stacked or not.
3723    if not flow_stacked and not value_stacked:
3724      flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value,
3725                                                       flow)
3726      return wrap(flow_out, False)
3727    if not value_stacked:
3728      # TODO(agarwal): tile in the second dimension directly instead of
3729      # transposing below.
3730      value = _stack(value, pfor_input.pfor.loop_len_vector).t
3731
3732    value = _transpose_first_two_dims(value)
3733    # TODO(agarwal): Note that if a previous write was unstacked, flow will be
3734    # unstacked, and a stacked value may be written here which may cause
3735    # runtime error due to different elements having different shape. We do
3736    # not try to prevent that.
3737    flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value,
3738                                                     flow)
3739    return _stack(flow_out, pfor_input.pfor.loop_len_vector)
3740  if not indices_stacked:
3741    raise ValueError("Need indices for %s to be not loop invariant" % handle)
3742  if not value_stacked:
3743    value = _stack(value, pfor_input.pfor.loop_len_vector).t
3744  value = _flatten_first_two_dims(value)
3745  flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, flow)
3746  return _stack(flow_out, pfor_input.pfor.loop_len_vector)
3747
3748
3749@RegisterPFor("TensorArrayGradV3")
3750def _convert_tensor_array_grad_v3(pfor_input):
3751  handle = pfor_input.unstacked_input(0)
3752  flow, flow_stacked, _ = pfor_input.input(1)
3753  if flow_stacked:
3754    flow = _unstack_flow(flow)
3755  source = pfor_input.get_attr("source")
3756  # TODO(agarwal): For now, we assume that gradients are stacked if the
3757  # TensorArrayGradV3 call is being done inside the pfor. Getting that wrong
3758  # will give runtime error due to incorrect shape being written to the
3759  # accumulator. It is difficult to know in advance if gradients written will be
3760  # stacked or not. Note that flow being stacked is not indicative of the
3761  # gradient being stacked or not. Revisit this later.
3762  shape_to_prepend = pfor_input.pfor.loop_len_vector
3763  grad_handle, flow_out = data_flow_ops.tensor_array_grad_with_shape(
3764      handle=handle,
3765      flow_in=flow,
3766      shape_to_prepend=shape_to_prepend,
3767      source=source)
3768  flow_out = _stack(flow_out, pfor_input.pfor.loop_len_vector).t
3769  return [wrap(grad_handle, False), wrap(flow_out, True)]
3770
3771
3772def _stack_tensor_list_shape(shape, first_dim):
3773  shape_value = tensor_util.constant_value(shape)
3774  # Note that negative values in the shape are used to signify unknown shapes
3775  # and are handled in a special way.
3776  if shape_value is not None:
3777    shape_value = np.asarray(shape_value)
3778    if -1 in shape_value:
3779      return constant_op.constant(-1)
3780    elif not shape_value.size:
3781      return first_dim
3782  else:
3783    shape = array_ops.reshape(shape, [-1])
3784    return control_flow_ops.cond(
3785        math_ops.reduce_any(shape < 0),
3786        lambda: constant_op.constant(-1),
3787        lambda: array_ops.concat([first_dim, shape], axis=0))
3788
3789
3790def _tile_variant_with_length(t, length):
3791  """stacks `t` `length` times."""
3792  if _is_variant_with_internal_stacking(t):
3793    # The content of TensorLists is vectorized, not the variant itself.
3794    return t
3795  original_tensor = t
3796  t.set_shape([])
3797  t = array_ops.reshape(t, [-1])
3798  with ops.device("CPU:0"):
3799    result = array_ops.tile(t, length)
3800    # TODO(b/169968286): Should regular shape functions do handle data
3801    # propagation here?
3802    handle_data_util.copy_handle_data(original_tensor, result)
3803    return result
3804
3805
3806def _tile_variant(t, pfor_input):
3807  """stacks `t` according to its loop context."""
3808  return _tile_variant_with_length(t, pfor_input.pfor.loop_len_vector)
3809
3810
3811def _untile_variant(t):
3812  if _is_variant_with_internal_stacking(t):
3813    # The content of TensorLists is vectorized, not the variant itself.
3814    if not t.shape.is_compatible_with([]):
3815      raise AssertionError(
3816          ("Unexpectedly saw a vectorized variant (e.g. TensorList) with "
3817           "non-scalar shape: {!r}").format(t))
3818    return t
3819  return array_ops.gather(t, 0)
3820
3821
3822@RegisterPFor("OptionalFromValue")
3823def _convert_optional_from_value(pfor_input):
3824  pfor_input.stack_inputs()
3825  return wrap(
3826      gen_dataset_ops.optional_from_value([x.t for x in pfor_input.inputs]),
3827      True)
3828
3829
3830@RegisterPFor("OptionalGetValue")
3831def _convert_optional_get_value(pfor_input):
3832  handle = pfor_input.stacked_input(0)
3833  output_types = pfor_input.get_attr("output_types")
3834  original_output_shapes = pfor_input.get_attr("output_shapes")
3835  output_shapes = []
3836  for shape in original_output_shapes:
3837    shape = tensor_shape.TensorShape(shape)
3838    loop_len_shape = tensor_shape.TensorShape(
3839        [tensor_util.constant_value(pfor_input.pfor.loop_len_vector)])
3840    shape = loop_len_shape.concatenate(shape)
3841    output_shapes.append(shape.as_proto())
3842  results = gen_dataset_ops.optional_get_value(handle, output_types,
3843                                               output_shapes)
3844  return [wrap(t, True) for t in results]
3845
3846
3847@RegisterPFor("TensorListReserve")
3848def _convert_tensor_list_reserve(pfor_input):
3849  element_shape = pfor_input.unstacked_input(0)
3850  num_elements = pfor_input.unstacked_input(1)
3851  element_dtype = pfor_input.get_attr("element_dtype")
3852
3853  # Prepend a dimension to element_shape.
3854  element_shape = _stack_tensor_list_shape(element_shape,
3855                                           pfor_input.pfor.loop_len_vector)
3856  handle = list_ops.tensor_list_reserve(
3857      element_shape, num_elements, element_dtype=element_dtype)
3858
3859  return wrap(_tile_variant(handle, pfor_input), True)
3860
3861
3862@RegisterPFor("TensorListElementShape")
3863def _convert_tensor_list_element_shape(pfor_input):
3864  handle = _untile_variant(pfor_input.stacked_input(0))
3865  shape_type = pfor_input.get_attr("shape_type")
3866  shape = list_ops.tensor_list_element_shape(handle, shape_type)
3867  shape = array_ops.reshape(shape, [-1])
3868  shape = shape[1:]
3869  return wrap(shape, False)
3870
3871
3872@RegisterPFor("TensorListLength")
3873def _convert_tensor_list_length(pfor_input):
3874  handle = _untile_variant(pfor_input.stacked_input(0))
3875  return wrap(list_ops.tensor_list_length(handle), False)
3876
3877
3878def _stack_tensor_list(handle, dtype, loop_len_vector, element_shape=None):
3879  if element_shape is None:
3880    element_shape = list_ops.tensor_list_element_shape(handle, dtypes.int32)
3881  length = list_ops.tensor_list_length(handle)
3882  new_handle = list_ops.tensor_list_reserve(
3883      _stack_tensor_list_shape(element_shape, loop_len_vector), length, dtype)
3884
3885  def _body_fn(i, h):
3886    elem = list_ops.tensor_list_get_item(handle, i, dtype, element_shape)
3887    elem = _stack(elem, loop_len_vector).t
3888    return i + 1, list_ops.tensor_list_set_item(h, i, elem)
3889
3890  return control_flow_ops.while_loop(lambda i, _: i < length, _body_fn,
3891                                     [0, new_handle])[1]
3892
3893
3894@RegisterPFor("TensorListGetItem")
3895def _convert_tensor_list_get_item(pfor_input):
3896  handle, handle_stacked, _ = pfor_input.input(0)
3897  index, index_stacked, _ = pfor_input.input(1)
3898  element_shape = pfor_input.unstacked_input(2)
3899  element_dtype = pfor_input.get_attr("element_dtype")
3900
3901  if handle_stacked:
3902    handle = _untile_variant(handle)
3903    element_shape = _stack_tensor_list_shape(element_shape,
3904                                             pfor_input.pfor.loop_len_vector)
3905    if index_stacked:
3906      # We use a sequential loop since that may be more efficient than first
3907      # gathering and concatenating all the element corresponding to `index`,
3908      # and then doing a gather on it.
3909      def _map_fn(i):
3910        item_i = list_ops.tensor_list_get_item(
3911            handle,
3912            index[i],
3913            element_dtype=element_dtype)
3914        return array_ops.gather(item_i, i)
3915
3916      output = map_fn.map_fn(_map_fn, pfor_input.pfor.all_indices)
3917      return wrap(output, True)
3918    else:
3919      output = list_ops.tensor_list_get_item(
3920          handle,
3921          index,
3922          element_shape=element_shape,
3923          element_dtype=element_dtype)
3924      return wrap(output, True)
3925  else:
3926    assert index_stacked
3927    return wrap(
3928        list_ops.tensor_list_gather(
3929            handle,
3930            index,
3931            element_shape=element_shape,
3932            element_dtype=element_dtype), True)
3933
3934
3935@RegisterPFor("TensorListSetItem")
3936def _convert_tensor_array_set_item(pfor_input):
3937  handle, handle_stacked, _ = pfor_input.input(0)
3938  index, index_stacked, _ = pfor_input.input(1)
3939  item, item_stacked, _ = pfor_input.input(2)
3940
3941  if not handle_stacked:
3942    # Special case where we can statically guarantee that the indices are
3943    # disjoint.
3944    if index is pfor_input.pfor.all_indices:
3945      if not item_stacked:
3946        item = _stack(item, pfor_input.pfor.loop_len_vector).t
3947      return wrap(
3948          list_ops.tensor_list_scatter(item, index, input_handle=handle), False)
3949    else:
3950      handle = _stack_tensor_list(handle, item.dtype,
3951                                  pfor_input.pfor.loop_len_vector)
3952  else:
3953    handle = _untile_variant(handle)
3954
3955  if index_stacked:
3956    # TODO(agarwal): handle this.
3957    raise ValueError("Vectorizing writes to a TensorList with loop "
3958                     "variant indices is currently unsupported.")
3959
3960  else:
3961    if not item_stacked:
3962      item = _stack(item, pfor_input.pfor.loop_len_vector).t
3963    handle = list_ops.tensor_list_set_item(handle, index, item)
3964    return wrap(_tile_variant(handle, pfor_input), True)
3965
3966
3967@RegisterPFor("TensorListPushBack")
3968def _convert_tensor_list_push_back(pfor_input):
3969  handle, handle_stacked, _ = pfor_input.input(0)
3970  tensor, tensor_stacked, _ = pfor_input.input(1)
3971  if handle_stacked:
3972    handle = _untile_variant(handle)
3973  else:
3974    handle = _stack_tensor_list(handle, tensor.dtype,
3975                                pfor_input.pfor.loop_len_vector)
3976  if not tensor_stacked:
3977    tensor = _stack(tensor, pfor_input.pfor.loop_len_vector).t
3978  handle = list_ops.tensor_list_push_back(handle, tensor)
3979  return wrap(_tile_variant(handle, pfor_input), True)
3980
3981
3982@RegisterPFor("TensorListPopBack")
3983def _convert_tensor_array_push_back(pfor_input):
3984  handle = pfor_input.stacked_input(0)
3985  element_shape = pfor_input.unstacked_input(1)
3986  handle = _untile_variant(handle)
3987
3988  if element_shape.shape.ndims == 0:
3989    # Default / unspecified
3990    vectorized_shape = -1
3991  else:
3992    # PopBack has an element shape set when it's the gradient of PushBack, only
3993    # used when the list is uninitialized.
3994    vectorized_shape = array_ops.concat(
3995        [pfor_input.pfor.loop_len_vector, element_shape], axis=0)
3996
3997  output_handle, tensor = gen_list_ops.tensor_list_pop_back(
3998      input_handle=handle, element_dtype=pfor_input.get_attr("element_dtype"),
3999      element_shape=vectorized_shape)
4000  return wrap(output_handle, True), wrap(tensor, True)
4001
4002
4003@RegisterPFor("TensorListConcatV2")
4004def _convert_tensor_list_concat_v2(pfor_input):
4005  input_handle = pfor_input.stacked_input(0)
4006  element_shape = pfor_input.unstacked_input(1)
4007  leading_dims = pfor_input.unstacked_input(2)
4008  element_dtype = pfor_input.get_attr("element_dtype")
4009
4010  handle = _untile_variant(input_handle)
4011  length = list_ops.tensor_list_length(handle)
4012  # Note that element_shape attribute can have incomplete shapes. This doesn't
4013  # seem to work well when creating another list and then doing a concat on it.
4014  # Hence we try to find the dynamic shape here.
4015  element_shape = control_flow_ops.cond(
4016      length > 0, lambda: array_ops.shape(
4017          list_ops.tensor_list_get_item(handle, 0, element_dtype, None)),
4018      lambda: constant_op.constant([0, 0], dtype=dtypes.int32))
4019  # The code below creates a copy of the list with each elements' first two
4020  # dimensions transposed.
4021  new_element_shape = array_ops.concat(
4022      [element_shape[1:2], element_shape[0:1], element_shape[2:]], axis=0)
4023
4024  # Create a new TensorList with elements transposed.
4025  def _transpose_elem(i, h):
4026    elem = list_ops.tensor_list_get_item(handle, i, element_dtype, None)
4027    elem = _transpose_first_two_dims(elem)
4028    return i + 1, list_ops.tensor_list_set_item(h, i, elem)
4029
4030  new_handle = list_ops.tensor_list_reserve(new_element_shape, length,
4031                                            element_dtype)
4032  new_handle = control_flow_ops.while_loop(lambda i, _: i < length,
4033                                           _transpose_elem, [0, new_handle])[1]
4034  output, lengths = gen_list_ops.tensor_list_concat_v2(
4035      input_handle=new_handle,
4036      element_dtype=element_dtype,
4037      element_shape=new_element_shape,
4038      leading_dims=leading_dims)
4039  output = _transpose_first_two_dims(output)
4040  return wrap(output, True), wrap(lengths, False)
4041
4042
4043@RegisterPFor("TensorListStack")
4044def _convert_tensor_list_stack(pfor_input):
4045  handle = pfor_input.stacked_input(0)
4046  input_shape = pfor_input.unstacked_input(1)
4047  element_dtype = pfor_input.get_attr("element_dtype")
4048  num_elements = pfor_input.get_attr("num_elements")
4049
4050  handle = _untile_variant(handle)
4051  input_shape = _stack_tensor_list_shape(input_shape,
4052                                         pfor_input.pfor.loop_len_vector)
4053  output = list_ops.tensor_list_stack(
4054      handle,
4055      element_dtype,
4056      element_shape=input_shape,
4057      num_elements=num_elements)
4058  output = _transpose_first_two_dims(output)
4059  return wrap(output, True)
4060
4061
4062@RegisterPFor("TensorListGather")
4063def _convert_tensor_list_gather(pfor_input):
4064  handle, handle_stacked, _ = pfor_input.input(0)
4065  index, index_stacked, _ = pfor_input.input(1)
4066  element_shape = pfor_input.unstacked_input(2)
4067  element_dtype = pfor_input.get_attr("element_dtype")
4068
4069  if handle_stacked:
4070    handle = _untile_variant(handle)
4071    element_shape = _stack_tensor_list_shape(element_shape,
4072                                             pfor_input.pfor.loop_len_vector)
4073    if index_stacked:
4074      # We use a sequential loop since that may be more efficient than first
4075      # gathering and concatenating all the element corresponding to `index`,
4076      # and then doing a gather on it.
4077      def _map_fn(i):
4078        item_i = list_ops.tensor_list_gather(
4079            handle,
4080            index[i],
4081            element_dtype=element_dtype)
4082        axis = array_ops.rank(index) - 1
4083        return array_ops.gather(item_i, i, axis=axis)
4084
4085      output = map_fn.map_fn(_map_fn, pfor_input.pfor.all_indices)
4086      return wrap(output, True)
4087    else:
4088      output = list_ops.tensor_list_gather(
4089          handle,
4090          index,
4091          element_shape=element_shape,
4092          element_dtype=element_dtype)
4093      return wrap(output, True)
4094  else:
4095    assert index_stacked
4096    index_shape = array_ops.shape(index)
4097    index = array_ops.reshape(index, [-1])
4098    values = list_ops.tensor_list_gather(
4099        handle, index, element_shape=element_shape, element_dtype=element_dtype)
4100    final_shape = array_ops.concat(
4101        [index_shape, array_ops.shape(values)[1:]], axis=0)
4102    return wrap(array_ops.reshape(values, final_shape), True)
4103
4104
4105@RegisterPFor("TensorListScatterIntoExistingList")
4106def _convert_tensor_list_scatter(pfor_input):
4107  pfor_input.stack_inputs([1])
4108  handle, handle_stacked, _ = pfor_input.input(0)
4109  item = pfor_input.stacked_input(1)
4110  indices, indices_stacked, _ = pfor_input.input(2)
4111  if handle_stacked:
4112    handle = _untile_variant(handle)
4113  else:
4114    handle = _stack_tensor_list(handle, item.dtype,
4115                                pfor_input.pfor.loop_len_vector)
4116
4117  item = _transpose_first_two_dims(item)
4118  if indices_stacked:
4119    # Pretend the list is a dense tensor:
4120    #   list_as_dense: Tensor[list_len, loop_len, ...]
4121    # And indices are a tensor with shape (before transpose):
4122    #   indices: Tensor[loop_len, num_scatters]
4123    # The item to scatter has shape (before transpose):
4124    #   item: Tensor[loop_len, num_scatters, ...]
4125    #
4126    # We want list_as_dense[indices[i, j], i] = item[i, j]
4127    #
4128    # Since we're not just indexing along the first axis of `list_as_dense`, we
4129    # need to first extract the relevant list entries based on `indices`,
4130    # scatter into them according to the loop index, and re-scatter the chunks
4131    # we updated back into the list.
4132    indices = _transpose_first_two_dims(indices)
4133    indices_flat = array_ops.reshape(indices, [-1])
4134    # In many cases `indices` will be unique across pfor iterations, but this is
4135    # not guaranteed. If there are duplicates, we need to map multiple updates
4136    # to a single chunk extracted from the list. The last update should win.
4137    unique_indices = array_ops.unique(indices_flat)
4138    gathered_items = list_ops.tensor_list_gather(
4139        handle, unique_indices.y, element_dtype=item.dtype,
4140        element_shape=array_ops.shape(item)[1:])
4141    loop_idx = math_ops.range(pfor_input.pfor.loop_len_vector[0])
4142    scatters_per_op = array_ops.shape(indices)[0]
4143
4144    unique_indices_loop_idx = array_ops.reshape(array_ops.tile(
4145        loop_idx[None, :], [scatters_per_op, 1]), [-1])
4146    scatter_indices = array_ops.stack(
4147        [unique_indices.idx, unique_indices_loop_idx],
4148        axis=1)
4149    # This op does *not* guarantee last-update-wins on GPU, so semantics may not
4150    # be exactly preserved for duplicate updates there.
4151    scattered = array_ops.tensor_scatter_nd_update(
4152        tensor=gathered_items,
4153        indices=scatter_indices,
4154        updates=_flatten_first_two_dims(item))
4155    handle = list_ops.tensor_list_scatter(
4156        scattered, unique_indices.y, input_handle=handle)
4157  else:
4158    handle = list_ops.tensor_list_scatter(item, indices, input_handle=handle)
4159  return wrap(_tile_variant(handle, pfor_input), True)
4160
4161
4162@RegisterPFor("TensorListFromTensor")
4163def _convert_tensor_list_from_tensor(pfor_input):
4164  tensor = pfor_input.stacked_input(0)
4165  element_shape = pfor_input.unstacked_input(1)
4166  tensor = _transpose_first_two_dims(tensor)
4167  element_shape = _stack_tensor_list_shape(element_shape,
4168                                           pfor_input.pfor.loop_len_vector)
4169  handle = list_ops.tensor_list_from_tensor(tensor, element_shape)
4170  return wrap(_tile_variant(handle, pfor_input), True)
4171
4172
4173# StackV2 conversion is tricky since we don't have arrays of StackV2. So similar
4174# to TensorArrays, we convert them by changing the dimension of the elements
4175# inside the stack.
4176#
4177# We consider two cases:
4178#
4179# 1. StackV2 is constructed and used entirely inside the pfor loop.
4180# We keep a single Stack and perform the push/pop operations of all the
4181# iterations in lock-step. We also assume that all the iterations perform these
4182# operations. In case of dynamic control flow, if only some of the iterations
4183# try to perform a push/pop, then the conversion may not work correctly and may
4184# cause undefined behavior.
4185# TODO(agarwal): test StackV2 with dynamic control flow.
4186#
4187# 2. StackV2 is constructed outside the pfor loop.
4188# Performing stack push/pop in a parallel fashion is ill-defined. However given
4189# that reading stacks created externally is a common operation when computing
4190# jacobians, we provide some special semantics here as follows.
4191#  - disallow push operations to the stack
4192#  - pop operations are performed in lock step by all iterations, similar to the
4193#  case when the stack is created inside. A single value is popped during the
4194#  lock-step operation and broadcast to all the iterations. Values in the stack
4195#  are assumed to be loop-invariant.
4196#
4197# Some other implementation details:
4198# We use an ugly logic to find whether values in Stack data structure are
4199# loop invariant or not. When converting push/pop operations, we keep track of
4200# whether the last conversion used a stacked value or not (see _stack_cache
4201# below). As a result if an unstacked value is written first, subsequent stacked
4202# writes are disallowed when they could have been allowed in theory.
4203
4204# Map from cache key based on StackV2 handle to a bool indicating whether values
4205# are stacked or not.
4206# TODO(agarwal): move _stack_cache inside pfor?
4207_stack_cache = {}
4208
4209
4210def _stack_cache_key(pfor_input):
4211  """Create cache key corresponding to a stack handle."""
4212  op_type = pfor_input.op_type
4213  assert op_type in ["StackPushV2", "StackPopV2"], op_type
4214  orig_handle = pfor_input.op.inputs[0]
4215  while orig_handle.op.type in ["Identity", "Enter"]:
4216    orig_handle = orig_handle.op.inputs[0]
4217  assert orig_handle.op.type == "StackV2", orig_handle.op
4218  return ops.get_default_graph(), pfor_input.pfor, orig_handle
4219
4220
4221def _stack_handle_inside_pfor(handle, pfor_input):
4222  while handle.op.type in ["Identity", "Enter"]:
4223    handle = handle.op.inputs[0]
4224  assert handle.op.type == "StackV2", ("Unable to find StackV2 op. Got %s" %
4225                                       handle.op)
4226  return pfor_input.pfor.op_is_inside_loop(handle.op)
4227
4228
4229@RegisterPFor("StackPushV2")
4230def _convert_stack_push_v2(pfor_input):
4231  handle = pfor_input.unstacked_input(0)
4232  elem, elem_stacked, _ = pfor_input.input(1)
4233  swap_memory = pfor_input.get_attr("swap_memory")
4234
4235  if not _stack_handle_inside_pfor(pfor_input.op.inputs[0], pfor_input):
4236    raise ValueError("StackPushV2 not allowed on stacks created outside pfor")
4237  stack_cache_key = _stack_cache_key(pfor_input)
4238  stacked = _stack_cache.get(stack_cache_key, None)
4239  if stacked is None:
4240    stacked = elem_stacked
4241    _stack_cache[stack_cache_key] = stacked
4242  else:
4243    # If we previously made it unstacked then we can't revert to being stacked.
4244    if not stacked and elem_stacked:
4245      raise ValueError(
4246          "It looks like the stack was previously determined to be loop"
4247          " invariant, but we are now trying to push a loop dependent value"
4248          " to it. This is currently unsupported.")
4249    if stacked and not elem_stacked:
4250      elem = _stack(elem, pfor_input.pfor.loop_len_vector).t
4251  out = data_flow_ops.stack_push_v2(handle, elem, swap_memory=swap_memory)
4252  return wrap(out, stacked)
4253
4254
4255# Note that inputs to this convertor will be unstacked. However it should get
4256# called since it is a stateful op.
4257@RegisterPFor("StackPopV2")
4258def _convert_stack_pop_v2(pfor_input):
4259  handle = pfor_input.unstacked_input(0)
4260  stack_cache_key = _stack_cache_key(pfor_input)
4261  stacked = _stack_cache.get(stack_cache_key, None)
4262  # If a StackPushV2 has not been converted yet, we default to unstacked since
4263  # the push could be outside of pfor, or the convertor may not be called if the
4264  # inputs are unconverted.
4265  if stacked is None:
4266    stacked = False
4267    _stack_cache[stack_cache_key] = False
4268  elem_type = pfor_input.get_attr("elem_type")
4269  out = data_flow_ops.stack_pop_v2(handle, elem_type)
4270  return wrap(out, stacked)
4271
4272
4273# parsing_ops
4274
4275
4276@RegisterPFor("DecodeCSV")
4277def _convert_decode_csv(pfor_input):
4278  lines = pfor_input.stacked_input(0)
4279  record_defaults = [
4280      pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs)
4281  ]
4282  field_delim = pfor_input.get_attr("field_delim")
4283  use_quote_delim = pfor_input.get_attr("use_quote_delim")
4284  select_cols = pfor_input.get_attr("select_cols")
4285  if not select_cols:
4286    select_cols = None
4287  return [
4288      wrap(t, True) for t in parsing_ops.decode_csv(
4289          lines,
4290          record_defaults,
4291          field_delim=field_delim,
4292          use_quote_delim=use_quote_delim,
4293          select_cols=select_cols)
4294  ]
4295
4296
4297@RegisterPFor("ParseSingleExample")
4298def _convert_parse_single_example(pfor_input):
4299  serialized = pfor_input.stacked_input(0)
4300  dense_defaults = [
4301      pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs)
4302  ]
4303  sparse_keys = pfor_input.get_attr("sparse_keys")
4304  dense_keys = pfor_input.get_attr("dense_keys")
4305  sparse_types = pfor_input.get_attr("sparse_types")
4306  dense_shapes = pfor_input.get_attr("dense_shapes")
4307  output = gen_parsing_ops.parse_example(
4308      serialized=serialized,
4309      names=[],
4310      dense_defaults=dense_defaults,
4311      sparse_keys=sparse_keys,
4312      dense_keys=dense_keys,
4313      sparse_types=sparse_types,
4314      dense_shapes=dense_shapes)
4315  return [wrap(t, True, True) for t in nest.flatten(output)]
4316
4317
4318@RegisterPFor("ParseExampleV2")
4319def _convert_parse_example_v2(pfor_input):
4320  serialized = pfor_input.stacked_input(0)
4321  sparse_keys = pfor_input.unstacked_input(2)
4322  dense_keys = pfor_input.unstacked_input(3)
4323  ragged_keys = pfor_input.unstacked_input(4)
4324  dense_defaults = [
4325      pfor_input.unstacked_input(i) for i in range(5, pfor_input.num_inputs)
4326  ]
4327  num_sparse = pfor_input.get_attr("num_sparse")
4328  sparse_types = pfor_input.get_attr("sparse_types")
4329  ragged_value_types = pfor_input.get_attr("ragged_value_types")
4330  ragged_split_types = pfor_input.get_attr("ragged_split_types")
4331  dense_shapes = pfor_input.get_attr("dense_shapes")
4332  if serialized.shape.ndims not in (None, 1):
4333    raise ValueError("ParseExampleV2 can only be converted if `serialized` "
4334                     "is scalar.")
4335  output = gen_parsing_ops.parse_example_v2(
4336      serialized=serialized,
4337      names=[],
4338      sparse_keys=sparse_keys,
4339      dense_keys=dense_keys,
4340      ragged_keys=ragged_keys,
4341      dense_defaults=dense_defaults,
4342      num_sparse=num_sparse,
4343      sparse_types=sparse_types,
4344      ragged_value_types=ragged_value_types,
4345      ragged_split_types=ragged_split_types,
4346      dense_shapes=dense_shapes)
4347  return [wrap(t, True, True) for t in nest.flatten(output)]
4348
4349
4350# functional_ops
4351
4352
4353def _convert_function_call(func, converter, inputs):
4354  assert isinstance(func.graph, func_graph.FuncGraph), func
4355  assert isinstance(converter, PFor)
4356
4357  # TODO(agarwal): consider caching this function definition.
4358  @def_function.function
4359  def f(*args):
4360    assert all(isinstance(arg, WrappedTensor) for arg in args), args
4361    assert len(args) == len(func.graph.inputs), (args, func.graph.inputs)
4362    #  Map inputs to function arguments.
4363    for inp, arg in zip(func.graph.inputs, args):
4364      converter._add_conversion(inp, arg)
4365    # Convert output tensors.
4366    return tuple(
4367        [converter._convert_helper(x).t for x in func._func_graph_outputs])
4368
4369  call_outputs = f(*inputs)
4370  assert len(call_outputs) == len(func._func_graph_outputs)
4371  outputs = []
4372  for call_output, output_tensor in zip(call_outputs, func._func_graph_outputs):
4373    func_output = converter._convert_helper(output_tensor)
4374    outputs.append(
4375        wrap(call_output, func_output.is_stacked,
4376             func_output.is_sparse_stacked))
4377  return outputs
4378
4379
4380@RegisterPFor("StatefulPartitionedCall")
4381@RegisterPFor("PartitionedCall")
4382def _convert_partitioned_call(pfor_input):
4383  func_name = pfor_input.get_attr("f").name
4384  func = pfor_input.op.graph._get_function(compat.as_bytes(func_name))
4385  assert isinstance(func.graph, func_graph.FuncGraph), (
4386      "Could not find FuncGraph object for %s. Got func %s" % (func_name, func))
4387  pfor = pfor_input.pfor
4388  converter = PFor(
4389      loop_var=pfor.loop_var,
4390      loop_len=pfor.loop_len_vector[0],
4391      pfor_ops=func.graph.get_operations(),
4392      fallback_to_while_loop=pfor.fallback_to_while_loop,
4393      all_indices=pfor.all_indices,
4394      all_indices_partitioned=pfor.all_indices_partitioned,
4395      pfor_config=pfor.pfor_config)
4396  return _convert_function_call(func, converter, pfor_input.inputs)
4397
4398
4399def _partition_inputs_for_indices(inputs, indices):
4400  new_inputs = []
4401  for inp in inputs:
4402    if inp.is_stacked:
4403      new_inputs.append(wrap(array_ops.gather(inp.t, indices), True))
4404    else:
4405      new_inputs.append(inp)
4406  return new_inputs
4407
4408
4409def _outputs_for_branch(func_name, indices, pfor_input, inputs):
4410  if indices is None:
4411    indices = pfor_input.pfor.all_indices
4412    partitioned = pfor_input.pfor.all_indices_partitioned
4413  else:
4414    partitioned = True
4415  func = pfor_input.op.graph._get_function(func_name)
4416  converter = PFor(
4417      loop_var=pfor_input.pfor.loop_var,
4418      loop_len=array_ops.size(indices),
4419      pfor_ops=func.graph.get_operations(),
4420      fallback_to_while_loop=pfor_input.pfor.fallback_to_while_loop,
4421      all_indices=indices,
4422      all_indices_partitioned=partitioned,
4423      pfor_config=pfor_input.pfor.pfor_config)
4424  outputs = _convert_function_call(func, converter, inputs)
4425  stacked_outputs = []
4426  for out in outputs:
4427    if not out.is_stacked:
4428      stacked_outputs.append(_stack(out.t, [array_ops.size(indices)]).t)
4429    else:
4430      stacked_outputs.append(out.t)
4431  return stacked_outputs
4432
4433
4434# TODO(agarwal): Currently the converted code aggressively tiles loop variant
4435# outputs from the then/else branches. Instead, it could do so only if at least
4436# one of the branch outputs is loop variant.
4437@RegisterPFor("StatelessIf")
4438@RegisterPFor("If")
4439def _convert_if(pfor_input):
4440  cond, cond_stacked, _ = pfor_input.input(0)
4441  inputs = pfor_input.inputs[1:]
4442  then_branch = pfor_input.get_attr("then_branch")
4443  else_branch = pfor_input.get_attr("else_branch")
4444
4445  if cond_stacked:
4446    cond_int = math_ops.cast(cond, dtypes.int32)
4447    # Compute loop indices for the different branches
4448    false_indices, true_indices = data_flow_ops.dynamic_partition(
4449        pfor_input.pfor.all_indices, cond_int, 2)
4450    # Compute indices for cond being True or False.
4451    if pfor_input.pfor.all_indices_partitioned:
4452      else_indices, then_indices = data_flow_ops.dynamic_partition(
4453          math_ops.range(pfor_input.pfor.loop_len_vector[0]),
4454          cond_int, 2)
4455    else:
4456      else_indices, then_indices = false_indices, true_indices
4457    # Partition inputs
4458    then_inputs = _partition_inputs_for_indices(inputs, then_indices)
4459    else_inputs = _partition_inputs_for_indices(inputs, else_indices)
4460
4461    # Convert "then" branch.
4462    then_outputs = _outputs_for_branch(then_branch.name, true_indices,
4463                                       pfor_input, then_inputs)
4464
4465    # Convert "else" branch.
4466    else_outputs = _outputs_for_branch(else_branch.name, false_indices,
4467                                       pfor_input, else_inputs)
4468
4469    assert len(then_outputs) == len(else_outputs)
4470    # Note that if the "then" and "else" branches are updating the same state,
4471    # and possibly reading them as well, it could lead to undefined behavior
4472    # since the ordering of those operations is not well defined.
4473    # One possibility is to order all the "then" branches to execute before all
4474    # the "else" branches so that the side-effects in the former are visible to
4475    # the latter. For now, we leave that as undefined behavior.
4476    outputs = []
4477    # Merge outputs
4478    for then_output, else_output in zip(then_outputs, else_outputs):
4479      out = data_flow_ops.dynamic_stitch([then_indices, else_indices],
4480                                         [then_output, else_output])
4481      outputs.append(wrap(out, True))
4482    return outputs
4483  else:
4484    outputs = control_flow_ops.cond(
4485        cond,
4486        lambda: _outputs_for_branch(then_branch.name, None, pfor_input, inputs),
4487        lambda: _outputs_for_branch(else_branch.name, None, pfor_input, inputs))
4488    return [wrap(t, True) for t in outputs]
4489
4490
4491class WhileV2(object):
4492  """Object for vectorizing V2 while_loop op."""
4493
4494  def __init__(self, pfor_input):
4495    self._pfor_input = pfor_input
4496    self._pfor = pfor_input.pfor
4497    cond_func_name = pfor_input.get_attr("cond").name
4498    self._cond_func = pfor_input.op.graph._get_function(compat.as_bytes(
4499        cond_func_name))
4500    body_func_name = pfor_input.get_attr("body").name
4501    self._body_func = pfor_input.op.graph._get_function(compat.as_bytes(
4502        body_func_name))
4503    if self._cond_func is None or self._body_func is None:
4504      raise ValueError("Error extracting cond and body functions for op %s." % (
4505          self._pfor_input.op))
4506    # Indices of inputs that are passed unchanged through the while loop body.
4507    # Typically these are tensors captured from outside the body context.
4508    self._body_pass_through_indices = set()
4509    for i, (inp, out) in enumerate(zip(self._body_func.graph.inputs,
4510                                       self._body_func.graph.outputs)):
4511      if id(inp) == id(out):
4512        self._body_pass_through_indices.add(i)
4513    self._parallel_iterations = self._pfor_input.get_attr("parallel_iterations")
4514
4515  def _output_shapes(self):
4516    # Calculate output shape for vectorized loop. This will be used as
4517    # shape_invariant. Merges shape inference outputs with the `output_shapes`
4518    # attribute of the op.
4519    output_shapes = [out.shape for out in self._pfor_input.op.outputs]
4520    shapes = self._pfor_input.get_attr("output_shapes")
4521    if not shapes:
4522      shapes = [tensor_shape.TensorShape(None) for _ in output_shapes]
4523    else:
4524      shapes = [tensor_shape.TensorShape(shape) for shape in shapes]
4525    for i, shape in enumerate(shapes):
4526      shape = shape.merge_with(output_shapes[i])
4527      pfor_input = self._pfor_input.input(i)
4528      if pfor_input.is_stacked:
4529        if _is_variant_with_internal_stacking(pfor_input.t):
4530          shape = tensor_shape.TensorShape([]).concatenate(shape)
4531        else:
4532          shape = tensor_shape.TensorShape([None]).concatenate(shape)
4533      output_shapes[i] = shape
4534    assert len(output_shapes) == self._pfor_input.num_inputs
4535    return output_shapes
4536
4537  def _init_values(self):
4538    """Create arguments passed to converted while_loop."""
4539    loop_len = self._pfor.loop_len_vector[0]
4540    inputs = []
4541    # TensorArrays for outputs of converted while loop
4542    output_tas = []
4543
4544    with ops.name_scope("while_init"):
4545      for inp in self._pfor_input.inputs:
4546        inputs.append(inp.t)
4547        variant_type_id = _variant_type_id(inp.t)
4548        if variant_type_id in _INTERNAL_STACKING_TYPE_IDS:
4549          if variant_type_id != full_type_pb2.TFT_ARRAY:
4550            raise NotImplementedError(
4551                ("While loop conversion is only supported for TensorLists. Got "
4552                 "another variant {}, probably an optional. Please file a bug.")
4553                .format(inp.t))
4554          # For TensorLists, the input format is:
4555          #
4556          #   List[user_list_len, Tensor[loop_len, ...]]
4557          #
4558          # rather than the usual
4559          #
4560          #   Tensor[loop_len, ...]
4561          #
4562          # The body of the loop will take and return lists in this "internal
4563          # vectorization" format, so we want to keep it that way as much as
4564          # possible. We'll accumulate finished iterations (only relevant for
4565          # pfor-loop-variant while_loop conditions) in an accumulator with
4566          # type:
4567          #
4568          #   List[user_list_len, List[loop_len, Tensor[...]]]
4569          #
4570          # This means that each while_loop iteration, we'll iterate over the
4571          # length of the TensorList, dividing done/remaining pfor loop indices
4572          # and scattering the done indices into the inner nested list of the
4573          # accumulator.
4574          element_shape = list_ops.tensor_list_element_shape(
4575              inp.t, dtypes.int32)
4576          if inp.is_stacked:
4577            # Shapes may be tf.constant(-1) for fully dynamic, in which case
4578            # slicing is an error.
4579            element_shape = control_flow_ops.cond(
4580                math_ops.equal(array_ops.rank(element_shape), 0),
4581                lambda: element_shape,
4582                lambda: element_shape[1:])
4583          dtype = _parse_variant_shapes_and_types(inp.t)[0].dtype
4584
4585          def _init_loop_body(index, output_ta):
4586            output_ta = output_ta.write(
4587                index,
4588                list_ops.tensor_list_reserve(element_shape, loop_len, dtype))
4589            return index + 1, output_ta
4590
4591          length = list_ops.tensor_list_length(inp.t)
4592          output_ta = tensor_array_ops.TensorArray(
4593            inp.t.dtype,  # Variant; this is a nested TensorList
4594            size=length,
4595            dynamic_size=True,
4596            infer_shape=False)
4597          _, output_ta = control_flow_ops.while_loop(
4598              lambda index, _: index < length,
4599              _init_loop_body,
4600              [0, output_ta])
4601        else:
4602          output_ta = tensor_array_ops.TensorArray(
4603            inp.t.dtype,
4604            size=loop_len,
4605            dynamic_size=False,
4606            infer_shape=True)
4607        output_tas.append(output_ta)
4608    # See documentation for __call__ for the structure of init_values.
4609    indices = (
4610        math_ops.range(self._pfor.loop_len_vector[0])
4611        if self._pfor.all_indices_partitioned else self._pfor.all_indices)
4612    return [True, indices] + inputs + output_tas
4613
4614  def _process_cond_unstacked(self, conditions, indices, inputs, output_tas):
4615    """Handles case when condition is pfor loop invariant."""
4616    # Note that all iterations end together. So we don't need to partition the
4617    # inputs.
4618    not_all_done = array_ops.reshape(conditions, [])
4619    return not_all_done, indices, inputs, output_tas
4620
4621  def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked,
4622                            output_tas):
4623    """Handles case when condition is pfor loop dependent."""
4624    # Compute if all iterations are done.
4625    not_all_done = math_ops.reduce_any(conditions)
4626    conditions_int = math_ops.cast(conditions, dtypes.int32)
4627    # Partition the indices.
4628    done_indices, new_indices = data_flow_ops.dynamic_partition(
4629        indices, conditions_int, 2)
4630
4631    new_inputs = []
4632    new_output_tas = []
4633    for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)):
4634      pass_through = i in self._body_pass_through_indices
4635      if not pass_through and  _variant_type_id(inp) == full_type_pb2.TFT_ARRAY:
4636        shape_and_type = _parse_variant_shapes_and_types(inp)[0]
4637        element_shape = list_ops.tensor_list_element_shape(inp, dtypes.int32)
4638        user_list_len = list_ops.tensor_list_length(inp)
4639
4640        def _split_vectorized_ta_element(index, new_inp, new_out_ta):
4641          elem = list_ops.tensor_list_get_item(inp, index, shape_and_type.dtype,
4642                                               element_shape)
4643          if stacked:
4644            done_elem, new_elem = data_flow_ops.dynamic_partition(
4645                elem, conditions_int, 2)
4646            new_inp = list_ops.tensor_list_set_item(new_inp, index, new_elem)
4647          else:
4648            done_elem = _stack(elem, [array_ops.size(done_indices)]).t
4649          done_accum = new_out_ta.read(index)
4650          done_accum = list_ops.tensor_list_scatter(
4651              tensor=done_elem, indices=done_indices, input_handle=done_accum)
4652          new_out_ta = new_out_ta.write(index, done_accum)
4653          return index + 1, new_inp, new_out_ta
4654
4655        length = list_ops.tensor_list_length(inp)
4656        new_inp = list_ops.tensor_list_reserve(
4657            tensor_shape.TensorShape([None])
4658            + tensor_shape.TensorShape(shape_and_type.shape)[1:],
4659            user_list_len, shape_and_type.dtype)
4660        _, new_inp, out_ta = control_flow_ops.while_loop(
4661            lambda index, unused_new_inp, unused_new_out_ta: index < length,
4662            _split_vectorized_ta_element,
4663            [0, new_inp, output_tas[i]])
4664      else:
4665        # Partition the inputs.
4666        if stacked:
4667          done_inp, new_inp = data_flow_ops.dynamic_partition(
4668              inp, conditions_int, 2)
4669        else:
4670          if not pass_through:
4671            done_inp = _stack(inp, [array_ops.size(done_indices)]).t
4672          new_inp = inp
4673
4674        out_ta = output_tas[i]
4675        if not pass_through:
4676          # Note that done_indices can be empty. done_inp should also be empty
4677          # in that case.
4678          out_ta = out_ta.scatter(done_indices, done_inp)
4679      new_inputs.append(new_inp)
4680      new_output_tas.append(out_ta)
4681
4682    assert len(new_output_tas) == len(output_tas)
4683    assert len(new_inputs) == len(inputs)
4684    return not_all_done, new_indices, new_inputs, new_output_tas
4685
4686  def _process_body(self, inputs_stacked, new_indices, cond_stacked,
4687                    new_inputs, not_all_done):
4688    """Convert the body function."""
4689    # This is used to store the indices of inputs to the while op that need to
4690    # be stacked. This stacking may be needed in cases where the input to the
4691    # while_loop is loop_invariant but the corresponding output is not.
4692    mismatching_stacked_indices = []
4693
4694    def true_fn():
4695      """Converts the body function for all but last iteration."""
4696      wrapped_inputs = [wrap(inp, stacked) for inp, stacked in
4697                        zip(new_inputs, inputs_stacked)]
4698      # Note the iterative process below to figure out loop invariance.
4699      # Here we iterate on vectorization process till a fixed point. The issue
4700      # is that the while body can take pfor loop invariant inputs but return
4701      # loop variant outputs. For any loop variant output, the corresponding
4702      # input has to be then made loop variant (since subsequent while
4703      # iterations will need to see loop variant values).
4704      # However once we make a new input loop variant, we might make other
4705      # outputs loop variant. Hence we need to iterate till we get fixed point.
4706      while True:
4707        if self._pfor.all_indices_partitioned:
4708          indices = array_ops.gather(self._pfor.all_indices, new_indices)
4709        else:
4710          indices = new_indices
4711        body_pfor = PFor(
4712            loop_var=self._pfor.loop_var,
4713            loop_len=array_ops.size(new_indices),
4714            pfor_ops=self._body_func.graph.get_operations(),
4715            fallback_to_while_loop=self._pfor.fallback_to_while_loop,
4716            all_indices=indices,
4717            all_indices_partitioned=(self._pfor.all_indices_partitioned or
4718                                     cond_stacked),
4719            pfor_config=self._pfor.pfor_config)
4720        stacking_mismatch = False
4721        outputs = _convert_function_call(self._body_func,
4722                                         body_pfor,
4723                                         wrapped_inputs)
4724        for i, (out, inp) in enumerate(zip(outputs, wrapped_inputs)):
4725          if out.is_stacked != inp.is_stacked:
4726            stacking_mismatch = True
4727            mismatching_stacked_indices.append(i)
4728            stacked = _stack(inp.t, [array_ops.size(new_indices)])
4729            if inp.t.dtype == dtypes.variant:
4730              stacked = wrap(
4731                  _tile_variant_with_length(stacked.t,
4732                                            [array_ops.size(new_indices)]))
4733            wrapped_inputs[i] = stacked
4734        if not stacking_mismatch:
4735          if mismatching_stacked_indices:
4736            # We needed to stack some inputs. This code will be abandoned and
4737            # should not get executed. Hence we simply return `new_inputs` to
4738            # make sure the graph construction code completes.
4739            with ops.control_dependencies([
4740                control_flow_ops.Assert(
4741                    False, ["pfor ERROR: this branch should never execute"])]):
4742              return [array_ops.identity(x) for x in new_inputs]
4743          else:
4744            return [out.t for out in outputs]
4745
4746    # If all are done, we simply return `new_inputs`. Else we need to run the
4747    # body function.
4748    return control_flow_ops.cond(
4749        not_all_done,
4750        true_fn,
4751        lambda: list(new_inputs)), mismatching_stacked_indices
4752
4753  def __call__(self):
4754    """Converter for the V2 while_loop.
4755
4756    The conversion of a while_loop is another while_loop.
4757
4758    The arguments to this converted while_loop are as follows:
4759    not_all_done: Boolean scalar Tensor indicating if all the pfor iterations
4760      are done.
4761    indices: int32 1-D Tensor storing the id of the pfor iterations that are not
4762      done.
4763    args: Remaining arguments. These can be divided into 2 categories:
4764      - The first set of arguments correspond one-to-one to the inputs to the
4765        unvectorized while_loop.
4766      - The second set are TensorArrays, corresponding one-to-one to each output
4767        of the unvectorized while_loop. Each TensorArray has `PFor.loop_len`
4768        elements, i.e. the number of pfor iterations. At the end, the i'th
4769        element of each TensorArray will contain the output computed by the i'th
4770        iteration of pfor. Note that elements can be written into these tensors
4771        arrays in any order, depending on when the corresponding pfor iteration
4772        is done.
4773    In each iteration, the while_loop body recomputes the condition for all
4774    active pfor iterations to see which of them are now done. It then partitions
4775    all the inputs and passes them along to the converted body. Values for all
4776    the iterations that are done are written to TensorArrays indexed by the pfor
4777    iteration number. When all iterations are done, the TensorArrays are stacked
4778    to get the final value.
4779
4780    Returns:
4781      List of converted outputs.
4782    """
4783    output_shapes = self._output_shapes()
4784    # Note that we use these lists as a hack since we need the `body` to compute
4785    # these values during construction of the while_loop graph.
4786    cond_is_stacked = [None]
4787    indices_to_stack = []
4788
4789    def cond(not_all_done, *_):
4790      return not_all_done
4791
4792    def body(not_all_done, indices, *args):
4793      # See documentation for __call__ for the structure of *args.
4794      num_inputs = self._pfor_input.num_inputs
4795      inputs = args[:num_inputs]
4796      output_tas = args[num_inputs:]
4797      inputs_stacked = [x.is_stacked for x in self._pfor_input.inputs]
4798      assert len(inputs) >= len(output_tas)
4799      assert len(inputs) == len(inputs_stacked)
4800      # Convert condition
4801      with ops.name_scope("while_cond"):
4802        # Note that we set all_indices_partitioned to True here. At this point
4803        # we don't know if indices will be partitioned. Hence we use the
4804        # conservative value.
4805        cond_pfor = PFor(
4806            loop_var=self._pfor.loop_var,
4807            loop_len=array_ops.size(indices),
4808            pfor_ops=self._cond_func.graph.get_operations(),
4809            fallback_to_while_loop=self._pfor.fallback_to_while_loop,
4810            all_indices=indices,
4811            all_indices_partitioned=True,
4812            pfor_config=self._pfor.pfor_config)
4813
4814        wrapped_inputs = [wrap(inp, stacked) for inp, stacked
4815                          in zip(inputs, inputs_stacked)]
4816        conditions, cond_stacked, _ = _convert_function_call(
4817            self._cond_func,
4818            cond_pfor,
4819            wrapped_inputs)[0]
4820        cond_is_stacked[0] = cond_stacked
4821
4822      # Recompute the new condition, write outputs of done iterations, and
4823      # partition the inputs if needed.
4824      if not cond_stacked:
4825        (not_all_done, new_indices, new_inputs,
4826         new_output_tas) = self._process_cond_unstacked(conditions, indices,
4827                                                        inputs, output_tas)
4828      else:
4829        (not_all_done, new_indices, new_inputs,
4830         new_output_tas) = self._process_cond_stacked(conditions, indices,
4831                                                      inputs, inputs_stacked,
4832                                                      output_tas)
4833      # Convert body
4834      with ops.name_scope("while_body"):
4835        #  Compute the outputs from the body.
4836        new_outputs, mismatching_stacked_indices = self._process_body(
4837            inputs_stacked, new_indices, cond_stacked, new_inputs, not_all_done)
4838
4839      indices_to_stack[:] = mismatching_stacked_indices
4840      for i, new_output in enumerate(new_outputs):
4841        new_output.set_shape(output_shapes[i])
4842      new_args = ([not_all_done, new_indices] + new_outputs +
4843                  list(new_output_tas))
4844      return tuple(new_args)
4845
4846    # Note that we run the code below in a function since we might abandon the
4847    # generated code in cases where the conversion dictates that some inputs be
4848    # further stacked. Hence we run the graph construction using
4849    # `get_concrete_function` and avoid calling the constructed function if not
4850    # needed.
4851    @def_function.function
4852    def while_fn():
4853      # Create init_values that will be passed to the while_loop.
4854      init_values = self._init_values()
4855      ta_shape_invariants = [tensor_shape.TensorShape([]) for _ in
4856                             self._pfor_input.outputs]
4857      shape_invariants = (
4858          [tensor_shape.TensorShape([]), tensor_shape.TensorShape([None])]
4859          + output_shapes + ta_shape_invariants)
4860
4861      while_outputs = control_flow_ops.while_loop(
4862          cond, body, init_values,
4863          shape_invariants=shape_invariants,
4864          parallel_iterations=self._parallel_iterations)
4865      if indices_to_stack:
4866        # This function will be abandoned.
4867        return while_outputs
4868      else:
4869        num_inputs = self._pfor_input.num_inputs
4870        new_inputs = while_outputs[2:num_inputs+2]
4871        output_tas = while_outputs[num_inputs+2:]
4872        assert cond_is_stacked[0] is not None
4873        outputs = []
4874        for i, inp in enumerate(new_inputs):
4875          if cond_is_stacked[0]:
4876            if i in self._body_pass_through_indices:
4877              outputs.append(init_values[i + 2])
4878            else:
4879              ta = output_tas[i]
4880              if _variant_type_id(inp) == full_type_pb2.TFT_ARRAY:
4881                shape_and_type = _parse_variant_shapes_and_types(inp)[0]
4882                length = list_ops.tensor_list_length(inp)
4883
4884                # We have been accumulating values in a:
4885                #
4886                #   List[user_list_len, List[loop_len, Tensor[...]]]
4887                #
4888                # We want to return an output in the same format as the input:
4889                #
4890                #   List[user_list_len, Tensor[loop_len, ...]]
4891                #
4892                # So we need to loop over the list and stack its contents.
4893                def _stack_loop_body(index, output_list):
4894                  current_value = ta.read(index)
4895                  output_list = list_ops.tensor_list_set_item(
4896                      output_list, index,
4897                      list_ops.tensor_list_stack(
4898                          current_value, shape_and_type.dtype))
4899                  return index + 1, output_list
4900
4901                output_list = list_ops.tensor_list_reserve(
4902                    tensor_shape.TensorShape(shape_and_type.shape), length,
4903                    shape_and_type.dtype)
4904                _, output_list = control_flow_ops.while_loop(
4905                    lambda index, _: index < length,
4906                    _stack_loop_body,
4907                    [0, output_list])
4908                outputs.append(output_list)
4909              else:
4910                outputs.append(ta.stack())
4911          else:
4912            outputs.append(inp)
4913        return outputs
4914
4915    _ = while_fn.get_concrete_function()
4916    if indices_to_stack:
4917      # Need to abandon the current conversion, stack some inputs and restart.
4918      self._pfor_input.stack_inputs(
4919          stack_indices=indices_to_stack, tile_variants=True)
4920      # Note that this call will recurse at most one time. The first call will
4921      # do the required stacking, based on the iterative procedure in
4922      # _process_body, and the next invocation to __call__ should not need to do
4923      # any more stacking.
4924      # We invoke `self()` here as a way to discard any corrupted state.
4925      return self()
4926    else:
4927      outputs = while_fn()
4928      wrapped_outputs = []
4929      for i, (out, inp) in enumerate(zip(outputs, self._pfor_input.inputs)):
4930        if i not in self._body_pass_through_indices and cond_is_stacked[0]:
4931          wrapped_outputs.append(wrap(out, True))
4932        else:
4933          wrapped_outputs.append(wrap(out, inp.is_stacked))
4934      return wrapped_outputs
4935
4936
4937@RegisterPFor("StatelessWhile")
4938@RegisterPFor("While")
4939def _convert_while(pfor_input):
4940  converter = WhileV2(pfor_input)
4941  return converter()
4942
4943
4944# spectral_ops
4945
4946
4947@RegisterPForWithArgs("FFT", gen_spectral_ops.fft)
4948@RegisterPForWithArgs("FFT2D", gen_spectral_ops.fft2d)
4949@RegisterPForWithArgs("FFT3D", gen_spectral_ops.fft3d)
4950@RegisterPForWithArgs("IFFT", gen_spectral_ops.ifft)
4951@RegisterPForWithArgs("IFFT2D", gen_spectral_ops.ifft2d)
4952@RegisterPForWithArgs("IFFT3D", gen_spectral_ops.ifft3d)
4953def _convert_fft(pfor_input, _, op_func):
4954  return wrap(op_func(pfor_input.stacked_input(0)), True)
4955
4956
4957@RegisterPForWithArgs("RFFT", gen_spectral_ops.rfft, "Tcomplex")
4958@RegisterPForWithArgs("RFFT2D", gen_spectral_ops.rfft2d, "Tcomplex")
4959@RegisterPForWithArgs("RFFT3D", gen_spectral_ops.rfft3d, "Tcomplex")
4960@RegisterPForWithArgs("IRFFT", gen_spectral_ops.irfft, "Treal")
4961@RegisterPForWithArgs("IRFFT2D", gen_spectral_ops.irfft2d, "Treal")
4962@RegisterPForWithArgs("IRFFT3D", gen_spectral_ops.irfft3d, "Treal")
4963def _convert_rfft(pfor_input, _, op_func, attr_name):
4964  inp = pfor_input.stacked_input(0)
4965  fft_length = pfor_input.unstacked_input(1)
4966  attr = pfor_input.get_attr(attr_name)
4967  return wrap(op_func(inp, fft_length, attr), True)
4968