• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Compiled parallel-for loop."""
16# pylint: disable=missing-docstring,g-direct-tensorflow-import
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import string
24import sys
25import traceback
26
27import six
28
29from tensorflow.compiler.tf2xla.python import xla
30from tensorflow.python.eager import context
31from tensorflow.python.eager import def_function
32from tensorflow.python.eager import execute
33from tensorflow.python.framework import constant_op
34from tensorflow.python.framework import dtypes
35from tensorflow.python.framework import func_graph
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import sparse_tensor
38from tensorflow.python.framework import tensor_shape
39from tensorflow.python.framework import tensor_spec
40from tensorflow.python.framework import tensor_util
41from tensorflow.python.ops import array_ops
42from tensorflow.python.ops import bitwise_ops
43from tensorflow.python.ops import check_ops
44from tensorflow.python.ops import control_flow_ops
45from tensorflow.python.ops import data_flow_ops
46from tensorflow.python.ops import gen_array_ops
47from tensorflow.python.ops import gen_linalg_ops
48from tensorflow.python.ops import gen_nn_ops
49from tensorflow.python.ops import gen_parsing_ops
50from tensorflow.python.ops import gen_random_ops
51from tensorflow.python.ops import gen_sparse_ops
52from tensorflow.python.ops import linalg_ops
53from tensorflow.python.ops import map_fn
54from tensorflow.python.ops import math_ops
55from tensorflow.python.ops import nn_ops
56from tensorflow.python.ops import parsing_ops
57from tensorflow.python.ops import sparse_ops
58from tensorflow.python.ops import special_math_ops
59from tensorflow.python.ops import tensor_array_ops
60from tensorflow.python.platform import flags
61from tensorflow.python.platform import tf_logging as logging
62from tensorflow.python.util import compat
63from tensorflow.python.util import nest
64from tensorflow.python.util import object_identity
65
66flags.DEFINE_bool(
67    "op_conversion_fallback_to_while_loop", False,
68    "If true, falls back to using a while loop for ops for "
69    "which a converter is not defined.")
70
71
72def _stack(t, length):
73  """stacks `t` `length` times."""
74  ones = array_ops.ones_like(array_ops.shape(t))
75  multiples = array_ops.concat([length, ones], 0)
76  t = array_ops.tile(array_ops.expand_dims(t, 0), multiples)
77  return wrap(t, True)
78
79
80# The following stateful ops can be safely called once, and with the same
81# signature as the unconverted version, if their inputs are loop invariant.
82# TODO(agarwal): implement a strategy for converting Variable reads/writes. The
83# plan is to map each read/write in the loop_fn to a corresponding merged
84# read/write in the converted graph. Writes need to be mergeable (e.g.
85# AssignAdd) to be used in `pfor`. Given a certain read/write order in the
86# loop_fn, doing a one-to-one conversion will simulate executing such
87# instructions in lock-step across all iterations.
88passthrough_stateful_ops = set([
89    "VariableV2",
90    "VarHandleOp",
91    "ReadVariableOp",
92    "StackV2",
93    "TensorArrayWriteV3",
94    "TensorArrayReadV3",
95    "TensorArraySizeV3",
96])
97
98
99def _is_stateful_pfor_op(op):
100  if isinstance(op, WhileOp):
101    return op.is_stateful
102  if op.type == "Const":
103    # Const didn't have an op_def.
104    return False
105  if op.type in passthrough_stateful_ops:
106    return False
107  assert hasattr(op, "op_def") and op.op_def is not None, op
108  return op.op_def.is_stateful
109
110
111# pylint: disable=protected-access
112class WhileOp(object):
113  """Object for storing state for converting the outputs of a while_loop."""
114
115  def __init__(self, exit_node, pfor_ops, pfor_config):
116    """Initializer.
117
118    Args:
119      exit_node: A tensor output from the while_loop.
120      pfor_ops: list of ops inside the current pfor loop.
121      pfor_config: PForConfig object used while constructing loop body.
122    """
123    self._pfor_config = pfor_config
124    self._pfor_ops = set(pfor_ops)
125    self._pfor_op_ids = set(x._id for x in pfor_ops)
126    assert isinstance(exit_node, ops.Tensor)
127    self._while_context = exit_node.op._get_control_flow_context()
128    assert isinstance(self._while_context, control_flow_ops.WhileContext)
129    self._context_name = self._while_context.name
130    self._condition = self._while_context.pivot.op.inputs[0]
131    # Parts of an external while_loop could be created inside a pfor loop.
132    # However for the purpose here, we declare such loops to be external. Also
133    # note that we check if the condition was created inside or outside to
134    # determine if the while_loop was first created inside or outside.
135    # TODO(agarwal): check that the Enter and Exit of this loop are unstacked.
136    self._is_inside_loop = self.op_is_inside_loop(self._condition.op)
137    if self._is_inside_loop:
138      for e in self._while_context.loop_exits:
139        assert self.op_is_inside_loop(e.op)
140
141    # Note the code below tries to reverse engineer an existing while_loop graph
142    # by assuming the following pattern of nodes.
143    #
144    #          NextIteration <---- Body <--- Enter
145    #              |                ^
146    #              V             ___| Y
147    #    Enter -> Merge -> Switch___
148    #                       ^       | N
149    #                       |       V
150    #                  LoopCond    Exit
151
152    # Node that elements in the list below correspond one-to-one with each
153    # other. i.e. these lists are the same size, and the i_th entry corresponds
154    # to different Operations/Tensors of a single cycle as illustrated above.
155    # List of Switch ops (ops.Operation) that feed into an Exit Node.
156    self._exit_switches = []
157    # List of inputs (ops.Tensor) to NextIteration.
158    self._body_outputs = []
159    # List of list of control inputs of the NextIteration nodes.
160    self._next_iter_control_inputs = []
161    # List of Merge ops (ops.Operation).
162    self._enter_merges = []
163    # List of output (ops.Tensor) of Exit nodes.
164    self._outputs = []
165
166    # List of Enter Tensors.
167    # There are two types of Enter nodes:
168    # - The Enter nodes that are used in the `loop_vars` argument to
169    # `while_loop` (see
170    # https://www.tensorflow.org/api_docs/python/tf/while_loop). We collect
171    # these Enter nodes immediately below by tracing backwards from the Exit
172    # nodes via Exit <- Switch <- Merge <- Enter. You can see this chain in the
173    # diagram above. This allows us to have a 1:1 correspondence between the
174    # self._outputs and the first elements in self._enters.
175    # - The Enter nodes that are used only by the body. They don't appear in the
176    # `loop_vars` and are not returned from the `while_loop`. In Python code,
177    # they are usually captured by the body lambda. We collect them below by
178    # iterating over all the ops in the graph. They are appended to the end of
179    # self._enters or self._direct_enters, and don't correspond to any outputs
180    # in self._outputs. Note that we keep the resource/variant Enter nodes in
181    # self._direct_enters and the constructed while_loop's body uses them
182    # directly as opposed to passing them as loop variables. This is done
183    # because the while_body cannot partition the resource/variant Tensors, so
184    # it has to leave them unchanged.
185    self._enters = []
186    self._direct_enters = []
187
188    for e in self._while_context.loop_exits:
189      self._outputs.append(e.op.outputs[0])
190      switch = e.op.inputs[0].op
191      assert switch.type == "Switch", switch
192      self._exit_switches.append(switch)
193      merge = switch.inputs[0].op
194      assert merge.type == "Merge", merge
195      self._enter_merges.append(merge)
196      enter = merge.inputs[0].op
197      assert enter.type == "Enter", enter
198      self._enters.append(enter.outputs[0])
199      next_iter = merge.inputs[1].op
200      assert next_iter.type == "NextIteration", next_iter
201      self._body_outputs.append(next_iter.inputs[0])
202      self._next_iter_control_inputs.append(next_iter.control_inputs)
203
204    # Collect all the Enter nodes that are not part of `loop_vars`, the second
205    # category described above.
206    # Also track whether the loop body has any stateful ops.
207    self._is_stateful = False
208    for op in ops.get_default_graph().get_operations():
209      # TODO(agarwal): make sure this works with nested case.
210      control_flow_context = op._get_control_flow_context()
211      if control_flow_context is None:
212        continue
213      if control_flow_context.name == self._context_name:
214        self._is_stateful |= _is_stateful_pfor_op(op)
215        if op.type == "Enter":
216          output = op.outputs[0]
217          if output not in self._enters:
218            if output.dtype in (dtypes.resource, dtypes.variant):
219              if output not in self._direct_enters:
220                self._direct_enters.append(output)
221            else:
222              self._enters.append(output)
223
224  def __str__(self):
225    """String representation."""
226    return "while_loop(%s)" % self.name
227
228  @property
229  def inputs(self):
230    """Input to all the Enter nodes."""
231    return [x.op.inputs[0] for x in self._enters + self._direct_enters]
232
233  @property
234  def control_inputs(self):
235    """Control input to all the Enter nodes."""
236    control_inputs = []
237    for x in self._enters + self._direct_enters:
238      control_inputs.extend(x.op.control_inputs)
239    return control_inputs
240
241  @property
242  def outputs(self):
243    """Outputs of all the Exit nodes."""
244    return self._outputs
245
246  @property
247  def name(self):
248    """Context name for the while loop."""
249    return self._context_name
250
251  @property
252  def is_inside_loop(self):
253    """Returns true if the while_loop was created inside the pfor."""
254    return self._is_inside_loop
255
256  def op_is_inside_loop(self, op):
257    """True if op was created inside the pfor loop body."""
258    assert isinstance(op, ops.Operation)
259    # Note that we use self._pfor_op_ids for the check and not self._pfor_ops
260    # since it appears there tensorflow API could return different python
261    # objects representing the same Operation node.
262    return op._id in self._pfor_op_ids
263
264  @property
265  def is_stateful(self):
266    return self._is_stateful
267
268  @property
269  def pfor_converter(self):
270    """Return a converter for the while loop."""
271    return self
272
273  def _init_pfor(self, parent_pfor, indices, cond_stacked, inputs,
274                 inputs_stacked):
275    """Create a PFor object for converting parts of the while_loop.
276
277    Args:
278      parent_pfor: PFor object being used for converting the while_loop.
279      indices: int32 Tensor of ids for the iterations that are still active
280        (i.e. did not exit the while_loop).
281      cond_stacked: True if the while_loop condition is stacked.
282      inputs: list of input Tensors corresponding 1-to-1 with self._enters. Note
283        that these Tensors are a subset of the loop variables for the generated
284        while_loop.
285      inputs_stacked: List of booleans corresponding 1-to-1 with `inputs`,
286        indicating if the value is stacked or not.
287
288    Returns:
289      A PFor instance. The instance is initialized by adding conversion mappings
290        of nodes that will be external to the conversion that the returned
291        instance will be used for. e.g. Enter nodes as well as Merge and Switch
292        outputs are mapped to converted values.
293    """
294    num_outputs = len(self._outputs)
295    assert len(inputs) == len(self._enters)
296    assert len(inputs_stacked) == len(self._enters)
297    loop_var = parent_pfor.loop_var
298    loop_len = array_ops.size(indices)
299    pfor = PFor(
300        loop_var,
301        loop_len,
302        pfor_ops=self._pfor_ops,
303        all_indices=indices,
304        all_indices_partitioned=cond_stacked,
305        pfor_config=self._pfor_config)
306    # Map all inputs of Enter nodes in self._direct_enters to their converted
307    # values.
308    for enter in self._direct_enters:
309      enter_input = enter.op.inputs[0]
310      converted_enter, stacked, is_sparse_stacked = parent_pfor._convert_helper(
311          enter_input)
312      # Since these are resources / variants, they should be unstacked.
313      assert not stacked and not is_sparse_stacked, (enter, converted_enter)
314      pfor._add_conversion(enter, wrap(converted_enter, False))
315
316    # Map all Enter nodes to the inputs.
317    for enter, inp, stacked in zip(self._enters, inputs, inputs_stacked):
318      pfor._add_conversion(enter, wrap(inp, stacked))
319    # Map outputs of Switch and Merge.
320    for i in range(num_outputs):
321      wrapped_inp = wrap(inputs[i], inputs_stacked[i])
322      merge = self._enter_merges[i]
323      pfor._add_conversion(merge.outputs[0], wrapped_inp)
324      # Note that second output of Merge is typically not used, except possibly
325      # as a control dependency. To avoid trying to output the correct value, we
326      # employ a hack here. We output a dummy invalid value with an incorrect
327      # dtype. This will allow control dependency to work but if using it as an
328      # input, it should typically lead to errors during graph construction due
329      # to dtype mismatch.
330      # TODO(agarwal): Check in the original graph to see if there are any
331      # consumers of this Tensor that use it as an input.
332      pfor._add_conversion(merge.outputs[1],
333                           wrap(constant_op.constant(-1.0), False))
334      switch = self._exit_switches[i]
335      # Don't need to worry about switch.output[0] which will feed to Exit node.
336      pfor._add_conversion(switch.outputs[1], wrapped_inp)
337    return pfor
338
339  def _convert_enter(self, parent_pfor, enter):
340    """Converts an Enter node."""
341    inp, stacked, _ = parent_pfor._convert_helper(enter.op.inputs[0])
342    control_inputs = []
343    for x in enter.op.control_inputs:
344      converted = parent_pfor._convert_helper(x)
345      if not isinstance(converted, ops.Operation):
346        converted = converted.t
347      control_inputs.append(converted)
348    if control_inputs:
349      with ops.control_dependencies(control_inputs):
350        inp = array_ops.identity(inp)
351    return inp, stacked
352
353  def _maybe_stacked(self, cache, inp):
354    """Heuristic to figue out if the coverting inp leads to a stacked value.
355
356
357    Args:
358      cache: map from Tensor to boolean indicating stacked/unstacked.
359      inp: input Tensor.
360
361    Returns:
362      True if `inp` could get stacked. If the function returns False, the
363      converted value should be guaranteed to be unstacked. If returning True,
364      it may or may not be stacked.
365    """
366    if inp in cache:
367      return cache[inp]
368    if not self.op_is_inside_loop(inp.op):
369      return False
370    op = inp.op
371    output = False
372    if op.type in [
373        "Shape",
374        "Rank",
375        "ShapeN",
376        "ZerosLike",
377        "TensorArrayV3",
378        "TensorArraySizeV3",
379    ]:
380      output = False
381    elif _is_stateful_pfor_op(op):
382      # This may be fairly aggressive.
383      output = True
384    elif op.type == "Exit":
385      # This may be fairly aggressive.
386      output = True
387    else:
388      for t in op.inputs:
389        if self._maybe_stacked(cache, t):
390          output = True
391          break
392    cache[inp] = output
393    return output
394
395  def _create_init_values(self, pfor_input):
396    """Create arguments passed to converted while_loop."""
397    with ops.name_scope("while_init"):
398      loop_len_vector = pfor_input.pfor.loop_len_vector
399      loop_len = loop_len_vector[0]
400      num_outputs = len(self._outputs)
401
402      inputs = []
403      maybe_stacked_cache = {}
404      # Convert all the Enters. Need to do this before checking for stacking
405      # below.
406      for i, enter in enumerate(self._enters):
407        inp, stacked = self._convert_enter(pfor_input.pfor, enter)
408        inputs.append(inp)
409        maybe_stacked_cache[enter] = stacked
410        # Since this enter node is part of the `loop_vars`, it corresponds to an
411        # output and its preceding switch. We mark this switch's output the same
412        # stackness, to act at the base case for the logic below. Below, we will
413        # be going through the body figuring out which inputs might need to be
414        # stacked and which inputs can safely remain unstacked.
415        if i < num_outputs:
416          maybe_stacked_cache[self._exit_switches[i].outputs[1]] = stacked
417
418      # Shape invariants for init_values corresponding to self._enters.
419      input_shape_invariants = []
420      # TensorArrays for outputs of converted while loop
421      output_tas = []
422      # Shape invariants for output TensorArrays.
423      ta_shape_invariants = []
424      # List of booleans indicating stackness of inputs, i.e. tensors
425      # corresponding to self._enters.
426      inputs_stacked = []
427      for i, inp in enumerate(inputs):
428        enter = self._enters[i]
429        inp_stacked = self._maybe_stacked(maybe_stacked_cache, enter)
430        # Note that even when an input is unstacked, the body could make it
431        # stacked. we use a heuristic below to figure out if body may be making
432        # it stacked.
433        if i < num_outputs:
434          body_output = self._body_outputs[i]
435          if enter.op in self._pfor_ops:
436            body_output_stacked = self._maybe_stacked(maybe_stacked_cache,
437                                                      body_output)
438          else:
439            # If constructed outside of pfor loop, then the output would not be
440            # stacked.
441            body_output_stacked = False
442          if body_output_stacked and not inp_stacked:
443            inp = _stack(inp, loop_len_vector).t
444            inputs[i] = inp
445            inp_stacked = True
446          # TODO(agarwal): other attributes for the TensorArray ?
447          output_tas.append(tensor_array_ops.TensorArray(inp.dtype, loop_len))
448          ta_shape_invariants.append(tensor_shape.TensorShape(None))
449
450        inputs_stacked.append(inp_stacked)
451        input_shape_invariants.append(tensor_shape.TensorShape(None))
452
453      # See documentation for __call__ for the structure of init_values.
454      init_values = [True, pfor_input.pfor.all_indices] + inputs + output_tas
455      # TODO(agarwal): try stricter shape invariants
456      shape_invariants = (
457          [tensor_shape.TensorShape(None),
458           tensor_shape.TensorShape(None)] + input_shape_invariants +
459          ta_shape_invariants)
460
461      return init_values, inputs_stacked, shape_invariants
462
463  def _process_cond_unstacked(self, conditions, indices, inputs, output_tas):
464    """Handles case when condition is unstacked.
465
466    Note that all iterations end together. So we don't need to partition the
467    inputs. When all iterations are done, we write the inputs to the
468    TensorArrays. Note that we only write to index 0 of output_tas. Since all
469    iterations end together, they can all be output together.
470    """
471    not_all_done = array_ops.reshape(conditions, [])
472    new_output_tas = []
473    # pylint: disable=cell-var-from-loop
474    for i, out_ta in enumerate(output_tas):
475      inp = inputs[i]
476      new_output_tas.append(
477          control_flow_ops.cond(not_all_done, lambda: out_ta,
478                                lambda: out_ta.write(0, inp)))
479    # pylint: enable=cell-var-from-loop
480    return not_all_done, indices, inputs, new_output_tas
481
482  def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked,
483                            output_tas):
484    num_outputs = len(self._outputs)
485    # Compute if all iterations are done.
486    not_all_done = math_ops.reduce_any(conditions)
487    conditions_int = math_ops.cast(conditions, dtypes.int32)
488    # Partition the indices.
489    done_indices, new_indices = data_flow_ops.dynamic_partition(
490        indices, conditions_int, 2)
491
492    new_inputs = []
493    new_output_tas = []
494    for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)):
495      # Partition the inputs.
496      if stacked:
497        done_inp, new_inp = data_flow_ops.dynamic_partition(
498            inp, conditions_int, 2)
499      else:
500        # TODO(agarwal): avoid this stacking. See TODO earlier in
501        # _process_cond_unstacked.
502        done_inp = _stack(inp, [array_ops.size(done_indices)]).t
503        new_inp = inp
504      new_inputs.append(new_inp)
505      # For iterations that are done, write them to TensorArrays.
506      if i < num_outputs:
507        out_ta = output_tas[i]
508        # Note that done_indices can be empty. done_inp should also be empty in
509        # that case.
510        new_output_tas.append(out_ta.scatter(done_indices, done_inp))
511    return not_all_done, new_indices, new_inputs, new_output_tas
512
513  def _process_body(self, pfor_input, inputs_stacked, new_indices, cond_stacked,
514                    new_inputs, not_all_done):
515    """Convert the body function."""
516
517    def true_fn(control_inputs, body_pfor, body_output, stacked):
518      """Converts the body function for all but last iteration.
519
520      This essentially converts body_output. Additionally, it needs to handle
521      any control dependencies on the NextIteration node. So it creates another
522      Identity node with the converted dependencies.
523      """
524      converted_control_inp = []
525      for x in control_inputs:
526        for t in x.outputs:
527          converted_control_inp.append(body_pfor._convert_helper(t).t)
528      if stacked:
529        # Note convert always does the stacking.
530        output = body_pfor.convert(body_output)
531      else:
532        output, convert_stacked, _ = body_pfor._convert_helper(body_output)
533        assert convert_stacked == stacked, body_output
534      with ops.control_dependencies(converted_control_inp):
535        return array_ops.identity(output)
536
537    body_pfor = self._init_pfor(pfor_input.pfor, new_indices, cond_stacked,
538                                new_inputs, inputs_stacked)
539    new_outputs = []
540
541    for i, (body_output,
542            stacked) in enumerate(zip(self._body_outputs, inputs_stacked)):
543      control_inp = self._next_iter_control_inputs[i]
544      out_dtype = body_output.dtype
545      # Note that we want to run the body only if not all pfor iterations are
546      # done. If all are done, we return empty tensors since these values will
547      # not be used. Notice that the value returned by the loop is based on
548      # TensorArrays and not directly on these returned values.
549      # pylint: disable=cell-var-from-loop
550      new_output = control_flow_ops.cond(
551          not_all_done,
552          lambda: true_fn(control_inp, body_pfor, body_output, stacked),
553          lambda: constant_op.constant([], dtype=out_dtype))
554      # pylint: enable=cell-var-from-loop
555      new_outputs.append(new_output)
556    return new_outputs
557
558  def __call__(self, pfor_input):
559    """Converter for the while_loop.
560
561    The conversion of a while_loop is another while_loop.
562
563    The arguments to this converted while_loop are as follows:
564    not_all_done: Boolean scalar Tensor indicating if all the pfor iterations
565      are done.
566    indices: int32 1-D Tensor storing the id of the iterations that are not
567      done.
568    args: Remaining arguments. These can be divided into 3 categories:
569      - First set of arguments are the tensors that correspond to the initial
570        elements of self._enters. The elements that appear in original while
571        loop's `loop_vars`.
572      - The second set of arguments are the tensors that correspond to the
573        remaining elements of self._enters. These are the tensors that directly
574        enter the original while loop body.
575       - Finally, the last set of arguments are TensorArrays. These TensorArrays
576         correspond to the outputs of the original while_loop, i.e. to the
577         elements in self._outputs. Each TensorArray has `PFor.loop_len`
578         elements, i.e. the number of pfor iterations. At the end, the i'th
579         element of each TensorArray will contain the output computed by the
580         i'th iteration of pfor. Note that elements can be written into these
581         tensors arrays in any order, depending on when the corresponding pfor
582         iteration is done.
583      If the original while_loop had `k` tensors in its `loop_vars` and its body
584      directly captured `m` tensors, the `args` will contain `2 * k + m` values.
585
586    In each iteration, the while_loop body recomputes the condition for all
587    active pfor iterations to see which of them are now done. It then partitions
588    all the inputs and passes them along to the converted body. Values for all
589    the iterations that are done are written to TensorArrays indexed by the pfor
590    iteration number. When all iterations are done, the TensorArrays are stacked
591    to get the final value.
592
593    Args:
594      pfor_input: A PForInput object corresponding to the output of any Exit
595        node from this while loop.
596
597    Returns:
598      List of converted outputs.
599    """
600    # Create init_values that will be passed to the while_loop.
601    init_values, inputs_stacked, shape_invariants = self._create_init_values(
602        pfor_input)
603    # Note that we use a list as a hack since we need the nested function body
604    # to set the value of cond_is_stacked. python2.x doesn't support nonlocal
605    # variables.
606    cond_is_stacked = [None]
607
608    def cond(not_all_done, *_):
609      return not_all_done
610
611    def body(not_all_done, indices, *args):
612      # See documentatin for __call__ for the structure of *args.
613      num_enters = len(self._enters)
614      inputs = args[:num_enters]
615      output_tas = args[num_enters:]
616      # TODO(agarwal): see which outputs have consumers and only populate the
617      # TensorArrays corresponding to those. Or do those paths get trimmed out
618      # from inside the while_loop body?
619      assert len(inputs) >= len(output_tas)
620      assert len(inputs) == len(inputs_stacked)
621
622      # Convert condition
623      with ops.name_scope("while_cond"):
624        # Note that we set cond_stacked to True here. At this point we don't
625        # know if it could be loop invariant, hence the conservative value is
626        # to assume stacked.
627        cond_pfor = self._init_pfor(
628            pfor_input.pfor,
629            indices,
630            cond_stacked=True,
631            inputs=inputs,
632            inputs_stacked=inputs_stacked)
633        conditions, cond_stacked, _ = cond_pfor._convert_helper(self._condition)
634        cond_is_stacked[0] = cond_stacked
635
636      # Recompute the new condition, write outputs of done iterations, and
637      # partition the inputs if needed.
638      if not cond_stacked:
639        (not_all_done, new_indices, new_inputs,
640         new_output_tas) = self._process_cond_unstacked(conditions, indices,
641                                                        inputs, output_tas)
642      else:
643        (not_all_done, new_indices, new_inputs,
644         new_output_tas) = self._process_cond_stacked(conditions, indices,
645                                                      inputs, inputs_stacked,
646                                                      output_tas)
647
648      # Convert body
649      with ops.name_scope("while_body"):
650        #  Compute the outputs from the body.
651        new_outputs = self._process_body(pfor_input, inputs_stacked,
652                                         new_indices, cond_stacked, new_inputs,
653                                         not_all_done)
654
655      # Note that the first num_outputs new values of inputs are computed using
656      # the body. Rest of them were direct Enters into the condition/body and
657      # the partitioning done earlier is sufficient to give the new value.
658      num_outputs = len(self._outputs)
659      new_args = ([not_all_done, new_indices] + new_outputs +
660                  list(new_inputs[num_outputs:]) + new_output_tas)
661      return tuple(new_args)
662
663    while_outputs = control_flow_ops.while_loop(
664        cond, body, init_values, shape_invariants=shape_invariants)
665    output_tas = while_outputs[-len(self._outputs):]
666    outputs = []
667    assert cond_is_stacked[0] is not None
668    for inp_stacked, ta in zip(inputs_stacked, output_tas):
669      if cond_is_stacked[0]:
670        outputs.append(wrap(ta.stack(), True))
671      else:
672        # Note that if while_loop condition is unstacked, all iterations exit at
673        # the same time and we wrote those outputs in index 0 of the tensor
674        # array.
675        outputs.append(wrap(ta.read(0), inp_stacked))
676    return outputs
677
678
679class _PforInput(object):
680  """Input object passed to registered pfor converters."""
681
682  def __init__(self, pfor, op, inputs):
683    """Creates a _PforInput object.
684
685    Args:
686      pfor: PFor converter object.
687      op: the Operation object that is being converted.
688      inputs: list of WrappedTensor objects representing converted values of the
689        inputs of `op`.
690    """
691    self.pfor = pfor
692    self._op = op
693    self._inputs = inputs
694
695  def stack_inputs(self, stack_indices=None):
696    """Stacks unstacked inputs at `stack_indices`.
697
698    Args:
699      stack_indices: indices of inputs at which stacking is done. If None,
700        stacking is done at all indices.
701    """
702    if stack_indices is None:
703      stack_indices = range(len(self._inputs))
704    length = self.pfor.loop_len_vector
705    for i in stack_indices:
706      inp = self._inputs[i]
707      if not inp.is_stacked:
708        self._inputs[i] = _stack(inp.t, length)
709
710  def expanddim_inputs_for_broadcast(self):
711    """Reshapes stacked inputs to prepare them for broadcast.
712
713    Since stacked inputs have an extra leading dimension, automatic broadcasting
714    rules could incorrectly try to expand dimensions before that leading
715    dimension. To avoid that, we reshape these stacked inputs to the maximum
716    rank they will need to be broadcasted to.
717    """
718    if not self._inputs:
719      return
720
721    # Find max rank
722    def _get_rank(x):
723      rank = array_ops.rank(x.t)
724      if not x.is_stacked:
725        rank += 1
726      return rank
727
728    ranks = [_get_rank(x) for x in self._inputs]
729    max_rank = ranks[0]
730    for rank in ranks[1:]:
731      max_rank = math_ops.maximum(rank, max_rank)
732
733    for i, inp in enumerate(self._inputs):
734      if inp.is_stacked:
735        shape = array_ops.shape(inp.t)
736        rank_diff = array_ops.reshape(max_rank - ranks[i], [1])
737        ones = array_ops.tile([1], rank_diff)
738        new_shape = array_ops.concat([shape[:1], ones, shape[1:]], axis=0)
739        self._inputs[i] = wrap(array_ops.reshape(inp.t, new_shape), True)
740
741  @property
742  def inputs(self):
743    return self._inputs
744
745  @property
746  def num_inputs(self):
747    return len(self._inputs)
748
749  def input(self, index):
750    assert len(self._inputs) > index, (index, self._inputs)
751    return self._inputs[index]
752
753  def stacked_input(self, index):
754    t, is_stacked, _ = self.input(index)
755    if not is_stacked:
756      op_type = self.op_type
757      op_def = getattr(self._op, "op_def", None)
758      if op_def is None:
759        input_name = "at index %d" % index
760      else:
761        input_name = "\"%s\"" % op_def.input_arg[index].name
762      raise ValueError(
763          "Input %s of op \"%s\" expected to be not loop invariant" % (
764              input_name, op_type))
765    return t
766
767  def unstacked_input(self, index):
768    t, is_stacked, _ = self.input(index)
769    if is_stacked:
770      op_type = self.op_type
771      op_def = getattr(self._op, "op_def", None)
772      if op_def is None:
773        input_name = "at index %d" % index
774      else:
775        input_name = "\"%s\"" % op_def.input_arg[index].name
776      raise ValueError("Input %s of op \"%s\" expected to be loop invariant" % (
777          input_name, op_type))
778    return t
779
780  @property
781  def op(self):
782    return self._op
783
784  @property
785  def op_type(self):
786    return self._op.type
787
788  def get_attr(self, attr):
789    return self._op.get_attr(attr)
790
791  @property
792  def outputs(self):
793    return self._op.outputs
794
795  def output(self, index):
796    assert index < len(self._op.outputs)
797    return self._op.outputs[index]
798
799
800_pfor_converter_registry = {}
801
802
803class RegisterPFor(object):
804  """Utility to register converters for pfor.
805
806  Usage:
807  @RegisterPFor(foo_op_type)
808  def _foo_converter(pfor_input):
809    ...
810
811  The above will register conversion function `_foo_converter` for handling
812  conversion of `foo_op_type`. These converters are called during vectorization
813  of a `pfor` loop body. For each operation node in this loop body,
814  the vectorization process will call the converter corresponding to the
815  operation type of the node.
816
817  During conversion, the registered function will be called with a single
818  argument `pfor_input`, of type `PForInput`, which will contain state needed
819  for the conversion.  When the converter is called for a node, all its inputs
820  should already have been converted and these converted values are stored in
821  `pfor_input.inputs`.  This registered function should output a list of
822  WrappedTensor objects with the same length as the number of outputs of the
823  node being converted. If the node had zero outputs, then it should return an
824  ops.Operation object.  These new sets of nodes should implement the
825  functionality of running that operation for the number of iterations specified
826  by `pfor_input.pfor.loop_len_vector[0]` where the inputs of the node for each
827  iteration are picked from `pfor_inputs.inputs()`.
828
829  One tricky aspect of the conversion process is keeping track of, and
830  leveraging loop invariance of computation. Each converted input is a
831  WrappedTensor which indicates whether the input was loop invariant or not. If
832  the converted value is loop invariant, its rank should match the rank of the
833  corresponding tensor in the loop body, else its rank is larger by 1. The
834  converter should look at the loop invariance of the inputs and generate new
835  nodes based on that. Note that the converter will not be called if all inputs
836  are loop invariant and the operation is not stateful. The converter should
837  determine if its own output is loop invariant and `wrap` its output
838  accordingly.
839
840  Example:
841
842  Here, the converter is trying to convert a Reshape node in the loop body. This
843  node will have two inputs: the tensor to reshape, and the new shape.  The
844  example here only handles the case where the shape is loop invariant.
845
846  @RegisterPFor("Reshape")
847  def _convert_reshape(pfor_input):
848    # We assume that input is not loop invariant. Call to `stacked_input`
849    # asserts that and returns the converted value. This value will have a rank
850    # larger by 1 compared to the rank of the input in the loop body.
851    t = pfor_input.stacked_input(0)
852
853    # We assume that shape input is loop invariant. Call to `unstacked_input`
854    # asserts that and returns the converted value.
855    shape = pfor_input.unstacked_input(1)
856
857    # We compute `new_shape` by prepending the number of iterations to the
858    # original shape.
859    new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape],
860                                 axis=0)
861
862    # The vectorized output involves reshaping the converted input `t` using
863    # `new_shape`.
864    new_output = array_ops.reshape(t, new_shape)
865
866    # The converted output is marked as not loop invariant using the call to
867    # wrap.
868    return wrap(new_output, True)
869  """
870
871  def __init__(self, op_type):
872    """Creates an object to register a converter for op with type `op_type`."""
873    self.op_type = op_type
874
875  def __call__(self, converter):
876    name = self.op_type
877    assert name not in _pfor_converter_registry, "Re-registering %s " % name
878    _pfor_converter_registry[name] = converter
879    return converter
880
881
882class RegisterPForWithArgs(RegisterPFor):
883  """Utility to register converters for pfor.
884
885  Usage:
886  @RegisteRPFor(foo_op_type, foo=value, ....)
887  def _foo_converter(pfor_input, foo=None, ....):
888    ...
889
890  See RegisterPFor for details on the conversion function.
891  `RegisterPForWithArgs` allows binding extra arguments to the
892  conversion function at registration time.
893  """
894
895  def __init__(self, op_type, *args, **kw_args):
896    super(RegisterPForWithArgs, self).__init__(op_type)
897    self._args = args
898    self._kw_args = kw_args
899
900  def __call__(self, converter):
901
902    def _f(pfor_input):
903      return converter(pfor_input, self.op_type, *self._args, **self._kw_args)
904
905    super(RegisterPForWithArgs, self).__call__(_f)
906    return converter
907
908
909# TODO(agarwal): call raw_ops instead of calling these low level routines.
910def _create_op(op_type, inputs, op_dtypes, attrs=None):
911  """Utility to create an op."""
912  op = ops.get_default_graph().create_op(
913      op_type, inputs, op_dtypes, attrs=attrs, compute_device=True)
914  flat_attrs = nest.flatten([(str(a), op.get_attr(str(a))) for a in attrs])
915  execute.record_gradient(op_type, op.inputs, tuple(flat_attrs), op.outputs[:])
916  return op
917
918
919WrappedTensor = collections.namedtuple("WrappedTensor",
920                                       ["t", "is_stacked", "is_sparse_stacked"])
921"""Wrapper around the result of a Tensor conversion.
922
923The additional fields are useful for keeping track of the conversion state as
924data flows through the ops in the loop body. For every op whose output is a
925Tensor, its converter should return either a WrappedTensor or a list of
926WrappedTensors.
927
928Args:
929  t: The converted tensor
930  is_stacked: True if the tensor is stacked, i.e. represents the results of all
931    the iterations of the loop, where each row i of the tensor corresponds to
932    that op's output on iteration i of the loop. False if the tensor is not
933    stacked, i.e. represents the result of the op on of a single iteration of
934    the loop, where the result does not vary between iterations.
935  is_sparse_stacked: True if the tensor corresponds to a component tensor
936    (indices, values, or dense_shape) of a sparse tensor, and has been logically
937    stacked via a sparse conversion.
938"""
939
940
941def wrap(tensor, is_stacked=True, is_sparse_stacked=False):
942  """Helper to create a WrappedTensor object."""
943  assert isinstance(is_stacked, bool)
944  assert isinstance(is_sparse_stacked, bool)
945  assert isinstance(tensor, ops.Tensor)
946  assert not is_sparse_stacked or is_stacked, ("If the wrapped tensor is "
947                                               "stacked via a sparse "
948                                               "conversion, it must also be "
949                                               "stacked.")
950  return WrappedTensor(tensor, is_stacked, is_sparse_stacked)
951
952
953def _fallback_converter(pfor_input):
954  logging.warn("Using a while_loop for converting %s", pfor_input.op_type)
955  output_dtypes = [x.dtype for x in pfor_input.outputs]
956  iters = pfor_input.pfor.loop_len_vector[0]
957
958  def while_body(i, *ta_list):
959    """Body of while loop."""
960    inputs = [
961        x[i, ...] if stacked else x for x, stacked, _ in pfor_input.inputs
962    ]
963    op_outputs = _create_op(
964        pfor_input.op_type,
965        inputs,
966        output_dtypes,
967        attrs=pfor_input.op.node_def.attr).outputs
968
969    outputs = []
970    for out, ta in zip(op_outputs, ta_list):
971      assert isinstance(out, ops.Tensor)
972      outputs.append(ta.write(i, array_ops.expand_dims(out, 0)))
973    return tuple([i + 1] + outputs)
974
975  ta_list = control_flow_ops.while_loop(
976      lambda i, *ta: i < iters, while_body, [0] +
977      [tensor_array_ops.TensorArray(dtype, iters) for dtype in output_dtypes
978      ])[1:]
979  return tuple([wrap(ta.concat(), True) for ta in ta_list])
980
981
982class PForConfig(object):
983  """A configuration object used to communicate with loop body function."""
984
985  def __init__(self):
986    # This may be set to the number of iterations.
987    self._maybe_iters = None
988    # Map from reduction node, created by `reduce`, to the bundle of reduction
989    # function and arguments.
990    self._reduce_map = {}
991
992  def _has_reductions(self):
993    """True if some reductions where performed by loop body."""
994    return len(self._reduce_map)
995
996  def _set_iters(self, iters):
997    """Set number of pfor iterations."""
998    self._maybe_iters = iters
999
1000  def reduce(self, fn, *args):
1001    """Performs reduction `fn` on `args` vectorized across pfor iterations.
1002
1003    Note that `fn` is traced once inside the loop function context. Hence any
1004    captures or side-effects will happen in that context. Call to the traced
1005    version of `fn` happens during the construction of the vectorized code.
1006
1007    Note that this currently may not work inside a control flow construct.
1008    Args:
1009      fn: a reduction function. It will be called with arguments that have the
1010        same structure as *args but with individual values whose rank may be
1011        higher by 1 since they represent loop invariant vectorized versions of
1012        the corresponding Tensors in *args.
1013      *args: unvectorized Tensors.
1014
1015    Returns:
1016      The result of running `fn` on the vectorized versions of `*args`. These
1017      outputs will be available as loop invariant values to all the iterations.
1018    """
1019    assert not context.executing_eagerly()
1020    # Creates a concrete function that will be used for reduction.
1021    tensor_specs = []
1022    for arg in args:
1023      if not isinstance(arg, ops.Tensor):
1024        raise ValueError("Got a non-Tensor argument %s in reduce" % arg)
1025      batched_shape = tensor_shape.TensorShape(
1026          [self._maybe_iters]).concatenate(arg.shape)
1027      tensor_specs.append(
1028          tensor_spec.TensorSpec(shape=batched_shape, dtype=arg.dtype))
1029    concrete_function = def_function.function(fn).get_concrete_function(
1030        *tensor_specs)
1031
1032    # Creates PlaceholderWithDefault and IdentityN nodes corresponding the the
1033    # reduction.
1034    pl_outputs = []
1035    with ops.control_dependencies(args):
1036      for output in concrete_function.outputs:
1037        if not isinstance(output, ops.Tensor):
1038          raise ValueError("Got a non-Tensor output %s while running reduce" %
1039                           output)
1040        # Note that we use placeholder_with_default just to make XLA happy since
1041        # it does not like placeholder ops.
1042        if output.shape.is_fully_defined():
1043          dummy = array_ops.zeros(output.shape.as_list(), dtype=output.dtype)
1044          pl_outputs.append(
1045              array_ops.placeholder_with_default(dummy, shape=output.shape))
1046        else:
1047          # TODO(agarwal): support case when under XLA and output.shape is not
1048          # fully defined.
1049          pl_outputs.append(
1050              array_ops.placeholder(output.dtype, shape=output.shape))
1051
1052      reduction_op = array_ops.identity_n(pl_outputs)[0].op
1053    self._reduce_map[reduction_op] = (concrete_function, args)
1054    if len(reduction_op.outputs) == 1:
1055      return reduction_op.outputs[0]
1056    else:
1057      return tuple(reduction_op.outputs)
1058
1059  # TODO(agarwal): handle reductions inside control flow constructs.
1060  def reduce_concat(self, x):
1061    """Performs a concat reduction on `x` across pfor iterations.
1062
1063    Note that this currently may not work inside a control flow construct.
1064    Args:
1065      x: an unvectorized Tensor.
1066
1067    Returns:
1068      A Tensor that has rank one higher than `x`. The value is the vectorized
1069      version of `x`, i.e. stacking the value of `x` across different pfor
1070      iterations.
1071    """
1072    return self.reduce(lambda y: y, x)
1073
1074  def reduce_mean(self, x):
1075    """Performs a mean reduction on `x` across pfor iterations.
1076
1077    Note that this currently may not work inside a control flow construct.
1078    Args:
1079      x: an unvectorized Tensor.
1080
1081    Returns:
1082      A Tensor that has same rank as `x`. The value is the mean of the values
1083      of `x` across the pfor iterations.
1084    """
1085    return self.reduce(lambda y: math_ops.reduce_mean(y, axis=0), x)
1086
1087  def reduce_sum(self, x):
1088    """Performs a sum reduction on `x` across pfor iterations.
1089
1090    Note that this currently may not work inside a control flow construct.
1091    Args:
1092      x: an unvectorized Tensor.
1093
1094    Returns:
1095      A Tensor that has same rank as `x`. The value is the sum of the values
1096      of `x` across the pfor iterations.
1097    """
1098    return self.reduce(lambda y: math_ops.reduce_sum(y, axis=0), x)
1099
1100  def _lookup_reduction(self, t):
1101    """Lookups Tensor `t` in the reduction maps."""
1102    assert isinstance(t, ops.Tensor), t
1103    return self._reduce_map.get(t.op)
1104
1105
1106class PFor(object):
1107  """Implementation of rewrite of parallel-for loops.
1108
1109  This class takes a DAG or a set of DAGs representing the body of a
1110  parallel-for loop, and adds new operations to the graph that implements
1111  functionality equivalent to running that loop body for a specified number of
1112  iterations. This new set of nodes may or may not use a tensorflow loop
1113  construct.
1114
1115  The process of conversion does not delete or change any existing operations.
1116  It only adds operations that efficiently implement the equivalent
1117  functionality. We refer to the added ops as "converted ops".
1118
1119  The conversion process uses a simple greedy heuristic. It walks the loop body
1120  and tries to express the functionality of running each node in a loop with a
1121  new set of nodes. When converting an op several cases are possible:
1122  - The op is not inside the loop body. Hence it can be used as is.
1123  - The op does not depend on the iteration number and is stateless. In this
1124    case, it can be used as is.
1125  - The op is not stateful, and depends on iteration number only through control
1126    dependencies. In this case, we can create a single op with same inputs and
1127    attributes, but with "converted" control dependencies.
1128  - The op is not stateful, and all its inputs are loop invariant. In this
1129    case, similar to above, we can create a single op with same inputs and
1130    attributes, but with "converted" control dependencies.
1131  - The op is stateful or at least one of the inputs is not loop invariant. In
1132    this case, we run the registered converter for that op to create a set of
1133    converted ops. All nodes in the set will have converted control dependencies
1134    corresponding to control dependencies of the original op. If the op returned
1135    multiple outputs, "converted outputs" could be produced by different ops in
1136    this set.
1137  """
1138
1139  def __init__(self,
1140               loop_var,
1141               loop_len,
1142               pfor_ops,
1143               all_indices=None,
1144               all_indices_partitioned=False,
1145               pfor_config=None):
1146    """Creates an object to rewrite a parallel-for loop.
1147
1148    Args:
1149      loop_var: ops.Tensor output of a Placeholder operation. The value should
1150        be an int32 scalar representing the loop iteration number.
1151      loop_len: A scalar or scalar Tensor representing the number of iterations
1152        the loop is run for.
1153      pfor_ops: List of all ops inside the loop body.
1154      all_indices: If not None, an int32 vector with size `loop_len`
1155        representing the iteration ids that are still active. These values
1156        should be unique and sorted. However they may not be contiguous. This is
1157        typically the case when inside a control flow construct which has
1158        partitioned the indices of the iterations that are being converted.
1159      all_indices_partitioned: If True, this object is being constructed from a
1160        control flow construct where not all the pfor iterations are guaranteed
1161        to be active.
1162      pfor_config: PForConfig object used while constructing the loop body.
1163    """
1164    assert isinstance(loop_var, ops.Tensor)
1165    assert loop_var.op.type == "PlaceholderWithDefault"
1166    self._loop_var = loop_var
1167    loop_len_value = tensor_util.constant_value(loop_len)
1168    if loop_len_value is not None:
1169      loop_len = loop_len_value
1170    self._loop_len_vector = array_ops.reshape(loop_len, [1])
1171    self._all_indices_partitioned = all_indices_partitioned
1172    if all_indices_partitioned:
1173      assert all_indices is not None
1174    self.all_indices = (
1175        math_ops.range(loop_len) if all_indices is None else all_indices)
1176
1177    self._conversion_map = object_identity.ObjectIdentityDictionary()
1178    self._conversion_map[loop_var] = wrap(self.all_indices, True)
1179    self._pfor_ops = set(pfor_ops)
1180    self._pfor_op_ids = set(x._id for x in pfor_ops)
1181    self._pfor_config = pfor_config
1182
1183  def op_is_inside_loop(self, op):
1184    """True if op was created inside the pfor loop body."""
1185    assert isinstance(op, ops.Operation)
1186    # Note that we use self._pfor_op_ids for the check and not self._pfor_ops
1187    # since it appears there tensorflow API could return different python
1188    # objects representing the same Operation node.
1189    return op._id in self._pfor_op_ids
1190
1191  def _convert_sparse(self, y):
1192    """Returns the converted value corresponding to SparseTensor y.
1193
1194    For SparseTensors, instead of stacking the component tensors separately,
1195    resulting in component tensors with shapes (N, m, rank), (N, m), and (N,
1196    rank) respectively for indices, values, and dense_shape (where N is the loop
1197    length and m is the number of sparse tensor values per loop iter), we want
1198    to logically stack the SparseTensors, to create a SparseTensor whose
1199    components are size (N * m, rank + 1), (N * m, ), and (rank + 1,)
1200    respectively.
1201
1202    Here, we try to get the conversion of each component tensor.
1203    If the tensors are stacked via a sparse conversion, return the resulting
1204    SparseTensor composed of the converted components. Otherwise, the component
1205    tensors are either unstacked or stacked naively. In the latter case, we
1206    unstack the component tensors to reform loop_len SparseTensor elements,
1207    then correctly batch them.
1208
1209    The unstacked tensors must have the same rank. Each dimension of each
1210    SparseTensor will expand to be the largest among all SparseTensor elements
1211    for that dimension. For example, if there are N SparseTensors of rank 3
1212    being stacked, with N dense shapes, where the i_th shape is (x_i, y_i, z_i),
1213    the new dense shape will be (N, max_i(x_i), max_i(y_i), max_i(z_i)).
1214
1215    Args:
1216      y: A tf.SparseTensor.
1217
1218    Returns:
1219      A tf.SparseTensor that is the converted value corresponding to y.
1220    """
1221    outputs = [
1222        self._convert_helper(t) for t in (y.indices, y.values, y.dense_shape)
1223    ]
1224    assert all(isinstance(o, WrappedTensor) for o in outputs)
1225
1226    if all(w.is_sparse_stacked for w in outputs):
1227      return sparse_tensor.SparseTensor(*[w.t for w in outputs])
1228
1229    assert not any(w.is_sparse_stacked for w in outputs), (
1230        "Error converting SparseTensor. All components should be logically "
1231        "stacked, or none.")
1232
1233    # If component tensors were not sparsely stacked, they are either unstacked
1234    # or stacked without knowledge that they are components of sparse tensors.
1235    # In this case, we have to restack them.
1236    return self._restack_sparse_tensor_logically(
1237        *[self._unwrap_or_tile(w) for w in outputs])
1238
1239  def _restack_sparse_tensor_logically(self, indices, values, shape):
1240    sparse_tensor_rank = indices.get_shape().dims[-1].value
1241    if sparse_tensor_rank is not None:
1242      sparse_tensor_rank += 1
1243
1244    def fn(args):
1245      res = gen_sparse_ops.serialize_sparse(
1246          args[0], args[1], args[2], out_type=dtypes.variant)
1247      return res
1248
1249    # Applies a map function to the component tensors to serialize each
1250    # sparse tensor element and batch them all, then deserializes the batch.
1251    # TODO(rachelim): Try to do this without map_fn -- add the right offsets
1252    # to shape and indices tensors instead.
1253    result = map_fn.map_fn(fn, [indices, values, shape], dtype=dtypes.variant)
1254    return sparse_ops.deserialize_sparse(
1255        result, dtype=values.dtype, rank=sparse_tensor_rank)
1256
1257  def _unwrap_or_tile(self, wrapped_tensor):
1258    """Given a wrapped tensor, unwrap if stacked. Otherwise, tiles it."""
1259    output, is_stacked = wrapped_tensor.t, wrapped_tensor.is_stacked
1260    if is_stacked:
1261      return output
1262    else:
1263      return _stack(output, self._loop_len_vector).t
1264
1265  def convert(self, y):
1266    """Returns the converted value corresponding to y.
1267
1268    Args:
1269      y: A ops.Tensor or a ops.Operation object. If latter, y should not have
1270        any outputs.
1271
1272    Returns:
1273      If y does not need to be converted, it returns y as is. Else it returns
1274      the "converted value" corresponding to y.
1275    """
1276    if y is None:
1277      return None
1278    if isinstance(y, sparse_tensor.SparseTensor):
1279      return self._convert_sparse(y)
1280    assert isinstance(y, (ops.Tensor, ops.Operation)), y
1281    output = self._convert_helper(y)
1282    if isinstance(output, WrappedTensor):
1283      assert isinstance(y, ops.Tensor)
1284      return self._unwrap_or_tile(output)
1285    else:
1286      assert isinstance(y, ops.Operation)
1287      assert not y.outputs
1288      assert isinstance(output, ops.Operation)
1289    return output
1290
1291  def _was_converted(self, t):
1292    """True if t is not a conversion of itself."""
1293    converted_t = self._conversion_map[t]
1294    return converted_t.t is not t
1295
1296  def _add_conversion(self, old_output, new_output):
1297    assert isinstance(old_output, (ops.Tensor, ops.Operation)), old_output
1298    assert isinstance(new_output, (WrappedTensor, ops.Operation)), new_output
1299    self._conversion_map[old_output] = new_output
1300
1301  def _convert_reduction(self, y):
1302    # Handle reductions.
1303    if self._pfor_config is None:
1304      return None
1305    reduction = self._pfor_config._lookup_reduction(y)
1306    if reduction is None:
1307      return None
1308    (reduction_fn, reduction_args) = reduction
1309    batched_args = []
1310    for reduction_arg in reduction_args:
1311      assert isinstance(reduction_arg, ops.Tensor), reduction_arg
1312      # Tensor being reduced should already be converted due to a control
1313      # dependency on the created placeholder.
1314      # Note that in cases where reduction_arg is in an outer context, one
1315      # needs to locate the corresponding Enter node and use that to lookup
1316      # the conversion.
1317      # TODO(agarwal): handle reductions inside control flow constructs.
1318      assert reduction_arg in self._conversion_map, (
1319          "Unable to handle reduction of %s, possibly as it was used "
1320          "inside a control flow construct. Note that reductions across "
1321          "pfor iterations are currently not supported inside control flow "
1322          "constructs." % reduction_arg)
1323      batched_arg = self._conversion_map[reduction_arg]
1324      batched_args.append(self._unwrap_or_tile(batched_arg))
1325    outputs = reduction_fn(*batched_args)
1326    return [wrap(output, False) for output in nest.flatten(outputs)]
1327
1328  def _convert_helper(self, op_or_tensor):
1329    stack = [op_or_tensor]
1330    while stack:
1331      y = stack[0]
1332      if y in self._conversion_map:
1333        assert isinstance(self._conversion_map[y],
1334                          (WrappedTensor, ops.Operation))
1335        stack.pop(0)
1336        continue
1337      if isinstance(y, ops.Operation):
1338        assert not y.outputs, (
1339            "We only support converting Operation objects with no outputs. "
1340            "Got %s", y)
1341        y_op = y
1342      else:
1343        assert isinstance(y, ops.Tensor), y
1344        y_op = y.op
1345
1346      is_while_loop = y_op.type == "Exit"
1347      if is_while_loop:
1348        while_op = WhileOp(
1349            y, pfor_ops=self._pfor_ops, pfor_config=self._pfor_config)
1350        is_inside_loop = while_op.is_inside_loop
1351        # If all nodes in the while_loop graph were created inside the pfor, we
1352        # treat the whole loop subgraph as a single op (y_op) and try to convert
1353        # it. For while_loops that are created completely or partially outside,
1354        # we treat them as external and should be able to simply return the Exit
1355        # node output as is without needing any conversion. Note that for
1356        # while_loops that are partially constructed inside, we assume they will
1357        # be loop invariant. If that is not the case, it will create runtime
1358        # errors since the converted graph would depend on the self._loop_var
1359        # placeholder.
1360        if is_inside_loop:
1361          y_op = while_op
1362      else:
1363        is_inside_loop = self.op_is_inside_loop(y_op)
1364
1365      # If this op was not created inside the loop body, we will return as is.
1366      # 1. Convert inputs and control inputs.
1367
1368      def _add_to_stack(x):
1369        if x not in self._conversion_map:
1370          stack.insert(0, x)
1371          return True
1372        else:
1373          return False
1374
1375      if is_inside_loop:
1376        added_to_stack = False
1377        for inp in y_op.inputs:
1378          added_to_stack |= _add_to_stack(inp)
1379        for cinp in y_op.control_inputs:
1380          if cinp.outputs:
1381            for t in cinp.outputs:
1382              added_to_stack |= _add_to_stack(t)
1383          else:
1384            added_to_stack |= _add_to_stack(cinp)
1385        if added_to_stack:
1386          continue
1387
1388        converted_inputs = [self._conversion_map[inp] for inp in y_op.inputs]
1389        some_input_converted = any(self._was_converted(x) for x in y_op.inputs)
1390        some_input_stacked = any(x.is_stacked for x in converted_inputs)
1391
1392        converted_control_ops = set()
1393        some_control_input_converted = False
1394        for cinp in y_op.control_inputs:
1395          if cinp.outputs:
1396            for t in cinp.outputs:
1397              converted_t = self._conversion_map[t]
1398              if self._was_converted(t):
1399                some_control_input_converted = True
1400              converted_control_ops.add(converted_t.t.op)
1401          else:
1402            converted_cinp = self._conversion_map[cinp]
1403            assert isinstance(converted_cinp, ops.Operation)
1404            if converted_cinp != cinp:
1405              some_control_input_converted = True
1406            converted_control_ops.add(converted_cinp)
1407        converted_control_ops = list(converted_control_ops)
1408        is_stateful = _is_stateful_pfor_op(y_op)
1409      else:
1410        converted_inputs = []
1411        converted_control_ops = []
1412      logging.vlog(3, "converting op:%s\ninputs:%s\ncontrol_inputs:%s", y_op,
1413                   converted_inputs, converted_control_ops)
1414
1415      # 2. Convert y_op
1416      # If converting a while_loop, we let the while_loop convertor deal with
1417      # putting the control dependencies appropriately.
1418      control_dependencies = [] if is_while_loop else converted_control_ops
1419      with ops.control_dependencies(control_dependencies), ops.name_scope(
1420          y_op.name + "/pfor/"), ops.get_default_graph()._original_op(y_op):
1421        # Op is a placeholder for a reduction.
1422        reduce_output = self._convert_reduction(y)
1423        if reduce_output is not None:
1424          new_outputs = reduce_output
1425        # None of the inputs and control inputs were converted.
1426        elif ((not is_inside_loop or
1427               (not is_stateful and not some_input_converted and
1428                not some_control_input_converted)) and
1429              y.graph == ops.get_default_graph()):
1430          if y is y_op:
1431            assert not isinstance(y_op, WhileOp)
1432            new_outputs = y_op
1433          else:
1434            new_outputs = [wrap(x, False) for x in y_op.outputs]
1435        elif not (is_stateful or is_while_loop or some_input_stacked):
1436          # All inputs are unstacked or uncoverted but some control inputs are
1437          # converted.
1438          # TODO(rachelim): Handle the case where some inputs are sparsely
1439          # stacked (i.e. any(x.is_sparse_stacked for x in converted_inputs))
1440          new_op = _create_op(y_op.type, [x.t for x in converted_inputs],
1441                              [x.dtype for x in y_op.outputs],
1442                              y_op.node_def.attr)
1443          if y is y_op:
1444            new_outputs = new_op
1445          else:
1446            new_outputs = [wrap(x, False) for x in new_op.outputs]
1447        else:
1448          # Either some inputs are not loop invariant or op is stateful.
1449          if hasattr(y_op, "pfor_converter"):
1450            converter = y_op.pfor_converter
1451          else:
1452            converter = _pfor_converter_registry.get(y_op.type, None)
1453          if converter is None:
1454            if flags.FLAGS.op_conversion_fallback_to_while_loop:
1455              converter = _fallback_converter
1456            else:
1457              raise ValueError("No converter defined for %s\n%s\ninputs: %s. "
1458                               "\nEither add a converter or set "
1459                               "--op_conversion_fallback_to_while_loop=True, "
1460                               "which may run slower" %
1461                               (y_op.type, y_op, converted_inputs))
1462          # TODO(rachelim): Handle the case where some inputs are sparsely
1463          # stacked. We should only call the converter if it supports handling
1464          # those inputs.
1465          pfor_inputs = _PforInput(self, y_op, converted_inputs)
1466          try:
1467            new_outputs = converter(pfor_inputs)
1468          except Exception as e:  # pylint: disable=broad-except
1469            logging.error("Got error while pfor was converting op %s"
1470                          "with inputs %s\n, converted inputs %s\n"
1471                          "%s\n"
1472                          "Here are the pfor conversion stack traces:" % (
1473                              y_op,
1474                              y_op.inputs[:],
1475                              pfor_inputs.inputs,
1476                              str(e)))
1477            original_op = y_op
1478            while isinstance(original_op, ops.Operation):
1479              logging.error("%s\ncreated at:\n  %s" % (
1480                  original_op,
1481                  "  ".join(traceback.format_list(original_op.traceback))))
1482              original_op = original_op._original_op
1483            six.reraise(e.__class__, e, sys.exc_info()[2])
1484
1485          if isinstance(new_outputs, WrappedTensor):
1486            new_outputs = [new_outputs]
1487          assert isinstance(new_outputs,
1488                            (list, tuple, ops.Operation)), new_outputs
1489        logging.vlog(2, "converted %s %s", y_op, new_outputs)
1490
1491        # Insert into self._conversion_map
1492        if y is y_op:
1493          assert isinstance(new_outputs, ops.Operation)
1494          self._add_conversion(y_op, new_outputs)
1495        else:
1496          assert len(y_op.outputs) == len(new_outputs), (y_op, y_op.outputs,
1497                                                         new_outputs)
1498          for old_output, new_output in zip(y_op.outputs, new_outputs):
1499            assert isinstance(new_output, WrappedTensor), (new_output, y, y_op)
1500            assert old_output.dtype == new_output.t.dtype, (new_output, y, y_op)
1501            # Set shape for converted output.
1502            output_shape = old_output.shape
1503            if not new_output.is_sparse_stacked:
1504              if new_output.is_stacked:
1505                loop_len = tensor_util.constant_value(self.loop_len_vector)
1506                if loop_len is None:
1507                  batch_dim = tensor_shape.TensorShape([None])
1508                else:
1509                  batch_dim = tensor_shape.TensorShape(loop_len)
1510                output_shape = batch_dim.concatenate(output_shape)
1511              new_output.t.set_shape(output_shape)
1512            self._add_conversion(old_output, new_output)
1513        stack.pop(0)
1514
1515    return self._conversion_map[op_or_tensor]
1516
1517  @property
1518  def loop_len_vector(self):
1519    """Returns a single element vector whose value is number of iterations."""
1520    return self._loop_len_vector
1521
1522  @property
1523  def loop_var(self):
1524    """Returns placeholder loop variable."""
1525    return self._loop_var
1526
1527  @property
1528  def pfor_ops(self):
1529    return self._pfor_ops
1530
1531  @property
1532  def pfor_config(self):
1533    return self._pfor_config
1534
1535  @property
1536  def all_indices_partitioned(self):
1537    """all_indices_partitioned property.
1538
1539    Returns:
1540      True if we are inside a control flow construct and not all pfor iterations
1541      may be active.
1542    """
1543    return self._all_indices_partitioned
1544
1545
1546# The code below defines converters for different operations. Please see comment
1547# for RegisterPFor to see how converters should be defined.
1548
1549# nn_ops
1550
1551
1552def _flatten_first_two_dims(x):
1553  """Merges first two dimensions."""
1554  old_shape = array_ops.shape(x)
1555  new_shape = array_ops.concat([[-1], old_shape[2:]], axis=0)
1556  return array_ops.reshape(x, new_shape)
1557
1558
1559def _unflatten_first_dim(x, first_dim):
1560  """Splits first dimension into [first_dim, -1]."""
1561  old_shape = array_ops.shape(x)
1562  new_shape = array_ops.concat([first_dim, [-1], old_shape[1:]], axis=0)
1563  return array_ops.reshape(x, new_shape)
1564
1565
1566def _inputs_with_flattening(pfor_input, input_indices):
1567  """Stacks and flattens first dim of inputs at indices `input_indices`."""
1568  if input_indices is None:
1569    input_indices = []
1570  pfor_input.stack_inputs(stack_indices=input_indices)
1571  inputs = []
1572  for i in range(pfor_input.num_inputs):
1573    if i in input_indices:
1574      inp = pfor_input.stacked_input(i)
1575      inp = _flatten_first_two_dims(inp)
1576    else:
1577      inp = pfor_input.unstacked_input(i)
1578    inputs.append(inp)
1579  return inputs
1580
1581
1582@RegisterPForWithArgs("Conv2D", dims=[0])
1583@RegisterPForWithArgs("AvgPool", dims=[0])
1584@RegisterPForWithArgs("MaxPool", dims=[0])
1585@RegisterPForWithArgs("MaxPool3D", dims=[0])
1586@RegisterPForWithArgs("MaxPool3DGrad", dims=[0, 1, 2])
1587@RegisterPForWithArgs("MaxPoolGrad", dims=[0, 1, 2])
1588@RegisterPForWithArgs("MaxPool3DGradGrad", dims=[0, 1, 2])
1589@RegisterPForWithArgs("MaxPoolGradGrad", dims=[0, 1, 2])
1590@RegisterPForWithArgs("SoftmaxCrossEntropyWithLogits", dims=[0, 1])
1591def _convert_flatten_batch(pfor_input, op_type, dims):
1592  del op_type
1593  inputs = _inputs_with_flattening(pfor_input, dims)
1594  outputs = _create_op(
1595      pfor_input.op_type,
1596      inputs, [x.dtype for x in pfor_input.outputs],
1597      attrs=pfor_input.op.node_def.attr).outputs
1598  n = pfor_input.pfor.loop_len_vector
1599  outputs = [_unflatten_first_dim(x, n) for x in outputs]
1600  return [wrap(x, True) for x in outputs]
1601
1602
1603_channel_flatten_input_cache = {}
1604
1605
1606def _channel_flatten_input(x, data_format):
1607  """Merge the stack dimension with the channel dimension.
1608
1609  If S is pfor's stacking dimension, then,
1610    - for SNCHW, we transpose to NSCHW. If N dimension has size 1, the transpose
1611      should be cheap.
1612    - for SNHWC, we transpose to NHWCS.
1613  We then merge the S and C dimension.
1614
1615  Args:
1616    x: ops.Tensor to transform.
1617    data_format: "NCHW" or "NHWC".
1618
1619  Returns:
1620    A 3-element tuple with the transformed value, along with the shape for
1621    reshape and order for transpose required to transform back.
1622  """
1623
1624  graph = ops.get_default_graph()
1625  cache_key = (graph, x.experimental_ref(), data_format)
1626  if cache_key not in _channel_flatten_input_cache:
1627    x_shape = array_ops.shape(x)
1628    if data_format == b"NCHW":
1629      order = [1, 0, 2, 3, 4]
1630      shape = array_ops.concat([x_shape[1:2], [-1], x_shape[3:]], axis=0)
1631      reverse_order = order
1632    else:
1633      order = [1, 2, 3, 0, 4]
1634      shape = array_ops.concat([x_shape[1:4], [-1]], axis=0)
1635      reverse_order = [3, 0, 1, 2, 4]
1636    # Move S dimension next to C dimension.
1637    x = array_ops.transpose(x, order)
1638    reverse_shape = array_ops.shape(x)
1639    # Reshape to merge the S and C dimension.
1640    x = array_ops.reshape(x, shape)
1641    outputs = x, reverse_order, reverse_shape
1642    _channel_flatten_input_cache[cache_key] = outputs
1643  else:
1644    outputs = _channel_flatten_input_cache[cache_key]
1645  return outputs
1646
1647
1648# Note that with training=True, running FusedBatchNormV3 on individual examples
1649# is very different from running FusedBatchNormV3 on a batch of those examples.
1650# This is because, for the latter case, the operation can be considered as first
1651# computing the mean and variance over all the examples and then using these
1652# to scale all those examples. This creates a data dependency between these
1653# different "iterations" since the inputs to the scaling step depends on the
1654# statistics coming from all these inputs.
1655# As with other kernels, the conversion here effectively runs the kernel
1656# independently for each iteration, and returns outputs by stacking outputs from
1657# each of those iterations.
1658@RegisterPFor("FusedBatchNormV3")
1659def _convert_fused_batch_norm(pfor_input):
1660  is_training = pfor_input.get_attr("is_training")
1661  # When BatchNorm is used with training=False, mean and variance are provided
1662  # externally and used as is by the op. Thus, we can merge the S and N
1663  # dimensions as we do for regular operations.
1664  # When BatchNorm is used with training=True, mean and variance are computed
1665  # for each channel across the batch dimension (first one). If we merge S and N
1666  # dimensions, mean and variances will be computed over a larger set. So, we
1667  # merge the S and C dimensions instead.
1668  if not is_training:
1669    # We return zeros for batch_mean and batch_variance output. Note that CPU
1670    # and GPU seem to have different behavior for those two outputs. CPU outputs
1671    # zero because these values are not used during inference. GPU outputs
1672    # something, probably real means and variances.
1673    inputs = _inputs_with_flattening(pfor_input, [0])
1674    outputs = _create_op(
1675        pfor_input.op_type,
1676        inputs, [x.dtype for x in pfor_input.outputs],
1677        attrs=pfor_input.op.node_def.attr).outputs
1678    y = outputs[0]
1679    n = pfor_input.pfor.loop_len_vector
1680    y = _unflatten_first_dim(y, n)
1681    mean = pfor_input.unstacked_input(3)
1682    zeros = array_ops.zeros_like(mean)
1683    return [wrap(y, True)] + [wrap(zeros, False)] * 5
1684
1685  pfor_input.stack_inputs()
1686  data_format = pfor_input.get_attr("data_format")
1687  # We merge the first dimension with the "C" dimension, run FusedBatchNormV3,
1688  # and then transpose back.
1689  x = pfor_input.stacked_input(0)
1690  x, reverse_order, reverse_shape = _channel_flatten_input(x, data_format)
1691  # Note that we stack all the other inputs as well so that they are the same
1692  # size as the new size of the channel dimension.
1693  inputs = [x] + [
1694      array_ops.reshape(pfor_input.stacked_input(i), [-1])
1695      for i in range(1, pfor_input.num_inputs)
1696  ]
1697  outputs = _create_op(
1698      pfor_input.op_type,
1699      inputs, [x.dtype for x in pfor_input.outputs],
1700      attrs=pfor_input.op.node_def.attr).outputs
1701  y = outputs[0]
1702  y = array_ops.reshape(y, reverse_shape)
1703  y = array_ops.transpose(y, reverse_order)
1704  n = pfor_input.pfor.loop_len_vector
1705  outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]]
1706  outputs = [y] + outputs
1707  return [wrap(x, True) for x in outputs]
1708
1709
1710@RegisterPFor("FusedBatchNormGradV3")
1711def _convert_fused_batch_norm_grad(pfor_input):
1712  pfor_input.stack_inputs()
1713  data_format = pfor_input.get_attr("data_format")
1714  y_backprop = pfor_input.stacked_input(0)
1715  y_backprop, _, _ = _channel_flatten_input(y_backprop, data_format)
1716  x = pfor_input.stacked_input(1)
1717  x, x_reverse_order, x_reverse_shape = _channel_flatten_input(x, data_format)
1718  inputs = [y_backprop, x] + [
1719      array_ops.reshape(pfor_input.stacked_input(i), [-1])
1720      for i in range(2, pfor_input.num_inputs)
1721  ]
1722  outputs = _create_op(
1723      pfor_input.op_type,
1724      inputs, [x.dtype for x in pfor_input.outputs],
1725      attrs=pfor_input.op.node_def.attr).outputs
1726  x_backprop = outputs[0]
1727  x_backprop = array_ops.reshape(x_backprop, x_reverse_shape)
1728  x_backprop = array_ops.transpose(x_backprop, x_reverse_order)
1729  n = pfor_input.pfor.loop_len_vector
1730  outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]]
1731  outputs = [x_backprop] + outputs
1732  return [wrap(output, True) for output in outputs]
1733
1734
1735@RegisterPForWithArgs("Conv2DBackpropInput", flatten_dims=[2], shape_dim=0)
1736@RegisterPForWithArgs("AvgPoolGrad", flatten_dims=[1], shape_dim=0)
1737def _convert_flatten_batch_shape_input(pfor_input, op_type, flatten_dims,
1738                                       shape_dim):
1739  del op_type
1740  inputs = _inputs_with_flattening(pfor_input, flatten_dims)
1741  n = pfor_input.pfor.loop_len_vector
1742  # Adjust the `input_sizes` input.
1743  ones = array_ops.ones([array_ops.shape(inputs[shape_dim])[0] - 1],
1744                        dtype=n.dtype)
1745  inputs[shape_dim] *= array_ops.concat([n, ones], axis=0)
1746  outputs = _create_op(
1747      pfor_input.op_type,
1748      inputs, [x.dtype for x in pfor_input.outputs],
1749      attrs=pfor_input.op.node_def.attr).outputs
1750  outputs = [_unflatten_first_dim(x, n) for x in outputs]
1751  return [wrap(x, True) for x in outputs]
1752
1753
1754@RegisterPFor("Conv2DBackpropFilter")
1755def _convert_conv2d_backprop_filter(pfor_input):
1756  pfor_input.stack_inputs(stack_indices=[2])
1757  inputs, inputs_stacked, _ = pfor_input.input(0)
1758  filter_sizes = pfor_input.unstacked_input(1)
1759  grads = pfor_input.stacked_input(2)
1760  strides = pfor_input.get_attr("strides")
1761  padding = pfor_input.get_attr("padding")
1762  use_cudnn_on_gpu = pfor_input.get_attr("use_cudnn_on_gpu")
1763  data_format = pfor_input.get_attr("data_format")
1764  dilations = pfor_input.get_attr("dilations")
1765  if inputs_stacked:
1766    # TODO(agarwal): Implement this efficiently.
1767    logging.warn("Conv2DBackpropFilter uses a while_loop. Fix that!")
1768
1769    def while_body(i, ta):
1770      inp_i = inputs[i, ...]
1771      grad_i = grads[i, ...]
1772      output = nn_ops.conv2d_backprop_filter(
1773          inp_i,
1774          filter_sizes,
1775          grad_i,
1776          strides=strides,
1777          padding=padding,
1778          use_cudnn_on_gpu=use_cudnn_on_gpu,
1779          data_format=data_format,
1780          dilations=dilations)
1781      return i + 1, ta.write(i, array_ops.expand_dims(output, 0))
1782
1783    n = array_ops.reshape(pfor_input.pfor.loop_len_vector, [])
1784    _, ta = control_flow_ops.while_loop(
1785        lambda i, ta: i < n, while_body,
1786        (0, tensor_array_ops.TensorArray(inputs.dtype, n)))
1787    output = ta.concat()
1788    return wrap(output, True)
1789  else:
1790    # We merge the stack dimension with the channel dimension of the gradients
1791    # and pretend we had a larger filter (see change to filter_sizes below).
1792    # Once the filter backprop is computed, we reshape and transpose back
1793    # appropriately.
1794    grads, _, _ = _channel_flatten_input(grads, data_format)
1795    n = pfor_input.pfor.loop_len_vector
1796    old_filter_sizes = filter_sizes
1797    filter_sizes *= array_ops.concat([[1, 1, 1], n], axis=0)
1798    output = nn_ops.conv2d_backprop_filter(
1799        inputs,
1800        filter_sizes,
1801        grads,
1802        strides=strides,
1803        padding=padding,
1804        use_cudnn_on_gpu=use_cudnn_on_gpu,
1805        data_format=data_format,
1806        dilations=dilations)
1807    new_filter_shape = array_ops.concat([old_filter_sizes[:3], n, [-1]], axis=0)
1808    output = array_ops.reshape(output, new_filter_shape)
1809    output = array_ops.transpose(output, [3, 0, 1, 2, 4])
1810    return wrap(output, True)
1811
1812
1813@RegisterPForWithArgs("LogSoftmax", gen_nn_ops.log_softmax)
1814@RegisterPForWithArgs("Softmax", gen_nn_ops.softmax)
1815def _convert_softmax(pfor_input, op_type, op_func):
1816  del op_type
1817  return wrap(op_func(pfor_input.stacked_input(0)), True)
1818
1819
1820# array_ops
1821
1822
1823@RegisterPForWithArgs("Identity", array_ops.identity)
1824@RegisterPForWithArgs("StopGradient", array_ops.stop_gradient)
1825@RegisterPForWithArgs("MatrixDiag", array_ops.matrix_diag)
1826@RegisterPForWithArgs("MatrixDiagPart", array_ops.matrix_diag_part)
1827def _convert_identity(pfor_input, op_type, op_func):
1828  del op_type
1829  return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
1830
1831
1832@RegisterPFor("IdentityN")
1833def _convert_identity_n(pfor_input):
1834  outputs = array_ops.identity_n([x.t for x in pfor_input.inputs])
1835  return [
1836      wrap(out, inp.is_stacked) for out, inp in zip(outputs, pfor_input.inputs)
1837  ]
1838
1839
1840@RegisterPFor("Reshape")
1841def _convert_reshape(pfor_input):
1842  t = pfor_input.stacked_input(0)
1843  shape = pfor_input.unstacked_input(1)
1844  new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0)
1845  return wrap(array_ops.reshape(t, new_shape), True)
1846
1847
1848@RegisterPFor("BroadcastTo")
1849def _convert_broadcast_to(pfor_input):
1850  t = pfor_input.stacked_input(0)
1851  shape = pfor_input.unstacked_input(1)
1852  new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0)
1853
1854  # Expand dims of stacked t to broadcast against the new shape.
1855  # TODO(davmre): consider factoring out common code with
1856  # `expanddim_inputs_for_broadcast`, which has similar logic but with
1857  # implicit shapes (of input Tensors) rather than explicit shapes.
1858  rank_diff = array_ops.shape(new_shape)[0] - array_ops.rank(t)
1859  ones = array_ops.tile([1], array_ops.reshape(rank_diff, [1]))
1860  t_shape = array_ops.shape(t)
1861  t_expanded_shape = array_ops.concat([t_shape[:1], ones, t_shape[1:]], axis=0)
1862
1863  return wrap(
1864      array_ops.broadcast_to(array_ops.reshape(t, t_expanded_shape), new_shape),
1865      True)
1866
1867
1868@RegisterPFor("ExpandDims")
1869def _convert_expanddims(pfor_input):
1870  t = pfor_input.stacked_input(0)
1871  dim = pfor_input.unstacked_input(1)
1872  dim += math_ops.cast(dim >= 0, dtypes.int32)
1873  return wrap(array_ops.expand_dims(t, axis=dim), True)
1874
1875
1876@RegisterPForWithArgs("LowerBound", gen_array_ops.lower_bound)
1877@RegisterPForWithArgs("UpperBound", gen_array_ops.upper_bound)
1878def _convert_searchsorted(pfor_input, _, op_func):
1879  pfor_input.stack_inputs()
1880  sorted_inputs = _flatten_first_two_dims(pfor_input.stacked_input(0))
1881  values = _flatten_first_two_dims(pfor_input.stacked_input(1))
1882  out_type = pfor_input.get_attr("out_type")
1883  output = op_func(sorted_inputs, values, out_type)
1884  return wrap(
1885      _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector), True)
1886
1887
1888@RegisterPFor("MatrixBandPart")
1889def _convert_matrix_band_part(pfor_input):
1890  t = pfor_input.stacked_input(0)
1891  num_lower = pfor_input.unstacked_input(1)
1892  num_upper = pfor_input.unstacked_input(2)
1893  return wrap(
1894      array_ops.matrix_band_part(t, num_lower=num_lower, num_upper=num_upper),
1895      True)
1896
1897
1898@RegisterPFor("MatrixSetDiag")
1899def _convert_matrix_set_diag(pfor_input):
1900  pfor_input.stack_inputs()
1901  t = pfor_input.stacked_input(0)
1902  diag = pfor_input.stacked_input(1)
1903  return wrap(array_ops.matrix_set_diag(t, diag), True)
1904
1905
1906# Registrations for Matrix{Diag,DiagPart,SetDiag}V2-3.
1907# The input orders defined in the OpKernel and the actual python API are
1908# different (for compatibility with V1), so we cannot use _convert_identity.
1909# v2 is not compatible with v3 and is never exposed on the public API.
1910@RegisterPFor("MatrixDiagV2")
1911@RegisterPFor("MatrixDiagV3")
1912def _convert_matrix_diag_v2(pfor_input):
1913  params = {
1914      "diagonal": pfor_input.stacked_input(0),
1915      "k": pfor_input.unstacked_input(1),
1916      "num_rows": pfor_input.unstacked_input(2),
1917      "num_cols": pfor_input.unstacked_input(3),
1918      "padding_value": pfor_input.unstacked_input(4)
1919  }
1920  if pfor_input.op_type == "MatrixDiagV2":
1921    return wrap(array_ops.matrix_diag_v2(**params), True)
1922  params["align"] = pfor_input.get_attr("align")
1923  return wrap(array_ops.matrix_diag(**params), True)
1924
1925
1926# See notes for MatrixDiagV2
1927@RegisterPFor("MatrixDiagPartV2")
1928@RegisterPFor("MatrixDiagPartV3")
1929def _convert_matrix_diag_part_v2(pfor_input):
1930  params = {
1931      "input": pfor_input.stacked_input(0),
1932      "k": pfor_input.unstacked_input(1),
1933      "padding_value": pfor_input.unstacked_input(2)
1934  }
1935  if pfor_input.op_type == "MatrixDiagPartV2":
1936    return wrap(array_ops.matrix_diag_part_v2(**params), True)
1937  params["align"] = pfor_input.get_attr("align")
1938  return wrap(array_ops.matrix_diag_part(**params), True)
1939
1940
1941# See notes for MatrixDiagV2
1942@RegisterPFor("MatrixSetDiagV2")
1943@RegisterPFor("MatrixSetDiagV3")
1944def _convert_matrix_set_diag_v2(pfor_input):
1945  pfor_input.stack_inputs([0, 1])
1946  params = {
1947      "input": pfor_input.stacked_input(0),
1948      "diagonal": pfor_input.stacked_input(1),
1949      "k": pfor_input.unstacked_input(2)
1950  }
1951  if pfor_input.op_type == "MatrixSetDiagV2":
1952    return wrap(array_ops.matrix_set_diag_v2(**params), True)
1953  params["align"] = pfor_input.get_attr("align")
1954  return wrap(array_ops.matrix_set_diag(**params), True)
1955
1956
1957@RegisterPFor("OneHot")
1958def _convert_one_hot(pfor_input):
1959  indices = pfor_input.stacked_input(0)
1960  depth = pfor_input.unstacked_input(1)
1961  on_value = pfor_input.unstacked_input(2)
1962  off_value = pfor_input.unstacked_input(3)
1963  axis = pfor_input.get_attr("axis")
1964  if axis >= 0:
1965    axis += 1
1966  return wrap(
1967      array_ops.one_hot(indices, depth, on_value, off_value, axis), True)
1968
1969
1970@RegisterPFor("Slice")
1971def _convert_slice(pfor_input):
1972  t = pfor_input.stacked_input(0)
1973  begin = pfor_input.unstacked_input(1)
1974  size = pfor_input.unstacked_input(2)
1975  begin = array_ops.concat([[0], begin], axis=0)
1976  size = array_ops.concat([[-1], size], axis=0)
1977  return wrap(array_ops.slice(t, begin, size), True)
1978
1979
1980@RegisterPFor("Tile")
1981def _convert_tile(pfor_input):
1982  t = pfor_input.stacked_input(0)
1983  multiples = pfor_input.unstacked_input(1)
1984  multiples = array_ops.concat([[1], multiples], 0)
1985  return wrap(array_ops.tile(t, multiples), True)
1986
1987
1988@RegisterPFor("Pack")
1989def _convert_pack(pfor_input):
1990  pfor_input.stack_inputs()
1991  axis = pfor_input.get_attr("axis")
1992  if axis >= 0:
1993    axis += 1
1994  return wrap(
1995      array_ops.stack([x.t for x in pfor_input.inputs], axis=axis), True)
1996
1997
1998@RegisterPFor("Unpack")
1999def _convert_unpack(pfor_input):
2000  value = pfor_input.stacked_input(0)
2001  axis = pfor_input.get_attr("axis")
2002  if axis >= 0:
2003    axis += 1
2004  num = pfor_input.get_attr("num")
2005  return [wrap(x, True) for x in array_ops.unstack(value, axis=axis, num=num)]
2006
2007
2008@RegisterPFor("Pad")
2009def _convert_pad(pfor_input):
2010  t = pfor_input.stacked_input(0)
2011  paddings = pfor_input.unstacked_input(1)
2012  paddings = array_ops.concat([[[0, 0]], paddings], 0)
2013  return wrap(array_ops.pad(t, paddings, mode="CONSTANT"), True)
2014
2015
2016@RegisterPFor("Split")
2017def _convert_split(pfor_input):
2018  split_dim = pfor_input.unstacked_input(0)
2019  t = pfor_input.stacked_input(1)
2020  num_split = pfor_input.get_attr("num_split")
2021  split_dim += math_ops.cast(split_dim >= 0, dtypes.int32)
2022  return [wrap(x, True) for x in array_ops.split(t, num_split, axis=split_dim)]
2023
2024
2025@RegisterPFor("SplitV")
2026def _convert_split_v(pfor_input):
2027  t = pfor_input.stacked_input(0)
2028  splits = pfor_input.unstacked_input(1)
2029  split_dim = pfor_input.unstacked_input(2)
2030  split_dim += math_ops.cast(split_dim >= 0, dtypes.int32)
2031  return [wrap(x, True) for x in array_ops.split(t, splits, axis=split_dim)]
2032
2033
2034@RegisterPFor("Squeeze")
2035def _convert_squeeze(pfor_input):
2036  t = pfor_input.stacked_input(0)
2037  squeeze_dims = pfor_input.get_attr("squeeze_dims")
2038  squeeze_dims = [i + 1 if i >= 0 else i for i in squeeze_dims]
2039  return wrap(array_ops.squeeze(t, axis=squeeze_dims), True)
2040
2041
2042@RegisterPFor("Transpose")
2043def _convert_transpose(pfor_input):
2044  t = pfor_input.stacked_input(0)
2045  perm = pfor_input.unstacked_input(1)
2046  new_perm = array_ops.concat([[0], perm + 1], axis=0)
2047  return wrap(array_ops.transpose(t, new_perm), True)
2048
2049
2050@RegisterPFor("ZerosLike")
2051def _convert_zeroslike(pfor_input):
2052  t = pfor_input.stacked_input(0)
2053  shape = array_ops.shape(t)[1:]
2054  return wrap(array_ops.zeros(shape, dtype=t.dtype), False)
2055
2056
2057@RegisterPFor("Gather")
2058@RegisterPFor("GatherV2")
2059def _convert_gather(pfor_input):
2060  param, param_stacked, _ = pfor_input.input(0)
2061  indices, indices_stacked, _ = pfor_input.input(1)
2062  op_type = pfor_input.op_type
2063  if op_type == "Gather":
2064    validate_indices = pfor_input.get_attr("validate_indices")
2065    axis = 0
2066  else:
2067    validate_indices = None
2068    # Assume we will never have a Tensor with rank > 2**32.
2069    axis = math_ops.cast(pfor_input.unstacked_input(2), dtypes.int32)
2070    axis_value = tensor_util.constant_value(axis)
2071    if axis_value is not None:
2072      axis = axis_value
2073  if indices_stacked and not param_stacked:
2074    if indices is pfor_input.pfor.all_indices and axis == 0:
2075      param_shape0 = tensor_shape.dimension_value(param.shape[0])
2076      indices_shape0 = tensor_shape.dimension_value(indices.shape[0])
2077      if param_shape0 is not None and indices_shape0 == param_shape0:
2078        # Note that with loops and conditionals, indices may not be contiguous.
2079        # However they will be sorted and unique. So if the shape matches, then
2080        # it must be picking up all the rows of param.
2081        return wrap(param, True)
2082      # TODO(agarwal): use array_ops.slice here.
2083    output = array_ops.gather(
2084        param, indices, validate_indices=validate_indices, axis=axis)
2085    if axis != 0:
2086      axis = control_flow_ops.cond(axis < 0,
2087                                   lambda: axis + array_ops.rank(param),
2088                                   lambda: axis)
2089      order = array_ops.concat(
2090          [[axis],
2091           math_ops.range(axis),
2092           math_ops.range(axis + 1, array_ops.rank(output))],
2093          axis=0)
2094      output = control_flow_ops.cond(
2095          math_ops.equal(axis, 0), lambda: output,
2096          lambda: array_ops.transpose(output, order))
2097    return wrap(output, True)
2098  if param_stacked:
2099    loop_len_vector = pfor_input.pfor.loop_len_vector
2100    pfor_input.stack_inputs(stack_indices=[1])
2101    indices = pfor_input.stacked_input(1)
2102    param_flat = _flatten_first_two_dims(param)
2103
2104    # Recompute indices to handle stacked param.
2105    indices_offset = (math_ops.range(math_ops.cast(loop_len_vector[0],
2106                                                   dtype=indices.dtype)) *
2107                      math_ops.cast(array_ops.shape(param)[1], indices.dtype))
2108    # Reshape indices_offset to allow broadcast addition
2109    ones = array_ops.ones([array_ops.rank(indices) - 1], dtype=dtypes.int32)
2110    new_shape = array_ops.concat([loop_len_vector, ones], axis=0)
2111    indices_offset = array_ops.reshape(indices_offset, new_shape)
2112    indices += indices_offset
2113
2114    # TODO(agarwal): handle axis != 0. May need to transpose param or
2115    # array_ops.gather_nd.
2116    if isinstance(axis, ops.Tensor):
2117      axis_value = tensor_util.constant_value(axis)
2118    else:
2119      try:
2120        axis_value = int(axis)
2121      except TypeError:
2122        axis_value = None
2123    msg = ("Gather, where indices and param are both loop dependent, currently "
2124           "requires axis=0")
2125    if axis_value is not None and axis_value != 0:
2126      raise ValueError("Error while converting %s. %s. Got axis=%d" %
2127                       (pfor_input.op, msg, axis))
2128    with ops.control_dependencies(
2129        [check_ops.assert_equal(axis, 0, message=msg)]):
2130      output = array_ops.gather(param_flat, indices)
2131    return wrap(output, True)
2132
2133
2134@RegisterPFor("GatherNd")
2135def _convert_gather_nd(pfor_input):
2136  # TODO(jmenick): Add support for unstacked params.
2137  pfor_input.stack_inputs(stack_indices=[1])
2138  params = pfor_input.stacked_input(0)
2139  indices = pfor_input.stacked_input(1)
2140  stacked_result = array_ops.gather_nd(params, indices, batch_dims=1)
2141  return wrap(stacked_result, True)
2142
2143
2144@RegisterPFor("ConcatV2")
2145def _convert_concatv2(pfor_input):
2146  n = pfor_input.num_inputs
2147  pfor_input.stack_inputs(stack_indices=range(n - 1))
2148  axis = pfor_input.unstacked_input(n - 1)
2149  axis += math_ops.cast(axis >= 0, axis.dtype)
2150  return wrap(
2151      array_ops.concat([x.t for x in pfor_input.inputs[:n - 1]], axis=axis),
2152      True)
2153
2154
2155@RegisterPFor("StridedSlice")
2156def _convert_strided_slice(pfor_input):
2157  inp = pfor_input.stacked_input(0)
2158  begin = pfor_input.unstacked_input(1)
2159  end = pfor_input.unstacked_input(2)
2160  strides = pfor_input.unstacked_input(3)
2161  begin_mask = pfor_input.get_attr("begin_mask")
2162  end_mask = pfor_input.get_attr("end_mask")
2163  ellipsis_mask = pfor_input.get_attr("ellipsis_mask")
2164  new_axis_mask = pfor_input.get_attr("new_axis_mask")
2165  shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask")
2166
2167  begin = array_ops.concat([[0], begin], axis=0)
2168  end = array_ops.concat([[0], end], axis=0)
2169  strides = array_ops.concat([[1], strides], axis=0)
2170  begin_mask = begin_mask << 1 | 1
2171  end_mask = end_mask << 1 | 1
2172  ellipsis_mask <<= 1
2173  new_axis_mask <<= 1
2174  shrink_axis_mask <<= 1
2175  return wrap(
2176      array_ops.strided_slice(
2177          inp,
2178          begin,
2179          end,
2180          strides,
2181          begin_mask=begin_mask,
2182          end_mask=end_mask,
2183          ellipsis_mask=ellipsis_mask,
2184          new_axis_mask=new_axis_mask,
2185          shrink_axis_mask=shrink_axis_mask), True)
2186
2187
2188@RegisterPFor("StridedSliceGrad")
2189def _convert_strided_slice_grad(pfor_input):
2190  shape = pfor_input.unstacked_input(0)
2191  begin = pfor_input.unstacked_input(1)
2192  end = pfor_input.unstacked_input(2)
2193  strides = pfor_input.unstacked_input(3)
2194  dy = pfor_input.stacked_input(4)
2195  begin_mask = pfor_input.get_attr("begin_mask")
2196  end_mask = pfor_input.get_attr("end_mask")
2197  ellipsis_mask = pfor_input.get_attr("ellipsis_mask")
2198  new_axis_mask = pfor_input.get_attr("new_axis_mask")
2199  shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask")
2200
2201  shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0)
2202  begin = array_ops.concat([[0], begin], axis=0)
2203  end = array_ops.concat([[0], end], axis=0)
2204  strides = array_ops.concat([[1], strides], axis=0)
2205  begin_mask = begin_mask << 1 | 1
2206  end_mask = end_mask << 1 | 1
2207  ellipsis_mask <<= 1
2208  new_axis_mask <<= 1
2209  shrink_axis_mask <<= 1
2210  return wrap(
2211      array_ops.strided_slice_grad(
2212          shape,
2213          begin,
2214          end,
2215          strides,
2216          dy,
2217          begin_mask=begin_mask,
2218          end_mask=end_mask,
2219          ellipsis_mask=ellipsis_mask,
2220          new_axis_mask=new_axis_mask,
2221          shrink_axis_mask=shrink_axis_mask), True)
2222
2223
2224# math_ops
2225
2226
2227@RegisterPFor("MatMul")
2228def _convert_matmul(pfor_input):
2229  # TODO(agarwal): Check if tiling is faster than two transposes.
2230  a, a_stacked, _ = pfor_input.input(0)
2231  b, b_stacked, _ = pfor_input.input(1)
2232  tr_a = pfor_input.get_attr("transpose_a")
2233  tr_b = pfor_input.get_attr("transpose_b")
2234  if a_stacked and b_stacked:
2235    output = wrap(math_ops.matmul(a, b, adjoint_a=tr_a, adjoint_b=tr_b), True)
2236    return output
2237  elif a_stacked:
2238    if tr_a:
2239      a = array_ops.transpose(a, [0, 2, 1])
2240    if a.shape.is_fully_defined():
2241      x, y, z = a.shape
2242    else:
2243      x, y, z = [
2244          array_ops.reshape(i, [])
2245          for i in array_ops.split(array_ops.shape(a), 3)
2246      ]
2247    a = array_ops.reshape(a, [x * y, z])
2248    prod = math_ops.matmul(a, b, transpose_b=tr_b)
2249    return wrap(array_ops.reshape(prod, [x, y, -1]), True)
2250  else:
2251    assert b_stacked
2252    if tr_b:
2253      perm = [2, 0, 1]
2254      b = array_ops.transpose(b, perm)
2255    else:
2256      # As an optimization, if one of the first two dimensions is 1, then we can
2257      # reshape instead of transpose.
2258      # TODO(agarwal): This check can be done inside Transpose kernel.
2259      b_shape = array_ops.shape(b)
2260      min_dim = math_ops.minimum(b_shape[0], b_shape[1])
2261      perm = control_flow_ops.cond(
2262          math_ops.equal(min_dim, 1), lambda: [0, 1, 2], lambda: [1, 0, 2])
2263      new_shape = array_ops.stack([b_shape[1], b_shape[0], b_shape[2]])
2264      b = array_ops.transpose(b, perm)
2265      b = array_ops.reshape(b, new_shape)
2266
2267    if b.shape.is_fully_defined():
2268      x, y, z = b.shape
2269    else:
2270      x, y, z = [
2271          array_ops.reshape(i, [])
2272          for i in array_ops.split(array_ops.shape(b), 3)
2273      ]
2274    b = array_ops.reshape(b, [x, y * z])
2275    prod = math_ops.matmul(a, b, transpose_a=tr_a)
2276    prod = array_ops.reshape(prod, [-1, y, z])
2277    prod = array_ops.transpose(prod, [1, 0, 2])
2278    return wrap(prod, True)
2279
2280
2281# TODO(rmlarsen): Use the converter of BatchMatMulV2 once compatibility window
2282# is met.
2283@RegisterPFor("BatchMatMul")
2284def _convert_batch_mat_mul(pfor_input):
2285  # TODO(agarwal): There may be a more efficient way to do this instead of
2286  # stacking the inputs.
2287  pfor_input.stack_inputs()
2288  x = pfor_input.stacked_input(0)
2289  y = pfor_input.stacked_input(1)
2290  adj_x = pfor_input.get_attr("adj_x")
2291  adj_y = pfor_input.get_attr("adj_y")
2292
2293  x = _flatten_first_two_dims(x)
2294  y = _flatten_first_two_dims(y)
2295  output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
2296  output = _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector)
2297  return wrap(output, True)
2298
2299
2300@RegisterPFor("BatchMatMulV2")
2301def _convert_batch_mat_mul_v2(pfor_input):
2302  pfor_input.expanddim_inputs_for_broadcast()
2303  x = pfor_input.input(0)[0]
2304  y = pfor_input.input(1)[0]
2305  adj_x = pfor_input.get_attr("adj_x")
2306  adj_y = pfor_input.get_attr("adj_y")
2307
2308  output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
2309  return wrap(output, True)
2310
2311
2312@RegisterPForWithArgs("Sum", math_ops.reduce_sum)
2313@RegisterPForWithArgs("Prod", math_ops.reduce_prod)
2314@RegisterPForWithArgs("Max", math_ops.reduce_max)
2315@RegisterPForWithArgs("Min", math_ops.reduce_min)
2316@RegisterPForWithArgs("Mean", math_ops.reduce_mean)
2317@RegisterPForWithArgs("All", math_ops.reduce_all)
2318@RegisterPForWithArgs("Any", math_ops.reduce_any)
2319def _convert_reduction(pfor_input, _, op_func):
2320  t = pfor_input.stacked_input(0)
2321  indices = pfor_input.unstacked_input(1)
2322  # Shift positive indices by one to account for the extra dimension.
2323  indices += math_ops.cast(indices >= 0, dtypes.int32)
2324  keep_dims = pfor_input.get_attr("keep_dims")
2325  return wrap(op_func(t, indices, keepdims=keep_dims), True)
2326
2327
2328@RegisterPForWithArgs("Cumsum", math_ops.cumsum)
2329@RegisterPForWithArgs("Cumprod", math_ops.cumprod)
2330def _convert_cumfoo(pfor_input, _, op_func):
2331  t = pfor_input.stacked_input(0)
2332  axis = pfor_input.unstacked_input(1)
2333  # Shift positive indices by one to account for the extra dimension.
2334  axis += math_ops.cast(axis >= 0, dtypes.int32)
2335  exclusive = pfor_input.get_attr("exclusive")
2336  reverse = pfor_input.get_attr("reverse")
2337  return wrap(op_func(t, axis, exclusive=exclusive, reverse=reverse), True)
2338
2339
2340@RegisterPFor("BiasAdd")
2341def _convert_biasadd(pfor_input):
2342  t, t_stacked, _ = pfor_input.input(0)
2343  bias, bias_stacked, _ = pfor_input.input(1)
2344  data_format = pfor_input.get_attr("data_format").decode()
2345  if bias_stacked:
2346    # BiasAdd only supports 1-D biases, so cast bias to match value and use Add.
2347    pfor_input.expanddim_inputs_for_broadcast()
2348    t, _, _ = pfor_input.input(0)
2349    bias = math_ops.cast(pfor_input.stacked_input(1), t.dtype)
2350    if compat.as_bytes(data_format) == b"NCHW":
2351      b_shape = array_ops.shape(bias)
2352      new_b_shape = array_ops.concat(
2353          [b_shape[:-3], b_shape[-1:], b_shape[-3:-1]], axis=0)
2354      bias = array_ops.reshape(bias, new_b_shape)
2355    return wrap(math_ops.add(t, bias), True)
2356  else:
2357    assert t_stacked, "At least one input to BiasAdd should be loop variant."
2358    if compat.as_bytes(data_format) == b"NCHW":
2359      shape = array_ops.shape(t)
2360      flattened_shape = array_ops.concat([[-1], shape[2:]], axis=0)
2361      t = array_ops.reshape(t, flattened_shape)
2362      t = nn_ops.bias_add(t, bias, data_format="NCHW")
2363      t = array_ops.reshape(t, shape)
2364      return wrap(t, True)
2365    return wrap(nn_ops.bias_add(t, bias, data_format=data_format), True)
2366
2367
2368@RegisterPFor("UnsortedSegmentSum")
2369def _convert_unsortedsegmentsum(pfor_input):
2370  pfor_input.stack_inputs([0, 1])
2371  data = pfor_input.stacked_input(0)
2372  segment_ids = pfor_input.stacked_input(1)
2373  # TODO(agarwal): handle stacked?
2374  num_segments = pfor_input.unstacked_input(2)
2375  if segment_ids.dtype != num_segments.dtype:
2376    segment_ids = math_ops.cast(segment_ids, dtypes.int64)
2377    num_segments = math_ops.cast(num_segments, dtypes.int64)
2378  dtype = segment_ids.dtype
2379  segment_shape = array_ops.shape(segment_ids, out_type=dtype)
2380  n = segment_shape[0]
2381  ones = array_ops.ones_like(segment_shape, dtype=dtype)[1:]
2382  segment_offset = num_segments * math_ops.range(n, dtype=dtype)
2383  segment_offset = array_ops.reshape(segment_offset,
2384                                     array_ops.concat([[n], ones], axis=0))
2385  segment_ids += segment_offset
2386  num_segments = math_ops.cast(num_segments, dtypes.int64) * math_ops.cast(
2387      n, dtypes.int64)
2388  output = math_ops.unsorted_segment_sum(data, segment_ids, num_segments)
2389  new_output_shape = array_ops.concat(
2390      [[n, -1], array_ops.shape(output)[1:]], axis=0)
2391  output = array_ops.reshape(output, new_output_shape)
2392  return wrap(output, True)
2393
2394
2395def _flatten_array_with_offset(ids, offset_delta, num_rows):
2396  """Flattens a rank 2 tensor, adding an offset to each row."""
2397  # Note that if `ids` is rank 1, it is broadcast to rank 2.
2398  offset_delta = math_ops.cast(offset_delta, ids.dtype)
2399  n = math_ops.cast(num_rows, dtype=ids.dtype)
2400  offsets = math_ops.range(
2401      start=0, limit=n * offset_delta, delta=offset_delta, dtype=ids.dtype)
2402  offsets = array_ops.expand_dims(offsets, -1)
2403  ids += offsets
2404  return array_ops.reshape(ids, [-1])
2405
2406
2407@RegisterPForWithArgs("SparseSegmentSum", math_ops.sparse_segment_sum_v2)
2408@RegisterPForWithArgs("SparseSegmentMean", math_ops.sparse_segment_mean_v2)
2409@RegisterPForWithArgs("SparseSegmentSqrtN", math_ops.sparse_segment_sqrt_n_v2)
2410@RegisterPForWithArgs("SparseSegmentSumWithNumSegments",
2411                      math_ops.sparse_segment_sum_v2)
2412@RegisterPForWithArgs("SparseSegmentMeanWithNumSegments",
2413                      math_ops.sparse_segment_mean_v2)
2414@RegisterPForWithArgs("SparseSegmentSqrtNWithNumSegments",
2415                      math_ops.sparse_segment_sqrt_n_v2)
2416def _convert_sparse_segment(pfor_input, _, op_func):
2417  _, segment_ids_stacked, _ = pfor_input.input(2)
2418  if segment_ids_stacked:
2419    pfor_input.stack_inputs([1])
2420  data, data_stacked, _ = pfor_input.input(0)
2421  indices, _, _ = pfor_input.input(1)
2422  num_inputs = len(pfor_input.inputs)
2423  assert num_inputs in (3, 4)
2424  if num_inputs == 3:
2425    # `segment_ids` needs to be unstacked since otherwise output sizes could
2426    # differ across pfor iterations.
2427    segment_ids = pfor_input.unstacked_input(2)
2428    num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1)
2429  else:
2430    segment_ids, _, _ = pfor_input.input(2)
2431    num_segments = pfor_input.unstacked_input(3)
2432
2433  n = pfor_input.pfor.loop_len_vector[0]
2434  if data_stacked:
2435    indices = _flatten_array_with_offset(indices, array_ops.shape(data)[1], n)
2436    data = _flatten_first_two_dims(data)
2437  else:
2438    indices = array_ops.reshape(indices, [-1])
2439  segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n)
2440
2441  if num_inputs == 3:
2442    num_segments = None
2443  else:
2444    num_segments *= n
2445  output = op_func(data, indices, segment_ids, num_segments=num_segments)
2446  output = _unflatten_first_dim(output, [n])
2447  return wrap(output, True)
2448
2449
2450@RegisterPForWithArgs("SparseSegmentMeanGrad",
2451                      math_ops.sparse_segment_mean_grad)
2452@RegisterPForWithArgs("SparseSegmentSqrtNGrad",
2453                      math_ops.sparse_segment_sqrt_n_grad)
2454def _convert_sparse_segment_grad(pfor_input, _, op_func):
2455  grad = pfor_input.stacked_input(0)
2456  indices = pfor_input.unstacked_input(1)
2457  segment_ids = pfor_input.unstacked_input(2)
2458  dim0 = pfor_input.unstacked_input(3)
2459
2460  n = pfor_input.pfor.loop_len_vector[0]
2461  indices = _flatten_array_with_offset(indices, dim0, n)
2462  num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1)
2463  segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n)
2464  grad = _flatten_first_two_dims(grad)
2465  dim0 *= n
2466  output = op_func(grad, indices, segment_ids, dim0)
2467  output = _unflatten_first_dim(output, [n])
2468  return wrap(output, True)
2469
2470
2471@RegisterPFor("Cast")
2472def _convert_cast(pfor_input):
2473  inp = pfor_input.stacked_input(0)
2474  dtype = pfor_input.get_attr("DstT")
2475  return wrap(math_ops.cast(inp, dtype), True)
2476
2477
2478@RegisterPForWithArgs("Abs", math_ops.abs)
2479@RegisterPForWithArgs("Acos", math_ops.acos)
2480@RegisterPForWithArgs("Acosh", math_ops.acosh)
2481@RegisterPForWithArgs("Add", math_ops.add)
2482@RegisterPForWithArgs("AddV2", math_ops.add_v2)
2483@RegisterPForWithArgs("Angle", math_ops.angle)
2484@RegisterPForWithArgs("Asin", math_ops.asin)
2485@RegisterPForWithArgs("Asinh", math_ops.asinh)
2486@RegisterPForWithArgs("Atan", math_ops.atan)
2487@RegisterPForWithArgs("Atan2", math_ops.atan2)
2488@RegisterPForWithArgs("Atanh", math_ops.atanh)
2489@RegisterPForWithArgs("BesselI0e", math_ops.bessel_i0e)
2490@RegisterPForWithArgs("BesselI1e", math_ops.bessel_i1e)
2491@RegisterPForWithArgs("BitwiseAnd", bitwise_ops.bitwise_and)
2492@RegisterPForWithArgs("BitwiseOr", bitwise_ops.bitwise_or)
2493@RegisterPForWithArgs("BitwiseXor", bitwise_ops.bitwise_xor)
2494@RegisterPForWithArgs("Ceil", math_ops.ceil)
2495@RegisterPForWithArgs("Complex", math_ops.complex)
2496@RegisterPForWithArgs("ComplexAbs", math_ops.complex_abs)
2497@RegisterPForWithArgs("Conj", math_ops.conj)
2498@RegisterPForWithArgs("Cos", math_ops.cos)
2499@RegisterPForWithArgs("Cosh", math_ops.cosh)
2500@RegisterPForWithArgs("Dawsn", special_math_ops.dawsn)
2501@RegisterPForWithArgs("Digamma", math_ops.digamma)
2502@RegisterPForWithArgs("Div", math_ops.div)
2503@RegisterPForWithArgs("DivNoNan", math_ops.div_no_nan)
2504@RegisterPForWithArgs("Elu", nn_ops.elu)
2505@RegisterPForWithArgs("Erf", math_ops.erf)
2506@RegisterPForWithArgs("Erfc", math_ops.erfc)
2507@RegisterPForWithArgs("Erfinv", math_ops.erfinv)
2508@RegisterPForWithArgs("Exp", math_ops.exp)
2509@RegisterPForWithArgs("Expint", special_math_ops.expint)
2510@RegisterPForWithArgs("Expm1", math_ops.expm1)
2511@RegisterPForWithArgs("Floor", math_ops.floor)
2512@RegisterPForWithArgs("FloorDiv", math_ops.floor_div)
2513@RegisterPForWithArgs("FloorMod", math_ops.floor_mod)
2514@RegisterPForWithArgs("FresnelCos", special_math_ops.fresnel_cos)
2515@RegisterPForWithArgs("FresnelSin", special_math_ops.fresnel_sin)
2516@RegisterPForWithArgs("Greater", math_ops.greater)
2517@RegisterPForWithArgs("GreaterEqual", math_ops.greater_equal)
2518@RegisterPForWithArgs("Igamma", math_ops.igamma)
2519@RegisterPForWithArgs("IgammaGradA", math_ops.igamma_grad_a)
2520@RegisterPForWithArgs("Igammac", math_ops.igammac)
2521@RegisterPForWithArgs("Imag", math_ops.imag)
2522@RegisterPForWithArgs("Inv", math_ops.inv)
2523@RegisterPForWithArgs("Invert", bitwise_ops.invert)
2524@RegisterPForWithArgs("IsFinite", math_ops.is_finite)
2525@RegisterPForWithArgs("IsInf", math_ops.is_inf)
2526@RegisterPForWithArgs("IsNan", math_ops.is_nan)
2527@RegisterPForWithArgs("LeftShift", bitwise_ops.left_shift)
2528@RegisterPForWithArgs("Less", math_ops.less)
2529@RegisterPForWithArgs("LessEqual", math_ops.less_equal)
2530@RegisterPForWithArgs("Lgamma", math_ops.lgamma)
2531@RegisterPForWithArgs("Log", math_ops.log)
2532@RegisterPForWithArgs("Log1p", math_ops.log1p)
2533@RegisterPForWithArgs("LogicalAnd", math_ops.logical_and)
2534@RegisterPForWithArgs("LogicalNot", math_ops.logical_not)
2535@RegisterPForWithArgs("LogicalOr", math_ops.logical_or)
2536@RegisterPForWithArgs("LogicalXor", math_ops.logical_xor)
2537@RegisterPForWithArgs("Maximum", math_ops.maximum)
2538@RegisterPForWithArgs("Minimum", math_ops.minimum)
2539@RegisterPForWithArgs("Mod", math_ops.mod)
2540@RegisterPForWithArgs("Mul", math_ops.multiply)
2541@RegisterPForWithArgs("MulNoNan", math_ops.mul_no_nan)
2542@RegisterPForWithArgs("Ndtri", math_ops.ndtri)
2543@RegisterPForWithArgs("Neg", math_ops.negative)
2544@RegisterPForWithArgs("Polygamma", math_ops.polygamma)
2545@RegisterPForWithArgs("Pow", math_ops.pow)
2546@RegisterPForWithArgs("Real", math_ops.real)
2547@RegisterPForWithArgs("RealDiv", math_ops.divide)
2548@RegisterPForWithArgs("Reciprocal", math_ops.reciprocal)
2549@RegisterPForWithArgs("Relu", nn_ops.relu)
2550@RegisterPForWithArgs("Relu6", nn_ops.relu6)
2551@RegisterPForWithArgs("RightShift", bitwise_ops.right_shift)
2552@RegisterPForWithArgs("Rint", math_ops.rint)
2553@RegisterPForWithArgs("Round", math_ops.round)
2554@RegisterPForWithArgs("Rsqrt", math_ops.rsqrt)
2555@RegisterPForWithArgs("Selu", nn_ops.selu)
2556@RegisterPForWithArgs("Sigmoid", math_ops.sigmoid)
2557@RegisterPForWithArgs("Sign", math_ops.sign)
2558@RegisterPForWithArgs("Sin", math_ops.sin)
2559@RegisterPForWithArgs("Sinh", math_ops.sinh)
2560@RegisterPForWithArgs("Softplus", nn_ops.softplus)
2561@RegisterPForWithArgs("Softsign", nn_ops.softsign)
2562@RegisterPForWithArgs("Spence", special_math_ops.spence)
2563@RegisterPForWithArgs("Sqrt", math_ops.sqrt)
2564@RegisterPForWithArgs("Square", math_ops.square)
2565@RegisterPForWithArgs("SquaredDifference", math_ops.squared_difference)
2566@RegisterPForWithArgs("Sub", math_ops.subtract)
2567@RegisterPForWithArgs("Tan", math_ops.tan)
2568@RegisterPForWithArgs("Tanh", math_ops.tanh)
2569@RegisterPForWithArgs("TruncateDiv", math_ops.truncate_div)
2570@RegisterPForWithArgs("TruncateMod", math_ops.truncate_mod)
2571@RegisterPForWithArgs("Xdivy", math_ops.xdivy)
2572@RegisterPForWithArgs("Xlogy", math_ops.xlogy)
2573@RegisterPForWithArgs("Xlog1py", math_ops.xlog1py)
2574@RegisterPForWithArgs("Zeta", math_ops.zeta)
2575def _convert_cwise(pfor_input, op_type, op_func):
2576  # Note that ops handled here do not have attributes except those listed below
2577  # and hence don't need extra arguments passed to the cwise_op call below.
2578  for attr in pfor_input.op.node_def.attr.keys():
2579    assert attr in [u"T", u"Tout", u"_xla_compile_id"], (op_type, attr)
2580  pfor_input.expanddim_inputs_for_broadcast()
2581  return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
2582
2583
2584@RegisterPFor("Equal")
2585def _convert_equal(pfor_input):
2586  pfor_input.expanddim_inputs_for_broadcast()
2587  x = pfor_input.input(0)[0]
2588  y = pfor_input.input(1)[0]
2589  incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error")
2590  assert incompatible_shape_error
2591  return wrap(math_ops.equal(x, y), True)
2592
2593
2594@RegisterPFor("NotEqual")
2595def _convert_not_equal(pfor_input):
2596  pfor_input.expanddim_inputs_for_broadcast()
2597  x = pfor_input.input(0)[0]
2598  y = pfor_input.input(1)[0]
2599  incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error")
2600  assert incompatible_shape_error
2601  return wrap(math_ops.not_equal(x, y), True)
2602
2603
2604@RegisterPFor("ApproximateEqual")
2605def _convert_approximate_equal(pfor_input):
2606  pfor_input.expanddim_inputs_for_broadcast()
2607  x = pfor_input.input(0)[0]
2608  y = pfor_input.input(1)[0]
2609  tolerance = pfor_input.get_attr("tolerance")
2610  return wrap(math_ops.approximate_equal(x, y, tolerance=tolerance), True)
2611
2612
2613@RegisterPFor("Shape")
2614def _convert_shape(pfor_input):
2615  out_type = pfor_input.get_attr("out_type")
2616  return wrap(
2617      array_ops.shape(pfor_input.stacked_input(0), out_type=out_type)[1:],
2618      False)
2619
2620
2621@RegisterPFor("ShapeN")
2622def _convert_shape_n(pfor_input):
2623  out_type = pfor_input.get_attr("out_type")
2624  shapes = [
2625      array_ops.shape(x, out_type=out_type)[1:] if stacked else array_ops.shape(
2626          x, out_type=out_type) for x, stacked, _ in pfor_input.inputs
2627  ]
2628  return [wrap(x, False) for x in shapes]
2629
2630
2631@RegisterPFor("Size")
2632def _convert_size(pfor_input):
2633  out_type = pfor_input.get_attr("out_type")
2634  n = math_ops.cast(pfor_input.pfor.loop_len_vector[0], out_type)
2635  return wrap(
2636      array_ops.size(pfor_input.stacked_input(0), out_type=out_type) // n,
2637      False)
2638
2639
2640@RegisterPFor("Rank")
2641def _convert_rank(pfor_input):
2642  return wrap(array_ops.rank(pfor_input.stacked_input(0)) - 1, False)
2643
2644
2645@RegisterPFor("AddN")
2646def _convert_addn(pfor_input):
2647  # AddN does not support broadcasting.
2648  pfor_input.stack_inputs()
2649  return wrap(math_ops.add_n([x.t for x in pfor_input.inputs]), True)
2650
2651
2652@RegisterPFor("Cross")
2653def _convert_cross(pfor_input):
2654  pfor_input.stack_inputs()
2655  a = pfor_input.stacked_input(0)
2656  b = pfor_input.stacked_input(1)
2657  return wrap(math_ops.cross(a, b), True)
2658
2659
2660@RegisterPFor("BiasAddGrad")
2661def _convert_biasaddgrad(pfor_input):
2662  grad = pfor_input.stacked_input(0)
2663  fmt = pfor_input.get_attr("data_format")
2664  if fmt == b"NCHW":
2665    output = math_ops.reduce_sum(grad, axis=[1, 3, 4], keepdims=False)
2666  else:
2667    grad_shape = array_ops.shape(grad)
2668    last_dim_shape = grad_shape[-1]
2669    first_dim_shape = grad_shape[0]
2670    output = array_ops.reshape(grad, [first_dim_shape, -1, last_dim_shape])
2671    output = math_ops.reduce_sum(output, axis=[1], keepdims=False)
2672  return wrap(output, True)
2673
2674
2675# Some required ops are not exposed under the tf namespace. Hence relying on
2676# _create_op to create them.
2677@RegisterPForWithArgs("EluGrad")
2678@RegisterPForWithArgs("Relu6Grad")
2679@RegisterPForWithArgs("ReluGrad")
2680@RegisterPForWithArgs("SeluGrad")
2681@RegisterPForWithArgs("SigmoidGrad")
2682@RegisterPForWithArgs("SoftplusGrad")
2683@RegisterPForWithArgs("SoftsignGrad")
2684@RegisterPForWithArgs("TanhGrad")
2685@RegisterPForWithArgs("SqrtGrad")
2686@RegisterPForWithArgs("RsqrtGrad")
2687@RegisterPForWithArgs("ReciprocalGrad")
2688def _convert_grads(pfor_input, op_type, *args, **kw_args):
2689  del args
2690  del kw_args
2691  # TODO(agarwal): Looks like these ops don't support broadcasting. Hence we
2692  # have to use tiling here.
2693  pfor_input.stack_inputs()
2694  outputs = _create_op(
2695      op_type, [x.t for x in pfor_input.inputs],
2696      [x.dtype for x in pfor_input.outputs],
2697      attrs=pfor_input.op.node_def.attr).outputs
2698  return [wrap(x, True) for x in outputs]
2699
2700
2701@RegisterPFor("Select")
2702def _convert_select(pfor_input):
2703  pfor_input.stack_inputs()
2704  cond = pfor_input.stacked_input(0)
2705  t = pfor_input.stacked_input(1)
2706  e = pfor_input.stacked_input(2)
2707  cond_rank = array_ops.rank(cond)
2708  cond, t, e = control_flow_ops.cond(
2709      cond_rank > 1, lambda: _inputs_with_flattening(pfor_input, [0, 1, 2]),
2710      lambda: [cond, t, e])
2711  outputs = _create_op(
2712      pfor_input.op_type, [cond, t, e], [x.dtype for x in pfor_input.outputs],
2713      attrs=pfor_input.op.node_def.attr).outputs
2714  n = pfor_input.pfor.loop_len_vector
2715  out = control_flow_ops.cond(cond_rank > 1,
2716                              lambda: _unflatten_first_dim(outputs[0], n),
2717                              lambda: outputs[0])
2718  return [wrap(out, True) for x in outputs]
2719
2720
2721@RegisterPFor("SelectV2")
2722def _convert_selectv2(pfor_input):
2723  pfor_input.expanddim_inputs_for_broadcast()
2724  cond = pfor_input.input(0)[0]
2725  t = pfor_input.input(1)[0]
2726  e = pfor_input.input(2)[0]
2727  out = array_ops.where_v2(cond, t, e)
2728  return wrap(out, True)
2729
2730
2731# random_ops
2732
2733
2734def _transpose_dim_to_front(x, dim):
2735  rank = array_ops.rank(x)
2736  return array_ops.transpose(
2737      x,
2738      perm=array_ops.concat(
2739          [[dim], math_ops.range(0, dim),
2740           math_ops.range(dim + 1, rank)],
2741          axis=0))
2742
2743
2744@RegisterPForWithArgs("RandomUniform")
2745@RegisterPForWithArgs("RandomUniformInt")
2746@RegisterPForWithArgs("RandomStandardNormal")
2747@RegisterPForWithArgs("TruncatedNormal")
2748def _convert_random(pfor_input, op_type, *args, **kw_args):
2749  del args
2750  del kw_args
2751  inputs = [pfor_input.unstacked_input(i) for i in range(pfor_input.num_inputs)]
2752  # inputs[0] is "shape"
2753  inputs[0] = array_ops.concat([pfor_input.pfor.loop_len_vector, inputs[0]],
2754                               axis=0)
2755  logging.warning(
2756      "Note that %s inside pfor op may not give same output as "
2757      "inside a sequential loop.", op_type)
2758  outputs = _create_op(
2759      op_type,
2760      inputs, [x.dtype for x in pfor_input.outputs],
2761      attrs=pfor_input.op.node_def.attr).outputs
2762  return [wrap(x, True) for x in outputs]
2763
2764
2765@RegisterPFor("RandomGamma")
2766@RegisterPFor("RandomPoissonV2")
2767def _convert_random_with_param(pfor_input):
2768  shape = pfor_input.unstacked_input(0)
2769  # param is lam (Poisson rate) or alpha (Gamma shape).
2770  param, param_stacked, _ = pfor_input.input(1)
2771  logging.warning(
2772      "Note that %s inside pfor op may not give same output as "
2773      "inside a sequential loop.", pfor_input.op_type)
2774
2775  if param_stacked:
2776    samples = _create_op(
2777        pfor_input.op_type,
2778        inputs=[shape, param],
2779        op_dtypes=[x.dtype for x in pfor_input.outputs],
2780        attrs=pfor_input.op.node_def.attr).outputs[0]
2781    loop_dim = array_ops.shape(shape)[0]
2782    stacked_samples = _transpose_dim_to_front(samples, loop_dim)
2783  else:
2784    shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0)
2785    stacked_samples = _create_op(
2786        pfor_input.op_type,
2787        inputs=[shape, param],
2788        op_dtypes=[x.dtype for x in pfor_input.outputs],
2789        attrs=pfor_input.op.node_def.attr).outputs[0]
2790
2791  return wrap(stacked_samples, True)
2792
2793
2794@RegisterPFor("Multinomial")
2795def _convert_multinomial(pfor_input):
2796  logits, logits_stacked, _ = pfor_input.input(0)
2797  num_samples = pfor_input.unstacked_input(1)
2798  seed = pfor_input.get_attr("seed")
2799  seed2 = pfor_input.get_attr("seed2")
2800  output_dtype = pfor_input.get_attr("output_dtype")
2801  logging.warning(
2802      "Note that Multinomial inside pfor op may not give same output as "
2803      "inside a sequential loop.")
2804
2805  n = pfor_input.pfor.loop_len_vector[0]
2806  if logits_stacked:
2807    flattened_logits = _flatten_first_two_dims(logits)
2808    samples = gen_random_ops.multinomial(
2809        flattened_logits,
2810        num_samples,
2811        seed=seed,
2812        seed2=seed2,
2813        output_dtype=output_dtype)
2814    stacked_samples = _unflatten_first_dim(samples, [n])
2815  else:
2816    samples = gen_random_ops.multinomial(
2817        logits,
2818        num_samples * n,
2819        seed=seed,
2820        seed2=seed2,
2821        output_dtype=output_dtype)
2822    stacked_samples = array_ops.transpose(
2823        array_ops.reshape(samples, [-1, n, num_samples]), [1, 0, 2])
2824
2825  return wrap(stacked_samples, True)
2826
2827
2828# linalg_ops
2829
2830
2831# TODO(jmenick) - the same logic applies to other einsums. Generalize this
2832# in a future CL.
2833@RegisterPFor("XlaEinsum")
2834def _convert_einsum(pfor_input):
2835  first_input, first_input_stacked, _ = pfor_input.input(0)
2836  second_input, second_input_stacked, _ = pfor_input.input(1)
2837
2838  # Parse the einsum equation.
2839  equation = pfor_input.get_attr("equation").decode("utf-8")
2840  input_expr, output_expr = equation.split("->")
2841  input_a_expr, input_b_expr = input_expr.split(",")
2842
2843  # pick a placeholder symbol to use for the new axis
2844  chosen_symbol = None
2845  for s in string.ascii_letters:
2846    if s in equation:
2847      continue
2848    else:
2849      chosen_symbol = s
2850      break
2851
2852  if chosen_symbol is None:
2853    raise ValueError("Could not figure out what symbol to use for new axis.")
2854
2855  assert first_input_stacked or second_input_stacked
2856  if first_input_stacked:
2857    input_a_expr = "{}{}".format(chosen_symbol, input_a_expr)
2858  if second_input_stacked:
2859    input_b_expr = "{}{}".format(chosen_symbol, input_b_expr)
2860  output_expr = "{}{}".format(chosen_symbol, output_expr)
2861
2862  new_equation = "{},{}->{}".format(input_a_expr, input_b_expr, output_expr)
2863  result = xla.einsum(equation=new_equation, a=first_input, b=second_input)
2864  return wrap(result, True)
2865
2866
2867@RegisterPFor("Cholesky")
2868def _convert_cholesky(pfor_input):
2869  t = pfor_input.stacked_input(0)
2870  return wrap(linalg_ops.cholesky(t), True)
2871
2872
2873@RegisterPFor("LogMatrixDeterminant")
2874def _convert_log_matrix_determinant(pfor_input):
2875  t = pfor_input.stacked_input(0)
2876  return [wrap(x, True) for x in linalg_ops.log_matrix_determinant(t)]
2877
2878
2879@RegisterPFor("MatrixTriangularSolve")
2880def _convert_matrix_triangular_solve(pfor_input):
2881  pfor_input.expanddim_inputs_for_broadcast()
2882  matrix = pfor_input.input(0)[0]
2883  rhs = pfor_input.input(1)[0]
2884  lower = pfor_input.get_attr("lower")
2885  adjoint = pfor_input.get_attr("adjoint")
2886  output = linalg_ops.matrix_triangular_solve(
2887      matrix, rhs, lower=lower, adjoint=adjoint)
2888  return wrap(output, True)
2889
2890
2891@RegisterPFor("SelfAdjointEigV2")
2892def _convert_self_adjoint_eig(pfor_input):
2893  t = pfor_input.stacked_input(0)
2894  compute_v = pfor_input.get_attr("compute_v")
2895  e, v = gen_linalg_ops.self_adjoint_eig_v2(t, compute_v=compute_v)
2896  # If compute_v is False, v will have shape [0].
2897  return wrap(e, True), wrap(v, compute_v)
2898
2899
2900# logging_ops
2901
2902
2903@RegisterPFor("Assert")
2904def _convert_assert(pfor_input):
2905  cond, cond_stacked, _ = pfor_input.input(0)
2906  if cond_stacked:
2907    cond = math_ops.reduce_all(cond)
2908
2909  data_list = [x.t for x in pfor_input.inputs][1:]
2910  return _create_op(
2911      "Assert", [cond] + data_list, [], attrs=pfor_input.op.node_def.attr)
2912
2913
2914@RegisterPFor("Print")
2915def _convert_print(pfor_input):
2916  # Note that we don't stack all the inputs. Hence unstacked values are printed
2917  # once here vs multiple times in a while_loop.
2918  pfor_input.stack_inputs([0])
2919  outputs = _create_op(
2920      "Print", [x.t for x in pfor_input.inputs],
2921      [x.dtype for x in pfor_input.outputs],
2922      attrs=pfor_input.op.node_def.attr).outputs
2923  return [wrap(x, True) for x in outputs]
2924
2925
2926# data_flow_ops
2927
2928# TensorArray conversion is tricky since we don't support arrays of
2929# TensorArrays. For converting them, we consider two distinct cases:
2930#
2931# 1. The array is constructed outside the pfor call, and read/written inside the
2932# loop.
2933# This is an easier case since we don't need to make an array of TensorArrays.
2934# A correctness requirement is that these parallel iterations shouldn't attempt
2935# to write to the same location. Hence at conversion time we disallow indices to
2936# be loop-invariant as that would guarantee a collision. Even if the indices are
2937# not loop-invariant, they could conflict and that shall trigger runtime errors.
2938#
2939# 2. The array is constructed and used entirely inside each pfor iteration.
2940# For simplicity, here we require that the indices used for write/scatter are
2941# "unstacked". Otherwise it becomes hard to merge the TensorArrays created in
2942# different pfor iterations. We consider two sub_cases:
2943#
2944# 2a Elements written to the array are "stacked"
2945# To simulate multiple TensorArrays, we may increase the dimension of each
2946# element of the array. i.e. the i_th row of the j_th entry of the converted
2947# TensorArray corresponds to the j_th entry of the TensorArray in the i_th
2948# pfor iteration.
2949#
2950# 2b Elements written to the array are "unstacked"
2951# In this case we don't increase the dimensions to avoid redundant tiling. Each
2952# iteration is trying to write the same value. So we convert that to a single
2953# write.
2954#
2955# Here are some tricks used to implement the above:
2956# - TensorArrayV3 constructor encodes the element shape as an attr. Instead of
2957# trying to trace whether future writes are stacked or unstacked in order to set
2958# this attr, we set it to correspond to unknown shape.
2959# - We use the "flow" output of the different ops to track whether the array
2960# elements are stacked or unstacked. If a stacked write/scatter is done, we make
2961# the flow stacked as well.
2962# - We use some heuristic traversal of the graph to track whether the
2963# TensorArray handle was created inside or outside the pfor loop.
2964
2965
2966@RegisterPFor("TensorArrayV3")
2967def _convert_tensor_array_v3(pfor_input):
2968  size = pfor_input.unstacked_input(0)
2969  dtype = pfor_input.get_attr("dtype")
2970  dynamic_size = pfor_input.get_attr("dynamic_size")
2971  clear_after_read = pfor_input.get_attr("clear_after_read")
2972  identical_element_shapes = pfor_input.get_attr("identical_element_shapes")
2973  tensor_array_name = pfor_input.get_attr("tensor_array_name")
2974  handle, flow = data_flow_ops.tensor_array_v3(
2975      size,
2976      dtype=dtype,
2977      # We don't set element shape since we don't know if writes are stacked or
2978      # not yet.
2979      element_shape=None,
2980      dynamic_size=dynamic_size,
2981      clear_after_read=clear_after_read,
2982      identical_element_shapes=identical_element_shapes,
2983      tensor_array_name=tensor_array_name)
2984  # Note we keep flow unstacked for now since we don't know if writes will be
2985  # stacked or not.
2986  return wrap(handle, False), wrap(flow, False)
2987
2988
2989@RegisterPFor("TensorArraySizeV3")
2990def _convert_tensor_array_size_v3(pfor_input):
2991  handle = pfor_input.unstacked_input(0)
2992  flow, flow_stacked, _ = pfor_input.input(1)
2993  if flow_stacked:
2994    flow = _unstack_flow(flow)
2995  size = data_flow_ops.tensor_array_size_v3(handle, flow)
2996  return wrap(size, False)
2997
2998
2999def _handle_inside_pfor(pfor_input, handle):
3000  """Returns True if handle was created inside the pfor loop."""
3001  # We use some heuristic to find the original TensorArray creation op.
3002  # The logic should handle the common cases (except cond based subgraphs).
3003  # In theory the user could perform different operations on the handle (like
3004  # Reshape, stack multiple handles, etc) which could break this logic.
3005  # TODO(agarwal): handle Switch/Merge.
3006  while handle.op.type in ("Enter", "Identity"):
3007    handle = handle.op.inputs[0]
3008  if handle.op.type not in [
3009      "TensorArrayV3", "TensorArrayGradV3", "TensorArrayGradWithShape"
3010  ]:
3011    raise ValueError("Unable to find source for handle %s" % handle)
3012  else:
3013    return pfor_input.pfor.op_is_inside_loop(handle.op)
3014
3015
3016def _unstack_flow(value):
3017  # TODO(agarwal): consider looking if this is a Tile op then get its input.
3018  # This may avoid running the Tile operations.
3019  return array_ops.gather(value, 0)
3020
3021
3022@RegisterPFor("TensorArrayReadV3")
3023def _convert_tensor_array_read_v3(pfor_input):
3024  handle = pfor_input.unstacked_input(0)
3025  index, index_stacked, _ = pfor_input.input(1)
3026  dtype = pfor_input.get_attr("dtype")
3027  flow, flow_stacked, _ = pfor_input.input(2)
3028  if flow_stacked:
3029    flow = _unstack_flow(flow)
3030
3031  is_inside_pfor = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
3032  if is_inside_pfor:
3033    # Note that if we are inside a control flow construct inside the pfor, and
3034    # only some of the iterations are doing the read (i.e.
3035    # `all_indices_partitioned` is True), then the read operation should only
3036    # return values for the currently active pfor iterations (`all_indices`
3037    # below). Hence, whenever the returned value is stacked (i.e. `flow` is
3038    # stacked), we may need to do an extra gather after reading the values. Also
3039    # note that if `is_inside` is false, then values in the tensor array are
3040    # unstacked. So the check is only needed in this branch.
3041    all_indices = pfor_input.pfor.all_indices
3042    all_indices_partitioned = pfor_input.pfor.all_indices_partitioned
3043    # Note: flow_stacked indicates if values in the TensorArray are stacked or
3044    # not.
3045    if index_stacked:
3046      if flow_stacked:
3047        raise ValueError(
3048            "It looks like TensorArrayReadV3 was called on a TensorArray whose"
3049            " values are not loop-invariant, and the read indices were also"
3050            " not loop invariant. This is currently unsupported.")
3051      value = data_flow_ops.tensor_array_gather_v3(
3052          handle, index, flow, dtype=dtype)
3053      return wrap(value, True)
3054    value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype)
3055    if flow_stacked and all_indices_partitioned:
3056      value = array_ops.gather(value, all_indices)
3057    return wrap(value, flow_stacked)
3058  # Values in the TensorArray should be unstacked (since different iterations
3059  # couldn't write to the same location). So whether output is stacked or not
3060  # depends on index_stacked.
3061  if index_stacked:
3062    value = data_flow_ops.tensor_array_gather_v3(
3063        handle, index, flow, dtype=dtype)
3064  else:
3065    value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype)
3066  return wrap(value, index_stacked)
3067
3068
3069@RegisterPFor("TensorArrayWriteV3")
3070def _convert_tensor_array_write_v3(pfor_input):
3071  handle = pfor_input.unstacked_input(0)
3072  index, index_stacked, _ = pfor_input.input(1)
3073  value, value_stacked, _ = pfor_input.input(2)
3074  flow, flow_stacked, _ = pfor_input.input(3)
3075  if value_stacked and pfor_input.pfor.all_indices_partitioned:
3076    # Looks like we are in a control flow in a pfor where not all iterations are
3077    # active now. We don't allow that since that could lead to different indices
3078    # having different shapes which will be hard to merge later.
3079    raise ValueError("Writing non loop invariant values to TensorArray from "
3080                     "inside a while_loop/cond not supported.")
3081  if flow_stacked:
3082    flow = _unstack_flow(flow)
3083  is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
3084  if is_inside:
3085    if index_stacked:
3086      raise ValueError("Need indices for %s to be loop invariant" % handle)
3087    if not flow_stacked and not value_stacked:
3088      flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow)
3089      return wrap(flow_out, False)
3090    else:
3091      if not value_stacked:
3092        value = _stack(value, pfor_input.pfor.loop_len_vector).t
3093      # TODO(agarwal): Note that if flow is unstacked and value is stacked, then
3094      # this may or may not be a safe situation. flow is unstacked both for a
3095      # freshly created TensorArray, as well as after unstacked values are
3096      # written to it. If it is the latter, then we cannot write a stacked value
3097      # now since that may cause runtime errors due to different shapes in the
3098      # array. At the moment we are not able to handle this gracefully and
3099      # distinguish between the two cases. That would require some heuristic
3100      # traversal of the graph to figure out whether all the writes are
3101      # unstacked or not.
3102      flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow)
3103      return _stack(flow_out, pfor_input.pfor.loop_len_vector)
3104  else:
3105    if not index_stacked:
3106      raise ValueError("Need indices for %s to be not loop invariant" % handle)
3107    # Note that even when index_stacked is true, actual values in index may
3108    # still not be unique. However that will cause runtime error when executing
3109    # the scatter operation below.
3110    if not value_stacked:
3111      value = _stack(value, pfor_input.pfor.loop_len_vector).t
3112    flow_out = data_flow_ops.tensor_array_scatter_v3(handle, index, value, flow)
3113    return _stack(flow_out, pfor_input.pfor.loop_len_vector)
3114
3115
3116def _transpose_first_two_dims(value):
3117  # TODO(agarwal): optimize if one of the dims == 1.
3118  value_shape = array_ops.shape(value)
3119  v0 = value_shape[0]
3120  v1 = value_shape[1]
3121  value = array_ops.reshape(value, [v0, v1, -1])
3122  value = array_ops.transpose(value, [1, 0, 2])
3123  new_shape = array_ops.concat([[v1, v0], value_shape[2:]], axis=0)
3124  return array_ops.reshape(value, new_shape)
3125
3126
3127@RegisterPFor("TensorArrayGatherV3")
3128def _convert_tensor_array_gather_v3(pfor_input):
3129  handle = pfor_input.unstacked_input(0)
3130  indices, indices_stacked, _ = pfor_input.input(1)
3131  indices = array_ops.reshape(indices, [-1])
3132  flow, flow_stacked, _ = pfor_input.input(2)
3133  if flow_stacked:
3134    flow = _unstack_flow(flow)
3135  dtype = pfor_input.get_attr("dtype")
3136  # TODO(agarwal): support element_shape attr?
3137
3138  n = pfor_input.pfor.loop_len_vector
3139  value = data_flow_ops.tensor_array_gather_v3(
3140      handle, indices, flow, dtype=dtype)
3141  is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
3142  if is_inside:
3143    # flow_stacked indicates if values in the TensorArray are stacked or not.
3144    if indices_stacked:
3145      if flow_stacked:
3146        raise ValueError(
3147            "It looks like TensorArrayGatherV3 was called on a TensorArray "
3148            "whose values are not loop-invariant, and the indices were also "
3149            "not loop invariant. This is currently unsupported.")
3150      else:
3151        value = _unflatten_first_dim(value, n)
3152        return wrap(value, True)
3153    else:
3154      if flow_stacked:
3155        # Since elements in this array are stacked and `value` was produced by
3156        # gather, its first two dims are "gathered elements" and "stack
3157        # dimension". Our semantics require these two to be flipped.
3158        value = _transpose_first_two_dims(value)
3159      return wrap(value, flow_stacked)
3160  else:
3161    # Values in the TensorArray should be unstacked (since different iterations
3162    # couldn't write to the same location). So whether output is stacked or not
3163    # depends on indices_stacked.
3164    if indices_stacked:
3165      value = _unflatten_first_dim(value, n)
3166    return wrap(value, indices_stacked)
3167
3168
3169@RegisterPFor("TensorArrayScatterV3")
3170def _convert_tensor_array_scatter_v3(pfor_input):
3171  handle = pfor_input.unstacked_input(0)
3172  indices, indices_stacked, _ = pfor_input.input(1)
3173  indices = array_ops.reshape(indices, [-1])
3174  value, value_stacked, _ = pfor_input.input(2)
3175  flow, flow_stacked, _ = pfor_input.input(3)
3176
3177  if flow_stacked:
3178    flow = _unstack_flow(flow)
3179
3180  is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
3181  if is_inside:
3182    if indices_stacked:
3183      raise ValueError("Need indices for %s to be loop invariant" % handle)
3184    # Note that flow_stacked indicates if existing values in the array are
3185    # stacked or not.
3186    if not flow_stacked and not value_stacked:
3187      flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value,
3188                                                       flow)
3189      return wrap(flow_out, False)
3190    if not value_stacked:
3191      # TODO(agarwal): tile in the second dimension directly instead of
3192      # transposing below.
3193      value = _stack(value, pfor_input.pfor.loop_len_vector).t
3194
3195    value = _transpose_first_two_dims(value)
3196    # TODO(agarwal): Note that if a previous write was unstacked, flow will be
3197    # unstacked, and a stacked value may be written here which may cause
3198    # runtime error due to different elements having different shape. We do
3199    # not try to prevent that.
3200    flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value,
3201                                                     flow)
3202    return _stack(flow_out, pfor_input.pfor.loop_len_vector)
3203  if not indices_stacked:
3204    raise ValueError("Need indices for %s to be not loop invariant" % handle)
3205  if not value_stacked:
3206    value = _stack(value, pfor_input.pfor.loop_len_vector).t
3207  value = _flatten_first_two_dims(value)
3208  flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, flow)
3209  return _stack(flow_out, pfor_input.pfor.loop_len_vector)
3210
3211
3212@RegisterPFor("TensorArrayGradV3")
3213def _convert_tensor_array_grad_v3(pfor_input):
3214  handle = pfor_input.unstacked_input(0)
3215  flow, flow_stacked, _ = pfor_input.input(1)
3216  if flow_stacked:
3217    flow = _unstack_flow(flow)
3218  source = pfor_input.get_attr("source")
3219  # TODO(agarwal): For now, we assume that gradients are stacked if the
3220  # TensorArrayGradV3 call is being done inside the pfor. Getting that wrong
3221  # will give runtime error due to incorrect shape being written to the
3222  # accumulator. It is difficult to know in advance if gradients written will be
3223  # stacked or not. Note that flow being stacked is not indicative of the
3224  # gradient being stacked or not. Revisit this later.
3225  shape_to_prepend = pfor_input.pfor.loop_len_vector
3226  grad_handle, flow_out = data_flow_ops.tensor_array_grad_with_shape(
3227      handle=handle,
3228      flow_in=flow,
3229      shape_to_prepend=shape_to_prepend,
3230      source=source)
3231  flow_out = _stack(flow_out, pfor_input.pfor.loop_len_vector).t
3232  return [wrap(grad_handle, False), wrap(flow_out, True)]
3233
3234
3235# StackV2 conversion is tricky since we don't have arrays of StackV2. So similar
3236# to TensorArrays, we convert them by changing the dimension of the elements
3237# inside the stack.
3238#
3239# We consider two cases:
3240#
3241# 1. StackV2 is constructed and used entirely inside the pfor loop.
3242# We keep a single Stack and perform the push/pop operations of all the
3243# iterations in lock-step. We also assume that all the iterations perform these
3244# operations. In case of dynamic control flow, if only some of the iterations
3245# try to perform a push/pop, then the conversion may not work correctly and may
3246# cause undefined behavior.
3247# TODO(agarwal): test StackV2 with dynamic control flow.
3248#
3249# 2. StackV2 is constructed outside the pfor loop.
3250# Performing stack push/pop in a parallel fashion is ill-defined. However given
3251# that reading stacks created externally is a common operation when computing
3252# jacobians, we provide some special semantics here as follows.
3253#  - disallow push operations to the stack
3254#  - pop operations are performed in lock step by all iterations, similar to the
3255#  case when the stack is created inside. A single value is popped during the
3256#  lock-step operation and broadcast to all the iterations. Values in the stack
3257#  are assumed to be loop-invariant.
3258#
3259# Some other implementation details:
3260# We use an ugly logic to find whether values in Stack data structure are
3261# loop invariant or not. When converting push/pop operations, we keep track of
3262# whether the last conversion used a stacked value or not (see _stack_cache
3263# below). As a result if an unstacked value is written first, subsequent stacked
3264# writes are disallowed when they could have been allowed in theory.
3265
3266# Map from cache key based on StackV2 handle to a bool indicating whether values
3267# are stacked or not.
3268# TODO(agarwal): move _stack_cache inside pfor?
3269_stack_cache = {}
3270
3271
3272def _stack_cache_key(pfor_input):
3273  """Create cache key corresponding to a stack handle."""
3274  op_type = pfor_input.op_type
3275  assert op_type in ["StackPushV2", "StackPopV2"], op_type
3276  orig_handle = pfor_input.op.inputs[0]
3277  while orig_handle.op.type in ["Identity", "Enter"]:
3278    orig_handle = orig_handle.op.inputs[0]
3279  assert orig_handle.op.type == "StackV2", orig_handle.op
3280  return ops.get_default_graph(), pfor_input.pfor, orig_handle
3281
3282
3283def _stack_handle_inside_pfor(handle, pfor_input):
3284  while handle.op.type in ["Identity", "Enter"]:
3285    handle = handle.op.inputs[0]
3286  assert handle.op.type == "StackV2", ("Unable to find StackV2 op. Got %s" %
3287                                       handle.op)
3288  return pfor_input.pfor.op_is_inside_loop(handle.op)
3289
3290
3291@RegisterPFor("StackPushV2")
3292def _convert_stack_push_v2(pfor_input):
3293  handle = pfor_input.unstacked_input(0)
3294  elem, elem_stacked, _ = pfor_input.input(1)
3295  swap_memory = pfor_input.get_attr("swap_memory")
3296
3297  if not _stack_handle_inside_pfor(pfor_input.op.inputs[0], pfor_input):
3298    raise ValueError("StackPushV2 not allowed on stacks created outside pfor")
3299  stack_cache_key = _stack_cache_key(pfor_input)
3300  stacked = _stack_cache.get(stack_cache_key, None)
3301  if stacked is None:
3302    stacked = elem_stacked
3303    _stack_cache[stack_cache_key] = stacked
3304  else:
3305    # If we previously made it unstacked then we can't revert to being stacked.
3306    if not stacked and elem_stacked:
3307      raise ValueError(
3308          "It looks like the stack was previously determined to be loop"
3309          " invariant, but we are now trying to push a loop dependent value"
3310          " to it. This is currently unsupported.")
3311    if stacked and not elem_stacked:
3312      elem = _stack(elem, pfor_input.pfor.loop_len_vector).t
3313  out = data_flow_ops.stack_push_v2(handle, elem, swap_memory=swap_memory)
3314  return wrap(out, stacked)
3315
3316
3317# Note that inputs to this convertor will be unstacked. However it should get
3318# called since it is a stateful op.
3319@RegisterPFor("StackPopV2")
3320def _convert_stack_pop_v2(pfor_input):
3321  handle = pfor_input.unstacked_input(0)
3322  stack_cache_key = _stack_cache_key(pfor_input)
3323  stacked = _stack_cache.get(stack_cache_key, None)
3324  # If a StackPushV2 has not been converted yet, we default to unstacked since
3325  # the push could be outside of pfor, or the covertor may not be called if the
3326  # inputs are unconverted.
3327  if stacked is None:
3328    stacked = False
3329    _stack_cache[stack_cache_key] = False
3330  elem_type = pfor_input.get_attr("elem_type")
3331  out = data_flow_ops.stack_pop_v2(handle, elem_type)
3332  return wrap(out, stacked)
3333
3334
3335# parsing_ops
3336
3337
3338@RegisterPFor("DecodeCSV")
3339def _convert_decode_csv(pfor_input):
3340  lines = pfor_input.stacked_input(0)
3341  record_defaults = [
3342      pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs)
3343  ]
3344  field_delim = pfor_input.get_attr("field_delim")
3345  use_quote_delim = pfor_input.get_attr("use_quote_delim")
3346  select_cols = pfor_input.get_attr("select_cols")
3347  if not select_cols:
3348    select_cols = None
3349  return [
3350      wrap(t, True) for t in parsing_ops.decode_csv(
3351          lines,
3352          record_defaults,
3353          field_delim=field_delim,
3354          use_quote_delim=use_quote_delim,
3355          select_cols=select_cols)
3356  ]
3357
3358
3359@RegisterPFor("ParseSingleExample")
3360def _convert_parse_single_example(pfor_input):
3361  serialized = pfor_input.stacked_input(0)
3362  dense_defaults = [
3363      pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs)
3364  ]
3365  sparse_keys = pfor_input.get_attr("sparse_keys")
3366  dense_keys = pfor_input.get_attr("dense_keys")
3367  sparse_types = pfor_input.get_attr("sparse_types")
3368  dense_shapes = pfor_input.get_attr("dense_shapes")
3369  output = gen_parsing_ops.parse_example(
3370      serialized=serialized,
3371      names=[],
3372      dense_defaults=dense_defaults,
3373      sparse_keys=sparse_keys,
3374      dense_keys=dense_keys,
3375      sparse_types=sparse_types,
3376      dense_shapes=dense_shapes)
3377  return [wrap(t, True, True) for t in nest.flatten(output)]
3378
3379
3380@RegisterPFor("ParseExampleV2")
3381def _convert_parse_example_v2(pfor_input):
3382  serialized = pfor_input.stacked_input(0)
3383  sparse_keys = pfor_input.unstacked_input(2)
3384  dense_keys = pfor_input.unstacked_input(3)
3385  ragged_keys = pfor_input.unstacked_input(4)
3386  dense_defaults = [
3387      pfor_input.unstacked_input(i) for i in range(5, pfor_input.num_inputs)
3388  ]
3389  num_sparse = pfor_input.get_attr("num_sparse")
3390  sparse_types = pfor_input.get_attr("sparse_types")
3391  ragged_value_types = pfor_input.get_attr("ragged_value_types")
3392  ragged_split_types = pfor_input.get_attr("ragged_split_types")
3393  dense_shapes = pfor_input.get_attr("dense_shapes")
3394  if serialized.shape.ndims not in (None, 1):
3395    raise ValueError("ParseExampleV2 can only be converted if `serialized` "
3396                     "is scalar.")
3397  output = gen_parsing_ops.parse_example_v2(
3398      serialized=serialized,
3399      names=[],
3400      sparse_keys=sparse_keys,
3401      dense_keys=dense_keys,
3402      ragged_keys=ragged_keys,
3403      dense_defaults=dense_defaults,
3404      num_sparse=num_sparse,
3405      sparse_types=sparse_types,
3406      ragged_value_types=ragged_value_types,
3407      ragged_split_types=ragged_split_types,
3408      dense_shapes=dense_shapes)
3409  return [wrap(t, True, True) for t in nest.flatten(output)]
3410
3411
3412# functional_ops
3413
3414
3415@RegisterPFor("StatefulPartitionedCall")
3416@RegisterPFor("PartitionedCall")
3417def _convert_partitioned_call(pfor_input):
3418  func_name = pfor_input.get_attr("f").name
3419  func = pfor_input.op.graph._get_function(compat.as_bytes(func_name))
3420  assert isinstance(func.graph, func_graph.FuncGraph), (
3421      "Could not find FuncGraph object for %s. Got func %s" % (func_name, func))
3422  pfor = pfor_input.pfor
3423  converter = PFor(
3424      loop_var=pfor.loop_var,
3425      loop_len=pfor.loop_len_vector[0],
3426      pfor_ops=func.graph.get_operations(),
3427      all_indices=pfor.all_indices,
3428      all_indices_partitioned=pfor.all_indices_partitioned,
3429      pfor_config=pfor.pfor_config)
3430
3431  # TODO(agarwal): consider caching this function definition.
3432  @def_function.function
3433  def f(*args):
3434    assert all(isinstance(arg, WrappedTensor) for arg in args), args
3435    assert len(args) == len(func.graph.inputs), (args, func.graph.inputs)
3436    #  Map inputs to function arguments.
3437    for inp, arg in zip(func.graph.inputs, args):
3438      converter._add_conversion(inp, arg)
3439    # Convert output tensors.
3440    return tuple(
3441        [converter._convert_helper(x).t for x in func._func_graph_outputs])
3442
3443  call_outputs = f(*pfor_input.inputs)
3444  assert len(call_outputs) == len(func._func_graph_outputs)
3445  outputs = []
3446  for call_output, output_tensor in zip(call_outputs, func._func_graph_outputs):
3447    func_output = converter._convert_helper(output_tensor)
3448    outputs.append(
3449        wrap(call_output, func_output.is_stacked,
3450             func_output.is_sparse_stacked))
3451  return outputs
3452