• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Functional operations."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.framework import attr_value_pb2
22from tensorflow.python.eager import context
23from tensorflow.python.framework import auto_control_deps_utils as acd
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import function
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import gen_functional_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import tensor_array_ops
34from tensorflow.python.ops import variable_scope as vs
35# pylint: disable=unused-import
36from tensorflow.python.ops.gen_functional_ops import remote_call
37# pylint: enable=unused-import
38from tensorflow.python.ops.gen_functional_ops import symbolic_gradient
39from tensorflow.python.util import compat
40from tensorflow.python.util import deprecation
41from tensorflow.python.util import dispatch
42from tensorflow.python.util import function_utils
43from tensorflow.python.util import nest
44from tensorflow.python.util.tf_export import tf_export
45
46
47# TODO(yuanbyu, mrry): Handle stride to support sliding windows.
48@tf_export(v1=["foldl"])
49@dispatch.add_dispatch_support
50def foldl(fn,
51          elems,
52          initializer=None,
53          parallel_iterations=10,
54          back_prop=True,
55          swap_memory=False,
56          name=None):
57  """foldl on the list of tensors unpacked from `elems` on dimension 0.
58
59  This foldl operator repeatedly applies the callable `fn` to a sequence
60  of elements from first to last. The elements are made of the tensors
61  unpacked from `elems` on dimension 0. The callable fn takes two tensors as
62  arguments. The first argument is the accumulated value computed from the
63  preceding invocation of fn, and the second is the value at the current
64  position of `elems`. If `initializer` is None, `elems` must contain at least
65  one element, and its first element is used as the initializer.
66
67  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
68  of the result tensor is fn(initializer, values[0]).shape`.
69
70  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
71  is a (possibly nested) list or tuple of tensors, then each of these tensors
72  must have a matching first (unpack) dimension.  The signature of `fn` may
73  match the structure of `elems`.  That is, if `elems` is
74  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
75  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
76
77  Args:
78    fn: The callable to be performed.
79    elems: A tensor or (possibly nested) sequence of tensors, each of which will
80      be unpacked along their first dimension.  The nested sequence of the
81      resulting slices will be the first argument to `fn`.
82    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
83      as the initial value for the accumulator.
84    parallel_iterations: (optional) The number of iterations allowed to run in
85      parallel.
86    back_prop: (optional) True enables support for back propagation.
87    swap_memory: (optional) True enables GPU-CPU memory swapping.
88    name: (optional) Name prefix for the returned tensors.
89
90  Returns:
91    A tensor or (possibly nested) sequence of tensors, resulting from applying
92    `fn` consecutively to the list of tensors unpacked from `elems`, from first
93    to last.
94
95  Raises:
96    TypeError: if `fn` is not callable.
97
98  Example:
99    ```python
100    elems = tf.constant([1, 2, 3, 4, 5, 6])
101    sum = foldl(lambda a, x: a + x, elems)
102    # sum == 21
103    ```
104  """
105  if not callable(fn):
106    raise TypeError(
107        f"{fn.__name__} is not callable. Please provide a callable function.")
108
109  def create_ta(elem):
110    return tensor_array_ops.TensorArray(
111        dtype=elem.dtype, size=n, dynamic_size=False,
112        infer_shape=True).unstack(elem)
113
114  in_graph_mode = not context.executing_eagerly()
115  with ops.name_scope(name, "foldl", [elems]):
116    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
117    # supported in Eager
118    if in_graph_mode:
119      # Any get_variable calls in fn will cache the first call locally
120      # and not issue repeated network I/O requests for each iteration.
121      varscope = vs.get_variable_scope()
122      varscope_caching_device_was_none = False
123      if varscope.caching_device is None:
124        # TODO(ebrevdo): Change to using colocate_with here and in other
125        # methods.
126        varscope.set_caching_device(lambda op: op.device)
127        varscope_caching_device_was_none = True
128
129    # Convert elems to tensor array. n may be known statically.
130    elems_flat = [
131        ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems)
132    ]
133    n = (
134        tensor_shape.dimension_value(elems_flat[0].shape[0]) or
135        array_ops.shape(elems_flat[0])[0])
136
137    elems_ta = nest.map_structure(create_ta, elems)
138
139    if initializer is None:
140      a = nest.map_structure(lambda elem: elem.read(0), elems_ta)
141      i = constant_op.constant(1)
142    else:
143      a = initializer
144      i = constant_op.constant(0)
145
146    def compute(i, a):
147      elem_i = nest.map_structure(lambda elem: elem.read(i), elems_ta)
148      a = fn(a, elem_i)
149      return [i + 1, a]
150
151    _, r_a = control_flow_ops.while_loop(
152        lambda i, a: i < n,
153        compute, [i, a],
154        parallel_iterations=parallel_iterations,
155        back_prop=back_prop,
156        swap_memory=swap_memory,
157        maximum_iterations=n)
158
159    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
160    # supported in Eager
161    if in_graph_mode and varscope_caching_device_was_none:
162      varscope.set_caching_device(None)
163
164    return r_a
165
166
167@tf_export("foldl", v1=[])
168@dispatch.add_dispatch_support
169@deprecation.deprecated_arg_values(
170    None,
171    """back_prop=False is deprecated. Consider using tf.stop_gradient instead.
172Instead of:
173results = tf.foldl(fn, elems, back_prop=False)
174Use:
175results = tf.nest.map_structure(tf.stop_gradient, tf.foldl(fn, elems))""",
176    warn_once=True,
177    back_prop=False)
178def foldl_v2(fn,
179             elems,
180             initializer=None,
181             parallel_iterations=10,
182             back_prop=True,
183             swap_memory=False,
184             name=None):
185  """foldl on the list of tensors unpacked from `elems` on dimension 0.
186
187  This foldl operator repeatedly applies the callable `fn` to a sequence
188  of elements from first to last. The elements are made of the tensors
189  unpacked from `elems` on dimension 0. The callable fn takes two tensors as
190  arguments. The first argument is the accumulated value computed from the
191  preceding invocation of fn, and the second is the value at the current
192  position of `elems`. If `initializer` is None, `elems` must contain at least
193  one element, and its first element is used as the initializer.
194
195  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
196  of the result tensor is fn(initializer, values[0]).shape`.
197
198  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
199  is a (possibly nested) list or tuple of tensors, then each of these tensors
200  must have a matching first (unpack) dimension.  The signature of `fn` may
201  match the structure of `elems`.  That is, if `elems` is
202  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
203  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
204
205  Args:
206    fn: The callable to be performed.
207    elems: A tensor or (possibly nested) sequence of tensors, each of which will
208      be unpacked along their first dimension.  The nested sequence of the
209      resulting slices will be the first argument to `fn`.
210    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
211      as the initial value for the accumulator.
212    parallel_iterations: (optional) The number of iterations allowed to run in
213      parallel.
214    back_prop: (optional) Deprecated. False disables support for back
215      propagation. Prefer using `tf.stop_gradient` instead.
216    swap_memory: (optional) True enables GPU-CPU memory swapping.
217    name: (optional) Name prefix for the returned tensors.
218
219  Returns:
220    A tensor or (possibly nested) sequence of tensors, resulting from applying
221    `fn` consecutively to the list of tensors unpacked from `elems`, from first
222    to last.
223
224  Raises:
225    TypeError: if `fn` is not callable.
226
227  Example:
228    ```python
229    elems = tf.constant([1, 2, 3, 4, 5, 6])
230    sum = foldl(lambda a, x: a + x, elems)
231    # sum == 21
232    ```
233  """
234  return foldl(
235      fn=fn,
236      elems=elems,
237      initializer=initializer,
238      parallel_iterations=parallel_iterations,
239      back_prop=back_prop,
240      swap_memory=swap_memory,
241      name=name)
242
243
244@tf_export(v1=["foldr"])
245@dispatch.add_dispatch_support
246def foldr(fn,
247          elems,
248          initializer=None,
249          parallel_iterations=10,
250          back_prop=True,
251          swap_memory=False,
252          name=None):
253  """foldr on the list of tensors unpacked from `elems` on dimension 0.
254
255  This foldr operator repeatedly applies the callable `fn` to a sequence
256  of elements from last to first. The elements are made of the tensors
257  unpacked from `elems`. The callable fn takes two tensors as arguments.
258  The first argument is the accumulated value computed from the preceding
259  invocation of fn, and the second is the value at the current position of
260  `elems`. If `initializer` is None, `elems` must contain at least one element,
261  and its first element is used as the initializer.
262
263  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
264  of the result tensor is `fn(initializer, values[0]).shape`.
265
266  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
267  is a (possibly nested) list or tuple of tensors, then each of these tensors
268  must have a matching first (unpack) dimension.  The signature of `fn` may
269  match the structure of `elems`.  That is, if `elems` is
270  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
271  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
272
273  Args:
274    fn: The callable to be performed.
275    elems: A tensor or (possibly nested) sequence of tensors, each of which will
276      be unpacked along their first dimension.  The nested sequence of the
277      resulting slices will be the first argument to `fn`.
278    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
279      as the initial value for the accumulator.
280    parallel_iterations: (optional) The number of iterations allowed to run in
281      parallel.
282    back_prop: (optional) True enables support for back propagation.
283    swap_memory: (optional) True enables GPU-CPU memory swapping.
284    name: (optional) Name prefix for the returned tensors.
285
286  Returns:
287    A tensor or (possibly nested) sequence of tensors, resulting from applying
288    `fn` consecutively to the list of tensors unpacked from `elems`, from last
289    to first.
290
291  Raises:
292    TypeError: if `fn` is not callable.
293
294  Example:
295    ```python
296    elems = [1, 2, 3, 4, 5, 6]
297    sum = foldr(lambda a, x: a + x, elems)
298    # sum == 21
299    ```
300  """
301  if not callable(fn):
302    raise TypeError(
303        f"{fn.__name__} is not callable. Please provide a callable function.")
304
305  def create_ta(elem):
306    return tensor_array_ops.TensorArray(
307        dtype=elem.dtype, size=n, dynamic_size=False,
308        infer_shape=True).unstack(elem)
309
310  in_graph_mode = not context.executing_eagerly()
311  with ops.name_scope(name, "foldr", [elems]):
312    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
313    # supported in Eager
314    if in_graph_mode:
315      # Any get_variable calls in fn will cache the first call locally and not
316      # issue repeated network I/O requests for each iteration.
317      varscope = vs.get_variable_scope()
318      varscope_caching_device_was_none = False
319      if varscope.caching_device is None:
320        # TODO(ebrevdo): Change to using colocate_with here and in other
321        # methods.
322        varscope.set_caching_device(lambda op: op.device)
323        varscope_caching_device_was_none = True
324
325    # Convert elems to tensor array. n may be known statically.
326    elems_flat = [
327        ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems)
328    ]
329    n = (
330        tensor_shape.dimension_value(elems_flat[0].shape[0]) or
331        array_ops.shape(elems_flat[0])[0])
332
333    elems_ta = nest.map_structure(create_ta, elems)
334
335    if initializer is None:
336      i = n - 1
337      a = nest.map_structure(lambda elem: elem.read(i), elems_ta)
338    else:
339      i = n
340      a = initializer
341
342    def compute(i, a):
343      i -= 1
344      elem = nest.map_structure(lambda elem: elem.read(i), elems_ta)
345      a_out = fn(a, elem)
346      return [i, a_out]
347
348    _, r_a = control_flow_ops.while_loop(
349        lambda i, a: i > 0,
350        compute, [i, a],
351        parallel_iterations=parallel_iterations,
352        back_prop=back_prop,
353        swap_memory=swap_memory,
354        maximum_iterations=n)
355
356    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
357    # supported in Eager
358    if in_graph_mode and varscope_caching_device_was_none:
359      varscope.set_caching_device(None)
360
361    return r_a
362
363
364@tf_export("foldr", v1=[])
365@dispatch.add_dispatch_support
366@deprecation.deprecated_arg_values(
367    None,
368    """back_prop=False is deprecated. Consider using tf.stop_gradient instead.
369Instead of:
370results = tf.foldr(fn, elems, back_prop=False)
371Use:
372results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))""",
373    warn_once=True,
374    back_prop=False)
375def foldr_v2(fn,
376             elems,
377             initializer=None,
378             parallel_iterations=10,
379             back_prop=True,
380             swap_memory=False,
381             name=None):
382  """foldr on the list of tensors unpacked from `elems` on dimension 0.
383
384  This foldr operator repeatedly applies the callable `fn` to a sequence
385  of elements from last to first. The elements are made of the tensors
386  unpacked from `elems`. The callable fn takes two tensors as arguments.
387  The first argument is the accumulated value computed from the preceding
388  invocation of fn, and the second is the value at the current position of
389  `elems`. If `initializer` is None, `elems` must contain at least one element,
390  and its first element is used as the initializer.
391
392  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
393  of the result tensor is `fn(initializer, values[0]).shape`.
394
395  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
396  is a (possibly nested) list or tuple of tensors, then each of these tensors
397  must have a matching first (unpack) dimension.  The signature of `fn` may
398  match the structure of `elems`.  That is, if `elems` is
399  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
400  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
401
402  Args:
403    fn: The callable to be performed.
404    elems: A tensor or (possibly nested) sequence of tensors, each of which will
405      be unpacked along their first dimension.  The nested sequence of the
406      resulting slices will be the first argument to `fn`.
407    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
408      as the initial value for the accumulator.
409    parallel_iterations: (optional) The number of iterations allowed to run in
410      parallel.
411    back_prop: (optional) Deprecated. False disables support for back
412      propagation. Prefer using `tf.stop_gradient` instead.
413    swap_memory: (optional) True enables GPU-CPU memory swapping.
414    name: (optional) Name prefix for the returned tensors.
415
416  Returns:
417    A tensor or (possibly nested) sequence of tensors, resulting from applying
418    `fn` consecutively to the list of tensors unpacked from `elems`, from last
419    to first.
420
421  Raises:
422    TypeError: if `fn` is not callable.
423
424  Example:
425    ```python
426    elems = [1, 2, 3, 4, 5, 6]
427    sum = foldr(lambda a, x: a + x, elems)
428    # sum == 21
429    ```
430  """
431  return foldr(
432      fn=fn,
433      elems=elems,
434      initializer=initializer,
435      parallel_iterations=parallel_iterations,
436      back_prop=back_prop,
437      swap_memory=swap_memory,
438      name=name)
439
440
441@tf_export(v1=["scan"])
442@dispatch.add_dispatch_support
443def scan(fn,
444         elems,
445         initializer=None,
446         parallel_iterations=10,
447         back_prop=True,
448         swap_memory=False,
449         infer_shape=True,
450         reverse=False,
451         name=None):
452  """scan on the list of tensors unpacked from `elems` on dimension 0.
453
454  See also `tf.map_fn`.
455
456  The simplest version of `scan` repeatedly applies the callable `fn` to a
457  sequence of elements from first to last. The elements are made of the tensors
458  unpacked from `elems` on dimension 0. The callable fn takes two tensors as
459  arguments. The first argument is the accumulated value computed from the
460  preceding invocation of fn, and the second is the value at the current
461  position of `elems`. If `initializer` is None, `elems` must contain at least
462  one element, and its first element is used as the initializer.
463
464  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
465  of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
466  If reverse=True, it's fn(initializer, values[-1]).shape.
467
468  This method also allows multi-arity `elems` and accumulator.  If `elems`
469  is a (possibly nested) list or tuple of tensors, then each of these tensors
470  must have a matching first (unpack) dimension.  The second argument of
471  `fn` must match the structure of `elems`.
472
473  If no `initializer` is provided, the output structure and dtypes of `fn`
474  are assumed to be the same as its input; and in this case, the first
475  argument of `fn` must match the structure of `elems`.
476
477  If an `initializer` is provided, then the output of `fn` must have the same
478  structure as `initializer`; and the first argument of `fn` must match
479  this structure.
480
481  For example, if `elems` is `(t1, [t2, t3])` and `initializer` is
482  `[i1, i2]` then an appropriate signature for `fn` in `python2` is:
483  `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list,
484  `[acc_n1, acc_n2]`.  An alternative correct signature for `fn`, and the
485   one that works in `python3`, is:
486  `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples.
487
488  Args:
489    fn: The callable to be performed.  It accepts two arguments.  The first will
490      have the same structure as `initializer` if one is provided, otherwise it
491      will have the same structure as `elems`.  The second will have the same
492      (possibly nested) structure as `elems`.  Its output must have the same
493      structure as `initializer` if one is provided, otherwise it must have the
494      same structure as `elems`.
495    elems: A tensor or (possibly nested) sequence of tensors, each of which will
496      be unpacked along their first dimension.  The nested sequence of the
497      resulting slices will be the first argument to `fn`.
498    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
499      initial value for the accumulator, and the expected output type of `fn`.
500    parallel_iterations: (optional) The number of iterations allowed to run in
501      parallel.
502    back_prop: (optional) True enables support for back propagation.
503    swap_memory: (optional) True enables GPU-CPU memory swapping.
504    infer_shape: (optional) False disables tests for consistent output shapes.
505    reverse: (optional) True scans the tensor last to first (instead of first to
506      last).
507    name: (optional) Name prefix for the returned tensors.
508
509  Returns:
510    A tensor or (possibly nested) sequence of tensors.  Each tensor packs the
511    results of applying `fn` to tensors unpacked from `elems` along the first
512    dimension, and the previous accumulator value(s), from first to last (or
513    last to first, if `reverse=True`).
514
515  Raises:
516    TypeError: if `fn` is not callable or the structure of the output of
517      `fn` and `initializer` do not match.
518    ValueError: if the lengths of the output of `fn` and `initializer`
519      do not match.
520
521  Examples:
522    ```python
523    elems = np.array([1, 2, 3, 4, 5, 6])
524    sum = scan(lambda a, x: a + x, elems)
525    # sum == [1, 3, 6, 10, 15, 21]
526    sum = scan(lambda a, x: a + x, elems, reverse=True)
527    # sum == [21, 20, 18, 15, 11, 6]
528    ```
529
530    ```python
531    elems = np.array([1, 2, 3, 4, 5, 6])
532    initializer = np.array(0)
533    sum_one = scan(
534        lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer)
535    # sum_one == [1, 2, 3, 4, 5, 6]
536    ```
537
538    ```python
539    elems = np.array([1, 0, 0, 0, 0, 0])
540    initializer = (np.array(0), np.array(1))
541    fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer)
542    # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13])
543    ```
544  """
545  if not callable(fn):
546    raise TypeError(
547        f"{fn.__name__} is not callable. Please provide a callable function.")
548
549  input_is_sequence = nest.is_sequence(elems)
550  input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]
551
552  def input_pack(x):
553    return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]
554
555  if initializer is None:
556    output_is_sequence = input_is_sequence
557    output_flatten = input_flatten
558    output_pack = input_pack
559  else:
560    output_is_sequence = nest.is_sequence(initializer)
561    output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x]
562
563    def output_pack(x):
564      return (nest.pack_sequence_as(initializer, x)
565              if output_is_sequence else x[0])
566
567  elems_flat = input_flatten(elems)
568
569  in_graph_mode = not context.executing_eagerly()
570  with ops.name_scope(name, "scan", elems_flat):
571    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
572    # supported in Eager
573    if in_graph_mode:
574      # Any get_variable calls in fn will cache the first call locally
575      # and not issue repeated network I/O requests for each iteration.
576      varscope = vs.get_variable_scope()
577      varscope_caching_device_was_none = False
578      if varscope.caching_device is None:
579        # TODO(ebrevdo): Change to using colocate_with here and in other
580        # methods.
581        varscope.set_caching_device(lambda op: op.device)
582        varscope_caching_device_was_none = True
583
584    # Convert elems to tensor array.
585    elems_flat = [
586        ops.convert_to_tensor(elem, name="elem") for elem in elems_flat
587    ]
588
589    # Convert elems to tensor array. n may be known statically.
590    n = tensor_shape.dimension_value(elems_flat[0].shape[0])
591    if n is None:
592      n = array_ops.shape(elems_flat[0])[0]
593
594    # TensorArrays are always flat
595    elems_ta = [
596        tensor_array_ops.TensorArray(
597            dtype=elem.dtype,
598            size=n,
599            dynamic_size=False,
600            element_shape=elem.shape[1:],
601            infer_shape=True) for elem in elems_flat
602    ]
603    # Unpack elements
604    elems_ta = [
605        elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat)
606    ]
607
608    if initializer is None:
609      a_flat = [elem.read(n - 1 if reverse else 0) for elem in elems_ta]
610      i = 1
611    else:
612      initializer_flat = output_flatten(initializer)
613      a_flat = [ops.convert_to_tensor(init) for init in initializer_flat]
614      i = 0
615
616    # Create a tensor array to store the intermediate values.
617    accs_ta = [
618        tensor_array_ops.TensorArray(
619            dtype=init.dtype,
620            size=n,
621            element_shape=init.shape if infer_shape else None,
622            dynamic_size=False,
623            infer_shape=infer_shape) for init in a_flat
624    ]
625
626    if initializer is None:
627      accs_ta = [
628          acc_ta.write(n - 1 if reverse else 0, a)
629          for (acc_ta, a) in zip(accs_ta, a_flat)
630      ]
631
632    def compute(i, a_flat, tas):
633      """The loop body of scan.
634
635      Args:
636        i: the loop counter.
637        a_flat: the accumulator value(s), flattened.
638        tas: the output accumulator TensorArray(s), flattened.
639
640      Returns:
641        [i + 1, a_flat, tas]: the updated counter + new accumulator values +
642          updated TensorArrays
643
644      Raises:
645        TypeError: if initializer and fn() output structure do not match
646        ValueType: if initializer and fn() output lengths do not match
647      """
648      packed_elems = input_pack([elem_ta.read(i) for elem_ta in elems_ta])
649      packed_a = output_pack(a_flat)
650      a_out = fn(packed_a, packed_elems)
651      nest.assert_same_structure(elems if initializer is None else initializer,
652                                 a_out)
653      flat_a_out = output_flatten(a_out)
654      tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)]
655      if reverse:
656        next_i = i - 1
657      else:
658        next_i = i + 1
659      return (next_i, flat_a_out, tas)
660
661    if reverse:
662      initial_i = n - 1 - i
663      condition = lambda i, _1, _2: i >= 0
664    else:
665      initial_i = i
666      condition = lambda i, _1, _2: i < n
667    _, _, r_a = control_flow_ops.while_loop(
668        condition,
669        compute, (initial_i, a_flat, accs_ta),
670        parallel_iterations=parallel_iterations,
671        back_prop=back_prop,
672        swap_memory=swap_memory,
673        maximum_iterations=n)
674
675    results_flat = [r.stack() for r in r_a]
676
677    n_static = tensor_shape.Dimension(
678        tensor_shape.dimension_value(
679            elems_flat[0].get_shape().with_rank_at_least(1)[0]))
680    for elem in elems_flat[1:]:
681      n_static.assert_is_compatible_with(
682          tensor_shape.Dimension(
683              tensor_shape.dimension_value(
684                  elem.get_shape().with_rank_at_least(1)[0])))
685    for r in results_flat:
686      r.set_shape(
687          tensor_shape.TensorShape(n_static).concatenate(r.get_shape()[1:]))
688
689    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
690    # supported in Eager
691    if in_graph_mode and varscope_caching_device_was_none:
692      varscope.set_caching_device(None)
693
694    return output_pack(results_flat)
695
696
697@tf_export("scan", v1=[])
698@dispatch.add_dispatch_support
699@deprecation.deprecated_arg_values(
700    None,
701    """back_prop=False is deprecated. Consider using tf.stop_gradient instead.
702Instead of:
703results = tf.scan(fn, elems, back_prop=False)
704Use:
705results = tf.nest.map_structure(tf.stop_gradient, tf.scan(fn, elems))""",
706    warn_once=True,
707    back_prop=False)
708def scan_v2(fn,
709            elems,
710            initializer=None,
711            parallel_iterations=10,
712            back_prop=True,
713            swap_memory=False,
714            infer_shape=True,
715            reverse=False,
716            name=None):
717  """scan on the list of tensors unpacked from `elems` on dimension 0.
718
719  The simplest version of `scan` repeatedly applies the callable `fn` to a
720  sequence of elements from first to last. The elements are made of the tensors
721  unpacked from `elems` on dimension 0. The callable fn takes two tensors as
722  arguments. The first argument is the accumulated value computed from the
723  preceding invocation of fn, and the second is the value at the current
724  position of `elems`. If `initializer` is None, `elems` must contain at least
725  one element, and its first element is used as the initializer.
726
727  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
728  of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
729  If reverse=True, it's fn(initializer, values[-1]).shape.
730
731  This method also allows multi-arity `elems` and accumulator.  If `elems`
732  is a (possibly nested) list or tuple of tensors, then each of these tensors
733  must have a matching first (unpack) dimension.  The second argument of
734  `fn` must match the structure of `elems`.
735
736  If no `initializer` is provided, the output structure and dtypes of `fn`
737  are assumed to be the same as its input; and in this case, the first
738  argument of `fn` must match the structure of `elems`.
739
740  If an `initializer` is provided, then the output of `fn` must have the same
741  structure as `initializer`; and the first argument of `fn` must match
742  this structure.
743
744  For example, if `elems` is `(t1, [t2, t3])` and `initializer` is
745  `[i1, i2]` then an appropriate signature for `fn` in `python2` is:
746  `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list,
747  `[acc_n1, acc_n2]`.  An alternative correct signature for `fn`, and the
748   one that works in `python3`, is:
749  `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples.
750
751  Args:
752    fn: The callable to be performed.  It accepts two arguments.  The first will
753      have the same structure as `initializer` if one is provided, otherwise it
754      will have the same structure as `elems`.  The second will have the same
755      (possibly nested) structure as `elems`.  Its output must have the same
756      structure as `initializer` if one is provided, otherwise it must have the
757      same structure as `elems`.
758    elems: A tensor or (possibly nested) sequence of tensors, each of which will
759      be unpacked along their first dimension.  The nested sequence of the
760      resulting slices will be the first argument to `fn`.
761    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
762      initial value for the accumulator, and the expected output type of `fn`.
763    parallel_iterations: (optional) The number of iterations allowed to run in
764      parallel.
765    back_prop: (optional) Deprecated. False disables support for back
766      propagation. Prefer using `tf.stop_gradient` instead.
767    swap_memory: (optional) True enables GPU-CPU memory swapping.
768    infer_shape: (optional) False disables tests for consistent output shapes.
769    reverse: (optional) True scans the tensor last to first (instead of first to
770      last).
771    name: (optional) Name prefix for the returned tensors.
772
773  Returns:
774    A tensor or (possibly nested) sequence of tensors.  Each tensor packs the
775    results of applying `fn` to tensors unpacked from `elems` along the first
776    dimension, and the previous accumulator value(s), from first to last (or
777    last to first, if `reverse=True`).
778
779  Raises:
780    TypeError: if `fn` is not callable or the structure of the output of
781      `fn` and `initializer` do not match.
782    ValueError: if the lengths of the output of `fn` and `initializer`
783      do not match.
784
785  Examples:
786    ```python
787    elems = np.array([1, 2, 3, 4, 5, 6])
788    sum = scan(lambda a, x: a + x, elems)
789    # sum == [1, 3, 6, 10, 15, 21]
790    sum = scan(lambda a, x: a + x, elems, reverse=True)
791    # sum == [21, 20, 18, 15, 11, 6]
792    ```
793
794    ```python
795    elems = np.array([1, 2, 3, 4, 5, 6])
796    initializer = np.array(0)
797    sum_one = scan(
798        lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer)
799    # sum_one == [1, 2, 3, 4, 5, 6]
800    ```
801
802    ```python
803    elems = np.array([1, 0, 0, 0, 0, 0])
804    initializer = (np.array(0), np.array(1))
805    fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer)
806    # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13])
807    ```
808  """
809  return scan(
810      fn=fn,
811      elems=elems,
812      initializer=initializer,
813      parallel_iterations=parallel_iterations,
814      back_prop=back_prop,
815      swap_memory=swap_memory,
816      infer_shape=infer_shape,
817      reverse=reverse,
818      name=name)
819
820
821# pylint: disable=invalid-name
822def If(cond, inputs, then_branch, else_branch, name=None):
823  r"""output = Cond(inputs) ?
824
825  then_branch(inputs) : else_branch(inputs).
826
827  Args:
828    cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is
829      converted to a boolean according to the following rule: if the scalar is a
830        numerical value, non-zero means True and zero means False; if the scalar
831        is a string, non-empty means True and empty means False.
832    inputs: A list of input tensors.
833    then_branch: A function takes 'inputs' and returns a list of tensors, whose
834      types are the same as what else_branch returns.
835    else_branch: A function takes 'inputs' and returns a list of tensors. whose
836      types are the same as what then_branch returns.
837    name: A name for the operation (optional).
838
839  Returns:
840    A list of tensors returned by either then_branch(inputs)
841    or else_branch(inputs).
842  """
843  # pylint: disable=protected-access
844  # Handle the Defun case until users have transitioned to tf.function. Note
845  # that composites may need to be re-packed by the caller.
846  if isinstance(then_branch, function._DefinedFunction):
847    tlist = [_.type for _ in then_branch.definition.signature.output_arg]
848    return gen_functional_ops._if(
849        cond, inputs, tlist, then_branch, else_branch, name=name)
850
851  # We assume that `then_branch` is a ConcreteFunction here.
852  then_out = then_branch.structured_outputs
853  else_out = else_branch.structured_outputs
854
855  # Ensure then/else are the same type of composites to avoid an invalid call
856  # to pack_sequence_as later on.
857  nest.assert_same_structure(then_out, else_out, expand_composites=True)
858
859  tlist = nest.flatten(then_branch.output_dtypes)
860  ret = gen_functional_ops._if(
861      cond, inputs, tlist, then_branch, else_branch, name=name)
862
863  # Re-pack the outputs to restore any CompositeTensors
864  return nest.pack_sequence_as(then_out, ret, expand_composites=True)
865
866
867def Gradient(inputs, f, name=None):
868  r"""Computes the gradient function for function f via backpropagation.
869
870  Args:
871    inputs: A list of tensors of size N + M.
872    f: The function we want to compute the gradient for.  The function 'f' must
873      be a numerical function which takes N inputs and produces M outputs. Its
874      gradient function 'g', which is  a function taking N + M inputs and
875      produces N outputs.  I.e. if we have (y1, y2, ..., yM) = f(x1, x2, ...,
876      xN), then, g is (dL/dx1, dL/dx2, ..., dL/dxN) = g(x1, x2, ..., xN, dL/dy1,
877      dL/dy2, ..., dL/dyM),  where L is a scalar-value function of (x1, x2, ...,
878      xN) (e.g., the loss function). dL/dxi is the partial derivative of L with
879      respect to xi.
880    name: A name for the operation (optional).
881
882  Returns:
883    A list of tensors of size N.
884  """
885  # TODO(zhifengc): Pretty-print the above spec in latex.
886  # TODO(zhfiengc): Needs some math expert to say the comment above better.
887  tlist = [_.type for _ in f.definition.signature.input_arg]
888  return symbolic_gradient(input=inputs, Tout=tlist, f=f, name=name)
889
890
891def _GetInputDtypes(func):
892  """Returns the input dtypes of func, excluding dtypes for captured inputs."""
893  if isinstance(func, function._DefinedFunction):  # pylint: disable=protected-access
894    return func.declared_input_types
895
896  # We assume that `func` is a ConcreteFunction here, but we are not able to
897  # verify since importing eager function library will cause cyclic dependence.
898  #
899  # ConcreteFunction.inputs includes captured inputs.
900  num_non_captured_inputs = len(func.inputs) - len(func.captured_inputs)
901  inputs_without_captured = func.inputs[:num_non_captured_inputs]
902  return [t.dtype for t in inputs_without_captured]
903
904
905def _LoopBodyCaptureWrapper(func):
906  """Returns a wrapper for `func` that handles loop-carried captured inputs."""
907
908  @function.Defun(*_GetInputDtypes(func), func_name="%s_Wrapper" % func.name)
909  def Wrapper(*args):
910    """A wrapper that handles loop-carried captured inputs."""
911    result = func(*args)
912    extra_args = tuple(function.get_extra_args())
913    # Nullary functions return an Operation. Normal functions can't do this
914    # because their return values are converted to Tensors.
915    if isinstance(result, ops.Operation):
916      return extra_args
917    # Unary functions return a single Tensor value.
918    elif not isinstance(result, (list, tuple)):
919      return (result,) + extra_args
920    # N-ary functions return a tuple of Tensors.
921    else:
922      return result + type(result)(extra_args)
923
924  return Wrapper
925
926
927# pylint: disable=invalid-name,protected-access
928def While(input_, cond, body, name=None, hostmem=None):
929  r"""output = input; While (Cond(output)) { output = Body(output) }.
930
931  Args:
932    input_: A list of `Tensor` objects. A list of input tensors whose types are
933      T.
934    cond: . A function takes 'input' and returns a tensor.  If the tensor is a
935      scalar of non-boolean, the scalar is converted to a boolean
936      according to the following rule: if the scalar is a numerical value,
937        non-zero means True and zero means False; if the scalar is a string,
938        non-empty means True and empty means False. If the tensor is not a
939        scalar, non-emptiness means True and False otherwise.
940    body: . A function takes a list of tensors and returns another list tensors.
941      Both lists have the same types as specified by T.
942    name: A name for the operation (optional).
943    hostmem: A list of integer. If i is in the list, input[i] is a host memory
944      tensor.
945
946  Raises:
947    ValueError: if `cond` has implicitly captured inputs or if `cond` and `body`
948      have different signatures.
949
950  Returns:
951    A list of `Tensor` objects. Has the same type as `input`.
952    A list of output tensors whose types are T.
953  """
954  if cond.captured_inputs:
955    raise ValueError(
956        "The 'cond' argument can not have implicitly captured inputs. Received "
957        f"captured_inputs: {cond.captured_inputs}")
958
959  cond_input_types = _GetInputDtypes(cond)
960  body_input_types = _GetInputDtypes(body)
961
962  if cond_input_types != body_input_types:
963    raise ValueError(
964        "The 'cond' and 'body' signatures do not match. Received: "
965        f"cond_input_types={cond_input_types}, body_input_types="
966        f"{body_input_types}")
967
968  if body.captured_inputs:
969    cond_dtypes = list(body_input_types) + [
970        t.dtype for t in body.captured_inputs
971    ]
972
973    @function.Defun(*cond_dtypes, func_name="%s_Wrapper" % cond.name)
974    def CondWrapper(*args):
975      """A wrapper that handles loop-carried captured inputs."""
976      return cond(*args[:len(body_input_types)])
977
978    ret = gen_functional_ops._while(
979        input_ + body.captured_inputs,
980        CondWrapper,
981        _LoopBodyCaptureWrapper(body),
982        name=name)
983    # Slice off the loop-carried captured inputs.
984    ret = ret[:-len(body.captured_inputs)]
985  else:
986    ret = gen_functional_ops._while(input_, cond, body, name=name)
987  if hostmem:
988    input_attr = attr_value_pb2.AttrValue()
989    input_attr.list.i.extend(hostmem)
990    ret[0].op._set_attr("_input_hostmem", input_attr)  # pylint: disable=protected-access
991
992    output_attr = attr_value_pb2.AttrValue()
993    output_attr.list.i.extend(hostmem)
994    ret[0].op._set_attr("_output_hostmem", output_attr)  # pylint: disable=protected-access
995  return ret
996
997
998# b/36459430
999#
1000# Ideally, we do not need this rewrite For loop into a While loop.
1001# However, today, if a While runs on GPU and the condition returns a
1002# boolean, the While kernel crashes. Even if we fix the crash, the
1003# bool needs to be copied between GPU and CPU. So, a for loop is much
1004# preferred when running on GPU.
1005#
1006# On the other hand, For op has no directly XLA kernel. So, when we run
1007# a for loop, we need to rewrite it using a While op.
1008#
1009# It should be possible and probably better to write a XLA C++ kernel
1010# implementing the logic in _ForUsingWhile.
1011def _ForUsingWhile(start,
1012                   limit,
1013                   delta,
1014                   inputs,
1015                   forbody,
1016                   name=None,
1017                   hostmem=None):
1018  """Helper to implement a For loop using a While."""
1019  # To support negative delta (e.g., range(100, 0, -3)), we iterate
1020  # over the range(n) and use iter * delta + start as the real
1021  # iteration index. (e.g., for i in range(34): iter = i * (-3) +
1022  # 100).
1023  d = math_ops.abs(delta)
1024  # XLA on TPUs doesn't support integer division
1025  n = math_ops.cast(
1026      math_ops.cast((math_ops.abs(limit - start) + d - 1), dtypes.float32) /
1027      math_ops.cast(d, dtypes.float32), dtypes.int32)
1028
1029  # Carried loop variables ("extra_args") are implicitly added to the input list
1030  # of the WhileBody function. WhileCond does not call forbody, and so does not
1031  # depend on any of forbody's extra_args. Since WhileCond and WhileBody
1032  # must have identical inputs, we have to augment the cond signature to take
1033  # the same types as the carried loop variables.
1034  body_sig = [dtypes.int32] * 4 + list(forbody.declared_input_types)[1:]
1035
1036  cond_name = "%s_Cond" % forbody.name
1037
1038  @function.Defun(*body_sig, func_name=cond_name)
1039  def WhileCond(i, n, *args):
1040    del args
1041    return i < n
1042
1043  body_name = "%s_Body" % forbody.name
1044
1045  @function.Defun(*body_sig, func_name=body_name)
1046  def WhileBody(i, n, start, delta, *args):
1047    """A While wrapper for forbody that handles loop-carried captured inputs."""
1048    for_result = forbody(start + i * delta, *args)
1049    # Nullary functions return an Operation. Normal functions can't do this
1050    # because their return values are converted to Tensors.
1051    if isinstance(for_result, ops.Operation):
1052      for_result = ()
1053    # Unary functions return a single Tensor value.
1054    elif isinstance(for_result, ops.Tensor):
1055      for_result = (for_result,)
1056    return (i + 1, n, start, delta) + tuple(for_result)
1057
1058  if hostmem is not None:
1059    hostmem = [0, 1, 2, 3] + [(4 + _) for _ in hostmem]
1060  else:
1061    hostmem = [0, 1, 2, 3]
1062
1063  results = While(
1064      input_=[0, n, start, delta] + inputs,
1065      cond=WhileCond,
1066      body=WhileBody,
1067      name=name,
1068      hostmem=hostmem)
1069  # Slice off the loop-carried captured inputs.
1070  return list(results[4:len(results)])
1071
1072
1073def For(start,
1074        limit,
1075        delta,
1076        inputs,
1077        body,
1078        name=None,
1079        hostmem=None,
1080        rewrite_with_while=None):
1081  r"""out = input; for i in range(start, limit, delta) out = body(i, out).
1082
1083  Args:
1084    start: A `Tensor` of type `int32`.
1085    limit: A `Tensor` of type `int32`.
1086    delta: A `Tensor` of type `int32`.
1087    inputs: A list of `Tensor` objects. A list of input tensors whose types are
1088      T.
1089    body: A function takes a list of tensors and returns another list of
1090      tensors. Both lists have the same types as (int32, T...).
1091    name: A name for the operation (optional).
1092    hostmem: A list of integer. If i is in the list, inputs[i] is a host memory
1093      tensor. In other words, (i+1)-th argument of the body function is
1094      expecting a host memory.
1095    rewrite_with_while: If True, using While op to implement the For.
1096
1097  Returns:
1098    A list of `Tensor` objects. Has the same type as `input`.
1099    A list of output tensors whose types are T.
1100  """
1101  if rewrite_with_while:
1102    return _ForUsingWhile(start, limit, delta, inputs, body, name, hostmem)
1103  if body.captured_inputs:
1104    ret = gen_functional_ops._for(
1105        start,
1106        limit,
1107        delta,
1108        inputs + body.captured_inputs,
1109        _LoopBodyCaptureWrapper(body),
1110        name=name)
1111    # Slice off the loop-carried captured inputs.
1112    ret = ret[:-len(body.captured_inputs)]
1113  else:
1114    ret = gen_functional_ops._for(start, limit, delta, inputs, body, name=name)
1115  if hostmem:
1116    num_for_params = 3  # start/limit/delta
1117
1118    input_attr = attr_value_pb2.AttrValue()
1119    input_attr.list.i.extend([num_for_params + i for i in hostmem])
1120    ret[0].op._set_attr("_input_hostmem", input_attr)  # pylint: disable=protected-access
1121
1122    output_attr = attr_value_pb2.AttrValue()
1123    output_attr.list.i.extend(hostmem)
1124    ret[0].op._set_attr("_output_hostmem", output_attr)  # pylint: disable=protected-access
1125  return ret
1126
1127
1128# pylint: enable=invalid-name,protected-access
1129
1130
1131def partitioned_call(args,
1132                     f,
1133                     tout=None,
1134                     executing_eagerly=None,
1135                     config=None,
1136                     executor_type=None):
1137  """Executes a function while respecting device annotations.
1138
1139  Currently, only those functions that execute within the same address space
1140  can be executed.
1141
1142  Args:
1143    args: The arguments of the function, including captured inputs.
1144    f: The function to execute; an instance of `_DefinedFunction` or
1145      `_EagerDefinedFunction`.
1146    tout: a list containing the output dtypes enums; if `None`, inferred from
1147      the signature of `f`.
1148    executing_eagerly: (Optional) A boolean indicating whether the context is
1149      executing eagerly. If `None`, fetched from the global context.
1150    config: (Optional) A `tensorflow::ConfigProto` proto, serialized. If `None`,
1151      all optimizations are disabled. Currently only handled for eager defined
1152      functions.
1153    executor_type: (Optional) A string for the name of the executor to be used
1154      in the function call. If not set, or set to an empty string, the default
1155      tensorflow executor will be used.
1156
1157  Returns:
1158    The list of `Tensor`s returned by invoking `f(args)`. If the function does
1159    not return anything, then returns `None` if eager execution is enabled, or
1160    the `Operation` if not.
1161  """
1162
1163  if tout is None:
1164    tout = tuple(x.type for x in f.definition.signature.output_arg)
1165
1166  if executing_eagerly is None:
1167    executing_eagerly = context.executing_eagerly()
1168
1169  if config is None:
1170    config = function_utils.get_disabled_rewriter_config()
1171
1172  if executor_type is None:
1173    executor_type = ""
1174
1175  if executing_eagerly:
1176    if f.stateful_ops:
1177      outputs = gen_functional_ops.stateful_partitioned_call(
1178          args=args,
1179          Tout=tout,
1180          f=f,
1181          config_proto=config,
1182          executor_type=executor_type)
1183    else:
1184      outputs = gen_functional_ops.partitioned_call(
1185          args=args,
1186          Tout=tout,
1187          f=f,
1188          config_proto=config,
1189          executor_type=executor_type)
1190    return outputs if outputs else None
1191
1192  # The generated binding returns an empty list for functions that don't
1193  # return any Tensors, hence the need to use `create_op` directly.
1194  args = [ops.convert_to_tensor(x) for x in args]
1195  tin_attr = attr_value_pb2.AttrValue(
1196      list=attr_value_pb2.AttrValue.ListValue(
1197          type=[x.dtype.as_datatype_enum for x in args]))
1198  tout_attr = attr_value_pb2.AttrValue(
1199      list=attr_value_pb2.AttrValue.ListValue(type=tout))
1200  func_attr = attr_value_pb2.AttrValue(
1201      func=attr_value_pb2.NameAttrList(name=f.name))
1202  executor_type_attr = attr_value_pb2.AttrValue(
1203      s=compat.as_bytes(executor_type))
1204
1205  # When running in graph mode, the graph and function graphs are optimized
1206  # (i.e. run through grappler) per the session options, so we can disable any
1207  # eager-specific rewriting.
1208  config_proto = attr_value_pb2.AttrValue(s=config)
1209
1210  graph = ops.get_default_graph()
1211  f.add_to_graph(graph)
1212  op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall"
1213
1214  # Propagate the attribute indicating the need to compile from function to the
1215  # call itself.
1216  xla_compile_attr = "_XlaMustCompile"
1217  op_attrs = {
1218      "Tin": tin_attr,
1219      "Tout": tout_attr,
1220      "f": func_attr,
1221      "config_proto": config_proto,
1222      "executor_type": executor_type_attr,
1223  }
1224  if xla_compile_attr in f.definition.attr:
1225    op_attrs[xla_compile_attr] = f.definition.attr[xla_compile_attr]
1226  op = graph.create_op(op_name, args, tout, name=op_name, attrs=op_attrs)
1227  outputs = op.outputs
1228  if hasattr(f, "graph"):
1229    _set_read_only_resource_inputs_attr(op, f.graph)
1230    if hasattr(f.graph, "collective_manager_ids_used"):
1231      ops.set_int_list_attr(op, acd.COLLECTIVE_MANAGER_IDS,
1232                            f.graph.collective_manager_ids_used)
1233  return outputs if outputs else op
1234
1235
1236def _set_read_only_resource_inputs_attr(op, func_graph):
1237  """Sets the list of resource inputs which are read-only.
1238
1239  This is used by AutomaticControlDependencies.
1240
1241  Args:
1242    op: PartitionedCall Operation.
1243    func_graph: FuncGraph.
1244  """
1245  read_only_indices = acd.get_read_only_resource_input_indices_graph(func_graph)
1246  ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR,
1247                        read_only_indices)
1248