• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Functional operations."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.framework import attr_value_pb2
22from tensorflow.python.eager import context
23from tensorflow.python.framework import auto_control_deps_utils as acd
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import function
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import gen_functional_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import tensor_array_ops
34from tensorflow.python.ops import variable_scope as vs
35# pylint: disable=unused-import
36from tensorflow.python.ops.gen_functional_ops import remote_call
37# pylint: enable=unused-import
38from tensorflow.python.ops.gen_functional_ops import symbolic_gradient
39from tensorflow.python.util import compat
40from tensorflow.python.util import deprecation
41from tensorflow.python.util import dispatch
42from tensorflow.python.util import function_utils
43from tensorflow.python.util import nest
44from tensorflow.python.util.tf_export import tf_export
45
46
47# TODO(yuanbyu, mrry): Handle stride to support sliding windows.
48@tf_export(v1=["foldl"])
49@dispatch.add_dispatch_support
50def foldl(fn,
51          elems,
52          initializer=None,
53          parallel_iterations=10,
54          back_prop=True,
55          swap_memory=False,
56          name=None):
57  """foldl on the list of tensors unpacked from `elems` on dimension 0.
58
59  This foldl operator repeatedly applies the callable `fn` to a sequence
60  of elements from first to last. The elements are made of the tensors
61  unpacked from `elems` on dimension 0. The callable fn takes two tensors as
62  arguments. The first argument is the accumulated value computed from the
63  preceding invocation of fn, and the second is the value at the current
64  position of `elems`. If `initializer` is None, `elems` must contain at least
65  one element, and its first element is used as the initializer.
66
67  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
68  of the result tensor is fn(initializer, values[0]).shape`.
69
70  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
71  is a (possibly nested) list or tuple of tensors, then each of these tensors
72  must have a matching first (unpack) dimension.  The signature of `fn` may
73  match the structure of `elems`.  That is, if `elems` is
74  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
75  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
76
77  Args:
78    fn: The callable to be performed.
79    elems: A tensor or (possibly nested) sequence of tensors, each of which will
80      be unpacked along their first dimension.  The nested sequence of the
81      resulting slices will be the first argument to `fn`.
82    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
83      as the initial value for the accumulator.
84    parallel_iterations: (optional) The number of iterations allowed to run in
85      parallel.
86    back_prop: (optional) True enables support for back propagation.
87    swap_memory: (optional) True enables GPU-CPU memory swapping.
88    name: (optional) Name prefix for the returned tensors.
89
90  Returns:
91    A tensor or (possibly nested) sequence of tensors, resulting from applying
92    `fn` consecutively to the list of tensors unpacked from `elems`, from first
93    to last.
94
95  Raises:
96    TypeError: if `fn` is not callable.
97
98  Example:
99    ```python
100    elems = tf.constant([1, 2, 3, 4, 5, 6])
101    sum = foldl(lambda a, x: a + x, elems)
102    # sum == 21
103    ```
104  """
105  if not callable(fn):
106    raise TypeError("fn must be callable.")
107
108  def create_ta(elem):
109    return tensor_array_ops.TensorArray(
110        dtype=elem.dtype, size=n, dynamic_size=False,
111        infer_shape=True).unstack(elem)
112
113  in_graph_mode = not context.executing_eagerly()
114  with ops.name_scope(name, "foldl", [elems]):
115    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
116    # supported in Eager
117    if in_graph_mode:
118      # Any get_variable calls in fn will cache the first call locally
119      # and not issue repeated network I/O requests for each iteration.
120      varscope = vs.get_variable_scope()
121      varscope_caching_device_was_none = False
122      if varscope.caching_device is None:
123        # TODO(ebrevdo): Change to using colocate_with here and in other
124        # methods.
125        varscope.set_caching_device(lambda op: op.device)
126        varscope_caching_device_was_none = True
127
128    # Convert elems to tensor array. n may be known statically.
129    elems_flat = [
130        ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems)
131    ]
132    n = (
133        tensor_shape.dimension_value(elems_flat[0].shape[0]) or
134        array_ops.shape(elems_flat[0])[0])
135
136    elems_ta = nest.map_structure(create_ta, elems)
137
138    if initializer is None:
139      a = nest.map_structure(lambda elem: elem.read(0), elems_ta)
140      i = constant_op.constant(1)
141    else:
142      a = initializer
143      i = constant_op.constant(0)
144
145    def compute(i, a):
146      elem_i = nest.map_structure(lambda elem: elem.read(i), elems_ta)
147      a = fn(a, elem_i)
148      return [i + 1, a]
149
150    _, r_a = control_flow_ops.while_loop(
151        lambda i, a: i < n,
152        compute, [i, a],
153        parallel_iterations=parallel_iterations,
154        back_prop=back_prop,
155        swap_memory=swap_memory,
156        maximum_iterations=n)
157
158    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
159    # supported in Eager
160    if in_graph_mode and varscope_caching_device_was_none:
161      varscope.set_caching_device(None)
162
163    return r_a
164
165
166@tf_export("foldl", v1=[])
167@dispatch.add_dispatch_support
168@deprecation.deprecated_arg_values(
169    None,
170    """back_prop=False is deprecated. Consider using tf.stop_gradient instead.
171Instead of:
172results = tf.foldl(fn, elems, back_prop=False)
173Use:
174results = tf.nest.map_structure(tf.stop_gradient, tf.foldl(fn, elems))""",
175    warn_once=True,
176    back_prop=False)
177def foldl_v2(fn,
178             elems,
179             initializer=None,
180             parallel_iterations=10,
181             back_prop=True,
182             swap_memory=False,
183             name=None):
184  """foldl on the list of tensors unpacked from `elems` on dimension 0.
185
186  This foldl operator repeatedly applies the callable `fn` to a sequence
187  of elements from first to last. The elements are made of the tensors
188  unpacked from `elems` on dimension 0. The callable fn takes two tensors as
189  arguments. The first argument is the accumulated value computed from the
190  preceding invocation of fn, and the second is the value at the current
191  position of `elems`. If `initializer` is None, `elems` must contain at least
192  one element, and its first element is used as the initializer.
193
194  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
195  of the result tensor is fn(initializer, values[0]).shape`.
196
197  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
198  is a (possibly nested) list or tuple of tensors, then each of these tensors
199  must have a matching first (unpack) dimension.  The signature of `fn` may
200  match the structure of `elems`.  That is, if `elems` is
201  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
202  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
203
204  Args:
205    fn: The callable to be performed.
206    elems: A tensor or (possibly nested) sequence of tensors, each of which will
207      be unpacked along their first dimension.  The nested sequence of the
208      resulting slices will be the first argument to `fn`.
209    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
210      as the initial value for the accumulator.
211    parallel_iterations: (optional) The number of iterations allowed to run in
212      parallel.
213    back_prop: (optional) Deprecated. False disables support for back
214      propagation. Prefer using `tf.stop_gradient` instead.
215    swap_memory: (optional) True enables GPU-CPU memory swapping.
216    name: (optional) Name prefix for the returned tensors.
217
218  Returns:
219    A tensor or (possibly nested) sequence of tensors, resulting from applying
220    `fn` consecutively to the list of tensors unpacked from `elems`, from first
221    to last.
222
223  Raises:
224    TypeError: if `fn` is not callable.
225
226  Example:
227    ```python
228    elems = tf.constant([1, 2, 3, 4, 5, 6])
229    sum = foldl(lambda a, x: a + x, elems)
230    # sum == 21
231    ```
232  """
233  return foldl(
234      fn=fn,
235      elems=elems,
236      initializer=initializer,
237      parallel_iterations=parallel_iterations,
238      back_prop=back_prop,
239      swap_memory=swap_memory,
240      name=name)
241
242
243@tf_export(v1=["foldr"])
244@dispatch.add_dispatch_support
245def foldr(fn,
246          elems,
247          initializer=None,
248          parallel_iterations=10,
249          back_prop=True,
250          swap_memory=False,
251          name=None):
252  """foldr on the list of tensors unpacked from `elems` on dimension 0.
253
254  This foldr operator repeatedly applies the callable `fn` to a sequence
255  of elements from last to first. The elements are made of the tensors
256  unpacked from `elems`. The callable fn takes two tensors as arguments.
257  The first argument is the accumulated value computed from the preceding
258  invocation of fn, and the second is the value at the current position of
259  `elems`. If `initializer` is None, `elems` must contain at least one element,
260  and its first element is used as the initializer.
261
262  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
263  of the result tensor is `fn(initializer, values[0]).shape`.
264
265  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
266  is a (possibly nested) list or tuple of tensors, then each of these tensors
267  must have a matching first (unpack) dimension.  The signature of `fn` may
268  match the structure of `elems`.  That is, if `elems` is
269  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
270  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
271
272  Args:
273    fn: The callable to be performed.
274    elems: A tensor or (possibly nested) sequence of tensors, each of which will
275      be unpacked along their first dimension.  The nested sequence of the
276      resulting slices will be the first argument to `fn`.
277    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
278      as the initial value for the accumulator.
279    parallel_iterations: (optional) The number of iterations allowed to run in
280      parallel.
281    back_prop: (optional) True enables support for back propagation.
282    swap_memory: (optional) True enables GPU-CPU memory swapping.
283    name: (optional) Name prefix for the returned tensors.
284
285  Returns:
286    A tensor or (possibly nested) sequence of tensors, resulting from applying
287    `fn` consecutively to the list of tensors unpacked from `elems`, from last
288    to first.
289
290  Raises:
291    TypeError: if `fn` is not callable.
292
293  Example:
294    ```python
295    elems = [1, 2, 3, 4, 5, 6]
296    sum = foldr(lambda a, x: a + x, elems)
297    # sum == 21
298    ```
299  """
300  if not callable(fn):
301    raise TypeError("fn must be callable.")
302
303  def create_ta(elem):
304    return tensor_array_ops.TensorArray(
305        dtype=elem.dtype, size=n, dynamic_size=False,
306        infer_shape=True).unstack(elem)
307
308  in_graph_mode = not context.executing_eagerly()
309  with ops.name_scope(name, "foldr", [elems]):
310    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
311    # supported in Eager
312    if in_graph_mode:
313      # Any get_variable calls in fn will cache the first call locally and not
314      # issue repeated network I/O requests for each iteration.
315      varscope = vs.get_variable_scope()
316      varscope_caching_device_was_none = False
317      if varscope.caching_device is None:
318        # TODO(ebrevdo): Change to using colocate_with here and in other
319        # methods.
320        varscope.set_caching_device(lambda op: op.device)
321        varscope_caching_device_was_none = True
322
323    # Convert elems to tensor array. n may be known statically.
324    elems_flat = [
325        ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems)
326    ]
327    n = (
328        tensor_shape.dimension_value(elems_flat[0].shape[0]) or
329        array_ops.shape(elems_flat[0])[0])
330
331    elems_ta = nest.map_structure(create_ta, elems)
332
333    if initializer is None:
334      i = n - 1
335      a = nest.map_structure(lambda elem: elem.read(i), elems_ta)
336    else:
337      i = n
338      a = initializer
339
340    def compute(i, a):
341      i -= 1
342      elem = nest.map_structure(lambda elem: elem.read(i), elems_ta)
343      a_out = fn(a, elem)
344      return [i, a_out]
345
346    _, r_a = control_flow_ops.while_loop(
347        lambda i, a: i > 0,
348        compute, [i, a],
349        parallel_iterations=parallel_iterations,
350        back_prop=back_prop,
351        swap_memory=swap_memory,
352        maximum_iterations=n)
353
354    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
355    # supported in Eager
356    if in_graph_mode and varscope_caching_device_was_none:
357      varscope.set_caching_device(None)
358
359    return r_a
360
361
362@tf_export("foldr", v1=[])
363@dispatch.add_dispatch_support
364@deprecation.deprecated_arg_values(
365    None,
366    """back_prop=False is deprecated. Consider using tf.stop_gradient instead.
367Instead of:
368results = tf.foldr(fn, elems, back_prop=False)
369Use:
370results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))""",
371    warn_once=True,
372    back_prop=False)
373def foldr_v2(fn,
374             elems,
375             initializer=None,
376             parallel_iterations=10,
377             back_prop=True,
378             swap_memory=False,
379             name=None):
380  """foldr on the list of tensors unpacked from `elems` on dimension 0.
381
382  This foldr operator repeatedly applies the callable `fn` to a sequence
383  of elements from last to first. The elements are made of the tensors
384  unpacked from `elems`. The callable fn takes two tensors as arguments.
385  The first argument is the accumulated value computed from the preceding
386  invocation of fn, and the second is the value at the current position of
387  `elems`. If `initializer` is None, `elems` must contain at least one element,
388  and its first element is used as the initializer.
389
390  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
391  of the result tensor is `fn(initializer, values[0]).shape`.
392
393  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
394  is a (possibly nested) list or tuple of tensors, then each of these tensors
395  must have a matching first (unpack) dimension.  The signature of `fn` may
396  match the structure of `elems`.  That is, if `elems` is
397  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
398  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
399
400  Args:
401    fn: The callable to be performed.
402    elems: A tensor or (possibly nested) sequence of tensors, each of which will
403      be unpacked along their first dimension.  The nested sequence of the
404      resulting slices will be the first argument to `fn`.
405    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
406      as the initial value for the accumulator.
407    parallel_iterations: (optional) The number of iterations allowed to run in
408      parallel.
409    back_prop: (optional) Deprecated. False disables support for back
410      propagation. Prefer using `tf.stop_gradient` instead.
411    swap_memory: (optional) True enables GPU-CPU memory swapping.
412    name: (optional) Name prefix for the returned tensors.
413
414  Returns:
415    A tensor or (possibly nested) sequence of tensors, resulting from applying
416    `fn` consecutively to the list of tensors unpacked from `elems`, from last
417    to first.
418
419  Raises:
420    TypeError: if `fn` is not callable.
421
422  Example:
423    ```python
424    elems = [1, 2, 3, 4, 5, 6]
425    sum = foldr(lambda a, x: a + x, elems)
426    # sum == 21
427    ```
428  """
429  return foldr(
430      fn=fn,
431      elems=elems,
432      initializer=initializer,
433      parallel_iterations=parallel_iterations,
434      back_prop=back_prop,
435      swap_memory=swap_memory,
436      name=name)
437
438
439@tf_export(v1=["scan"])
440@dispatch.add_dispatch_support
441def scan(fn,
442         elems,
443         initializer=None,
444         parallel_iterations=10,
445         back_prop=True,
446         swap_memory=False,
447         infer_shape=True,
448         reverse=False,
449         name=None):
450  """scan on the list of tensors unpacked from `elems` on dimension 0.
451
452  See also `tf.map_fn`.
453
454  The simplest version of `scan` repeatedly applies the callable `fn` to a
455  sequence of elements from first to last. The elements are made of the tensors
456  unpacked from `elems` on dimension 0. The callable fn takes two tensors as
457  arguments. The first argument is the accumulated value computed from the
458  preceding invocation of fn, and the second is the value at the current
459  position of `elems`. If `initializer` is None, `elems` must contain at least
460  one element, and its first element is used as the initializer.
461
462  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
463  of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
464  If reverse=True, it's fn(initializer, values[-1]).shape.
465
466  This method also allows multi-arity `elems` and accumulator.  If `elems`
467  is a (possibly nested) list or tuple of tensors, then each of these tensors
468  must have a matching first (unpack) dimension.  The second argument of
469  `fn` must match the structure of `elems`.
470
471  If no `initializer` is provided, the output structure and dtypes of `fn`
472  are assumed to be the same as its input; and in this case, the first
473  argument of `fn` must match the structure of `elems`.
474
475  If an `initializer` is provided, then the output of `fn` must have the same
476  structure as `initializer`; and the first argument of `fn` must match
477  this structure.
478
479  For example, if `elems` is `(t1, [t2, t3])` and `initializer` is
480  `[i1, i2]` then an appropriate signature for `fn` in `python2` is:
481  `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list,
482  `[acc_n1, acc_n2]`.  An alternative correct signature for `fn`, and the
483   one that works in `python3`, is:
484  `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples.
485
486  Args:
487    fn: The callable to be performed.  It accepts two arguments.  The first will
488      have the same structure as `initializer` if one is provided, otherwise it
489      will have the same structure as `elems`.  The second will have the same
490      (possibly nested) structure as `elems`.  Its output must have the same
491      structure as `initializer` if one is provided, otherwise it must have the
492      same structure as `elems`.
493    elems: A tensor or (possibly nested) sequence of tensors, each of which will
494      be unpacked along their first dimension.  The nested sequence of the
495      resulting slices will be the first argument to `fn`.
496    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
497      initial value for the accumulator, and the expected output type of `fn`.
498    parallel_iterations: (optional) The number of iterations allowed to run in
499      parallel.
500    back_prop: (optional) True enables support for back propagation.
501    swap_memory: (optional) True enables GPU-CPU memory swapping.
502    infer_shape: (optional) False disables tests for consistent output shapes.
503    reverse: (optional) True scans the tensor last to first (instead of first to
504      last).
505    name: (optional) Name prefix for the returned tensors.
506
507  Returns:
508    A tensor or (possibly nested) sequence of tensors.  Each tensor packs the
509    results of applying `fn` to tensors unpacked from `elems` along the first
510    dimension, and the previous accumulator value(s), from first to last (or
511    last to first, if `reverse=True`).
512
513  Raises:
514    TypeError: if `fn` is not callable or the structure of the output of
515      `fn` and `initializer` do not match.
516    ValueError: if the lengths of the output of `fn` and `initializer`
517      do not match.
518
519  Examples:
520    ```python
521    elems = np.array([1, 2, 3, 4, 5, 6])
522    sum = scan(lambda a, x: a + x, elems)
523    # sum == [1, 3, 6, 10, 15, 21]
524    sum = scan(lambda a, x: a + x, elems, reverse=True)
525    # sum == [21, 20, 18, 15, 11, 6]
526    ```
527
528    ```python
529    elems = np.array([1, 2, 3, 4, 5, 6])
530    initializer = np.array(0)
531    sum_one = scan(
532        lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer)
533    # sum_one == [1, 2, 3, 4, 5, 6]
534    ```
535
536    ```python
537    elems = np.array([1, 0, 0, 0, 0, 0])
538    initializer = (np.array(0), np.array(1))
539    fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer)
540    # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13])
541    ```
542  """
543  if not callable(fn):
544    raise TypeError("fn must be callable.")
545
546  input_is_sequence = nest.is_sequence(elems)
547  input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]
548
549  def input_pack(x):
550    return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]
551
552  if initializer is None:
553    output_is_sequence = input_is_sequence
554    output_flatten = input_flatten
555    output_pack = input_pack
556  else:
557    output_is_sequence = nest.is_sequence(initializer)
558    output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x]
559
560    def output_pack(x):
561      return (nest.pack_sequence_as(initializer, x)
562              if output_is_sequence else x[0])
563
564  elems_flat = input_flatten(elems)
565
566  in_graph_mode = not context.executing_eagerly()
567  with ops.name_scope(name, "scan", elems_flat):
568    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
569    # supported in Eager
570    if in_graph_mode:
571      # Any get_variable calls in fn will cache the first call locally
572      # and not issue repeated network I/O requests for each iteration.
573      varscope = vs.get_variable_scope()
574      varscope_caching_device_was_none = False
575      if varscope.caching_device is None:
576        # TODO(ebrevdo): Change to using colocate_with here and in other
577        # methods.
578        varscope.set_caching_device(lambda op: op.device)
579        varscope_caching_device_was_none = True
580
581    # Convert elems to tensor array.
582    elems_flat = [
583        ops.convert_to_tensor(elem, name="elem") for elem in elems_flat
584    ]
585
586    # Convert elems to tensor array. n may be known statically.
587    n = tensor_shape.dimension_value(elems_flat[0].shape[0])
588    if n is None:
589      n = array_ops.shape(elems_flat[0])[0]
590
591    # TensorArrays are always flat
592    elems_ta = [
593        tensor_array_ops.TensorArray(
594            dtype=elem.dtype,
595            size=n,
596            dynamic_size=False,
597            element_shape=elem.shape[1:],
598            infer_shape=True) for elem in elems_flat
599    ]
600    # Unpack elements
601    elems_ta = [
602        elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat)
603    ]
604
605    if initializer is None:
606      a_flat = [elem.read(n - 1 if reverse else 0) for elem in elems_ta]
607      i = 1
608    else:
609      initializer_flat = output_flatten(initializer)
610      a_flat = [ops.convert_to_tensor(init) for init in initializer_flat]
611      i = 0
612
613    # Create a tensor array to store the intermediate values.
614    accs_ta = [
615        tensor_array_ops.TensorArray(
616            dtype=init.dtype,
617            size=n,
618            element_shape=init.shape if infer_shape else None,
619            dynamic_size=False,
620            infer_shape=infer_shape) for init in a_flat
621    ]
622
623    if initializer is None:
624      accs_ta = [
625          acc_ta.write(n - 1 if reverse else 0, a)
626          for (acc_ta, a) in zip(accs_ta, a_flat)
627      ]
628
629    def compute(i, a_flat, tas):
630      """The loop body of scan.
631
632      Args:
633        i: the loop counter.
634        a_flat: the accumulator value(s), flattened.
635        tas: the output accumulator TensorArray(s), flattened.
636
637      Returns:
638        [i + 1, a_flat, tas]: the updated counter + new accumulator values +
639          updated TensorArrays
640
641      Raises:
642        TypeError: if initializer and fn() output structure do not match
643        ValueType: if initializer and fn() output lengths do not match
644      """
645      packed_elems = input_pack([elem_ta.read(i) for elem_ta in elems_ta])
646      packed_a = output_pack(a_flat)
647      a_out = fn(packed_a, packed_elems)
648      nest.assert_same_structure(elems if initializer is None else initializer,
649                                 a_out)
650      flat_a_out = output_flatten(a_out)
651      tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)]
652      if reverse:
653        next_i = i - 1
654      else:
655        next_i = i + 1
656      return (next_i, flat_a_out, tas)
657
658    if reverse:
659      initial_i = n - 1 - i
660      condition = lambda i, _1, _2: i >= 0
661    else:
662      initial_i = i
663      condition = lambda i, _1, _2: i < n
664    _, _, r_a = control_flow_ops.while_loop(
665        condition,
666        compute, (initial_i, a_flat, accs_ta),
667        parallel_iterations=parallel_iterations,
668        back_prop=back_prop,
669        swap_memory=swap_memory,
670        maximum_iterations=n)
671
672    results_flat = [r.stack() for r in r_a]
673
674    n_static = tensor_shape.Dimension(
675        tensor_shape.dimension_value(
676            elems_flat[0].get_shape().with_rank_at_least(1)[0]))
677    for elem in elems_flat[1:]:
678      n_static.assert_is_compatible_with(
679          tensor_shape.Dimension(
680              tensor_shape.dimension_value(
681                  elem.get_shape().with_rank_at_least(1)[0])))
682    for r in results_flat:
683      r.set_shape(
684          tensor_shape.TensorShape(n_static).concatenate(r.get_shape()[1:]))
685
686    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
687    # supported in Eager
688    if in_graph_mode and varscope_caching_device_was_none:
689      varscope.set_caching_device(None)
690
691    return output_pack(results_flat)
692
693
694@tf_export("scan", v1=[])
695@dispatch.add_dispatch_support
696@deprecation.deprecated_arg_values(
697    None,
698    """back_prop=False is deprecated. Consider using tf.stop_gradient instead.
699Instead of:
700results = tf.scan(fn, elems, back_prop=False)
701Use:
702results = tf.nest.map_structure(tf.stop_gradient, tf.scan(fn, elems))""",
703    warn_once=True,
704    back_prop=False)
705def scan_v2(fn,
706            elems,
707            initializer=None,
708            parallel_iterations=10,
709            back_prop=True,
710            swap_memory=False,
711            infer_shape=True,
712            reverse=False,
713            name=None):
714  """scan on the list of tensors unpacked from `elems` on dimension 0.
715
716  The simplest version of `scan` repeatedly applies the callable `fn` to a
717  sequence of elements from first to last. The elements are made of the tensors
718  unpacked from `elems` on dimension 0. The callable fn takes two tensors as
719  arguments. The first argument is the accumulated value computed from the
720  preceding invocation of fn, and the second is the value at the current
721  position of `elems`. If `initializer` is None, `elems` must contain at least
722  one element, and its first element is used as the initializer.
723
724  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
725  of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
726  If reverse=True, it's fn(initializer, values[-1]).shape.
727
728  This method also allows multi-arity `elems` and accumulator.  If `elems`
729  is a (possibly nested) list or tuple of tensors, then each of these tensors
730  must have a matching first (unpack) dimension.  The second argument of
731  `fn` must match the structure of `elems`.
732
733  If no `initializer` is provided, the output structure and dtypes of `fn`
734  are assumed to be the same as its input; and in this case, the first
735  argument of `fn` must match the structure of `elems`.
736
737  If an `initializer` is provided, then the output of `fn` must have the same
738  structure as `initializer`; and the first argument of `fn` must match
739  this structure.
740
741  For example, if `elems` is `(t1, [t2, t3])` and `initializer` is
742  `[i1, i2]` then an appropriate signature for `fn` in `python2` is:
743  `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list,
744  `[acc_n1, acc_n2]`.  An alternative correct signature for `fn`, and the
745   one that works in `python3`, is:
746  `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples.
747
748  Args:
749    fn: The callable to be performed.  It accepts two arguments.  The first will
750      have the same structure as `initializer` if one is provided, otherwise it
751      will have the same structure as `elems`.  The second will have the same
752      (possibly nested) structure as `elems`.  Its output must have the same
753      structure as `initializer` if one is provided, otherwise it must have the
754      same structure as `elems`.
755    elems: A tensor or (possibly nested) sequence of tensors, each of which will
756      be unpacked along their first dimension.  The nested sequence of the
757      resulting slices will be the first argument to `fn`.
758    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
759      initial value for the accumulator, and the expected output type of `fn`.
760    parallel_iterations: (optional) The number of iterations allowed to run in
761      parallel.
762    back_prop: (optional) Deprecated. False disables support for back
763      propagation. Prefer using `tf.stop_gradient` instead.
764    swap_memory: (optional) True enables GPU-CPU memory swapping.
765    infer_shape: (optional) False disables tests for consistent output shapes.
766    reverse: (optional) True scans the tensor last to first (instead of first to
767      last).
768    name: (optional) Name prefix for the returned tensors.
769
770  Returns:
771    A tensor or (possibly nested) sequence of tensors.  Each tensor packs the
772    results of applying `fn` to tensors unpacked from `elems` along the first
773    dimension, and the previous accumulator value(s), from first to last (or
774    last to first, if `reverse=True`).
775
776  Raises:
777    TypeError: if `fn` is not callable or the structure of the output of
778      `fn` and `initializer` do not match.
779    ValueError: if the lengths of the output of `fn` and `initializer`
780      do not match.
781
782  Examples:
783    ```python
784    elems = np.array([1, 2, 3, 4, 5, 6])
785    sum = scan(lambda a, x: a + x, elems)
786    # sum == [1, 3, 6, 10, 15, 21]
787    sum = scan(lambda a, x: a + x, elems, reverse=True)
788    # sum == [21, 20, 18, 15, 11, 6]
789    ```
790
791    ```python
792    elems = np.array([1, 2, 3, 4, 5, 6])
793    initializer = np.array(0)
794    sum_one = scan(
795        lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer)
796    # sum_one == [1, 2, 3, 4, 5, 6]
797    ```
798
799    ```python
800    elems = np.array([1, 0, 0, 0, 0, 0])
801    initializer = (np.array(0), np.array(1))
802    fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer)
803    # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13])
804    ```
805  """
806  return scan(
807      fn=fn,
808      elems=elems,
809      initializer=initializer,
810      parallel_iterations=parallel_iterations,
811      back_prop=back_prop,
812      swap_memory=swap_memory,
813      infer_shape=infer_shape,
814      reverse=reverse,
815      name=name)
816
817
818# pylint: disable=invalid-name
819def If(cond, inputs, then_branch, else_branch, name=None):
820  r"""output = Cond(inputs) ?
821
822  then_branch(inputs) : else_branch(inputs).
823
824  Args:
825    cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is
826      converted to a boolean according to the following rule: if the scalar is a
827        numerical value, non-zero means True and zero means False; if the scalar
828        is a string, non-empty means True and empty means False.
829    inputs: A list of input tensors.
830    then_branch: A function takes 'inputs' and returns a list of tensors, whose
831      types are the same as what else_branch returns.
832    else_branch: A function takes 'inputs' and returns a list of tensors. whose
833      types are the same as what then_branch returns.
834    name: A name for the operation (optional).
835
836  Returns:
837    A list of tensors returned by either then_branch(inputs)
838    or else_branch(inputs).
839  """
840  # pylint: disable=protected-access
841  # Handle the Defun case until users have transitioned to tf.function. Note
842  # that composites may need to be re-packed by the caller.
843  if isinstance(then_branch, function._DefinedFunction):
844    tlist = [_.type for _ in then_branch.definition.signature.output_arg]
845    return gen_functional_ops._if(
846        cond, inputs, tlist, then_branch, else_branch, name=name)
847
848  # We assume that `then_branch` is a ConcreteFunction here.
849  then_out = then_branch.structured_outputs
850  else_out = else_branch.structured_outputs
851
852  # Ensure then/else are the same type of composites to avoid an invalid call
853  # to pack_sequence_as later on.
854  nest.assert_same_structure(then_out, else_out, expand_composites=True)
855
856  tlist = nest.flatten(then_branch.output_dtypes)
857  ret = gen_functional_ops._if(
858      cond, inputs, tlist, then_branch, else_branch, name=name)
859
860  # Re-pack the outputs to restore any CompositeTensors
861  return nest.pack_sequence_as(then_out, ret, expand_composites=True)
862
863
864def Gradient(inputs, f, name=None):
865  r"""Computes the gradient function for function f via backpropagation.
866
867  Args:
868    inputs: A list of tensors of size N + M.
869    f: The function we want to compute the gradient for.  The function 'f' must
870      be a numerical function which takes N inputs and produces M outputs. Its
871      gradient function 'g', which is  a function taking N + M inputs and
872      produces N outputs.  I.e. if we have (y1, y2, ..., yM) = f(x1, x2, ...,
873      xN), then, g is (dL/dx1, dL/dx2, ..., dL/dxN) = g(x1, x2, ..., xN, dL/dy1,
874      dL/dy2, ..., dL/dyM),  where L is a scalar-value function of (x1, x2, ...,
875      xN) (e.g., the loss function). dL/dxi is the partial derivative of L with
876      respect to xi.
877    name: A name for the operation (optional).
878
879  Returns:
880    A list of tensors of size N.
881  """
882  # TODO(zhifengc): Pretty-print the above spec in latex.
883  # TODO(zhfiengc): Needs some math expert to say the comment above better.
884  tlist = [_.type for _ in f.definition.signature.input_arg]
885  return symbolic_gradient(input=inputs, Tout=tlist, f=f, name=name)
886
887
888def _GetInputDtypes(func):
889  """Returns the input dtypes of func, excluding dtypes for captured inputs."""
890  if isinstance(func, function._DefinedFunction):  # pylint: disable=protected-access
891    return func.declared_input_types
892
893  # We assume that `func` is a ConcreteFunction here, but we are not able to
894  # verify since importing eager function library will cause cyclic dependence.
895  #
896  # ConcreteFunction.inputs includes captured inputs.
897  num_non_captured_inputs = len(func.inputs) - len(func.captured_inputs)
898  inputs_without_captured = func.inputs[:num_non_captured_inputs]
899  return [t.dtype for t in inputs_without_captured]
900
901
902def _LoopBodyCaptureWrapper(func):
903  """Returns a wrapper for `func` that handles loop-carried captured inputs."""
904
905  @function.Defun(*_GetInputDtypes(func), func_name="%s_Wrapper" % func.name)
906  def Wrapper(*args):
907    """A wrapper that handles loop-carried captured inputs."""
908    result = func(*args)
909    extra_args = tuple(function.get_extra_args())
910    # Nullary functions return an Operation. Normal functions can't do this
911    # because their return values are converted to Tensors.
912    if isinstance(result, ops.Operation):
913      return extra_args
914    # Unary functions return a single Tensor value.
915    elif not isinstance(result, (list, tuple)):
916      return (result,) + extra_args
917    # N-ary functions return a tuple of Tensors.
918    else:
919      return result + type(result)(extra_args)
920
921  return Wrapper
922
923
924# pylint: disable=invalid-name,protected-access
925def While(input_, cond, body, name=None, hostmem=None):
926  r"""output = input; While (Cond(output)) { output = Body(output) }.
927
928  Args:
929    input_: A list of `Tensor` objects. A list of input tensors whose types are
930      T.
931    cond: . A function takes 'input' and returns a tensor.  If the tensor is a
932      scalar of non-boolean, the scalar is converted to a boolean
933      according to the following rule: if the scalar is a numerical value,
934        non-zero means True and zero means False; if the scalar is a string,
935        non-empty means True and empty means False. If the tensor is not a
936        scalar, non-emptiness means True and False otherwise.
937    body: . A function takes a list of tensors and returns another list tensors.
938      Both lists have the same types as specified by T.
939    name: A name for the operation (optional).
940    hostmem: A list of integer. If i is in the list, input[i] is a host memory
941      tensor.
942
943  Raises:
944    ValueError: if `cond` has implicitly captured inputs or if `cond` and `body`
945      have different signatures.
946
947  Returns:
948    A list of `Tensor` objects. Has the same type as `input`.
949    A list of output tensors whose types are T.
950  """
951  if cond.captured_inputs:
952    raise ValueError("While op 'cond' argument must be a function "
953                     "without implicitly captured inputs.")
954
955  cond_input_types = _GetInputDtypes(cond)
956  body_input_types = _GetInputDtypes(body)
957
958  if cond_input_types != body_input_types:
959    raise ValueError(
960        "While op 'cond' and 'body' signatures do not match. %r vs %r" %
961        (cond_input_types, body_input_types))
962
963  if body.captured_inputs:
964    cond_dtypes = list(body_input_types) + [
965        t.dtype for t in body.captured_inputs
966    ]
967
968    @function.Defun(*cond_dtypes, func_name="%s_Wrapper" % cond.name)
969    def CondWrapper(*args):
970      """A wrapper that handles loop-carried captured inputs."""
971      return cond(*args[:len(body_input_types)])
972
973    ret = gen_functional_ops._while(
974        input_ + body.captured_inputs,
975        CondWrapper,
976        _LoopBodyCaptureWrapper(body),
977        name=name)
978    # Slice off the loop-carried captured inputs.
979    ret = ret[:-len(body.captured_inputs)]
980  else:
981    ret = gen_functional_ops._while(input_, cond, body, name=name)
982  if hostmem:
983    input_attr = attr_value_pb2.AttrValue()
984    input_attr.list.i.extend(hostmem)
985    ret[0].op._set_attr("_input_hostmem", input_attr)  # pylint: disable=protected-access
986
987    output_attr = attr_value_pb2.AttrValue()
988    output_attr.list.i.extend(hostmem)
989    ret[0].op._set_attr("_output_hostmem", output_attr)  # pylint: disable=protected-access
990  return ret
991
992
993# b/36459430
994#
995# Ideally, we do not need this rewrite For loop into a While loop.
996# However, today, if a While runs on GPU and the condition returns a
997# boolean, the While kernel crashes. Even if we fix the crash, the
998# bool needs to be copied between GPU and CPU. So, a for loop is much
999# preferred when running on GPU.
1000#
1001# On the other hand, For op has no directly XLA kernel. So, when we run
1002# a for loop, we need to rewrite it using a While op.
1003#
1004# It should be possible and probably better to write a XLA C++ kernel
1005# implementing the logic in _ForUsingWhile.
1006def _ForUsingWhile(start,
1007                   limit,
1008                   delta,
1009                   inputs,
1010                   forbody,
1011                   name=None,
1012                   hostmem=None):
1013  """Helper to implement a For loop using a While."""
1014  # To support negative delta (e.g., range(100, 0, -3)), we iterate
1015  # over the range(n) and use iter * delta + start as the real
1016  # iteration index. (e.g., for i in range(34): iter = i * (-3) +
1017  # 100).
1018  d = math_ops.abs(delta)
1019  # XLA on TPUs doesn't support integer division
1020  n = math_ops.cast(
1021      math_ops.cast((math_ops.abs(limit - start) + d - 1), dtypes.float32) /
1022      math_ops.cast(d, dtypes.float32), dtypes.int32)
1023
1024  # Carried loop variables ("extra_args") are implicitly added to the input list
1025  # of the WhileBody function. WhileCond does not call forbody, and so does not
1026  # depend on any of forbody's extra_args. Since WhileCond and WhileBody
1027  # must have identical inputs, we have to augment the cond signature to take
1028  # the same types as the carried loop variables.
1029  body_sig = [dtypes.int32] * 4 + list(forbody.declared_input_types)[1:]
1030
1031  cond_name = "%s_Cond" % forbody.name
1032
1033  @function.Defun(*body_sig, func_name=cond_name)
1034  def WhileCond(i, n, *args):
1035    del args
1036    return i < n
1037
1038  body_name = "%s_Body" % forbody.name
1039
1040  @function.Defun(*body_sig, func_name=body_name)
1041  def WhileBody(i, n, start, delta, *args):
1042    """A While wrapper for forbody that handles loop-carried captured inputs."""
1043    for_result = forbody(start + i * delta, *args)
1044    # Nullary functions return an Operation. Normal functions can't do this
1045    # because their return values are converted to Tensors.
1046    if isinstance(for_result, ops.Operation):
1047      for_result = ()
1048    # Unary functions return a single Tensor value.
1049    elif isinstance(for_result, ops.Tensor):
1050      for_result = (for_result,)
1051    return (i + 1, n, start, delta) + tuple(for_result)
1052
1053  if hostmem is not None:
1054    hostmem = [0, 1, 2, 3] + [(4 + _) for _ in hostmem]
1055  else:
1056    hostmem = [0, 1, 2, 3]
1057
1058  results = While(
1059      input_=[0, n, start, delta] + inputs,
1060      cond=WhileCond,
1061      body=WhileBody,
1062      name=name,
1063      hostmem=hostmem)
1064  # Slice off the loop-carried captured inputs.
1065  return list(results[4:len(results)])
1066
1067
1068def For(start,
1069        limit,
1070        delta,
1071        inputs,
1072        body,
1073        name=None,
1074        hostmem=None,
1075        rewrite_with_while=None):
1076  r"""out = input; for i in range(start, limit, delta) out = body(i, out).
1077
1078  Args:
1079    start: A `Tensor` of type `int32`.
1080    limit: A `Tensor` of type `int32`.
1081    delta: A `Tensor` of type `int32`.
1082    inputs: A list of `Tensor` objects. A list of input tensors whose types are
1083      T.
1084    body: A function takes a list of tensors and returns another list of
1085      tensors. Both lists have the same types as (int32, T...).
1086    name: A name for the operation (optional).
1087    hostmem: A list of integer. If i is in the list, inputs[i] is a host memory
1088      tensor. In other words, (i+1)-th argument of the body function is
1089      expecting a host memory.
1090    rewrite_with_while: If True, using While op to implement the For.
1091
1092  Returns:
1093    A list of `Tensor` objects. Has the same type as `input`.
1094    A list of output tensors whose types are T.
1095  """
1096  if rewrite_with_while:
1097    return _ForUsingWhile(start, limit, delta, inputs, body, name, hostmem)
1098  if body.captured_inputs:
1099    ret = gen_functional_ops._for(
1100        start,
1101        limit,
1102        delta,
1103        inputs + body.captured_inputs,
1104        _LoopBodyCaptureWrapper(body),
1105        name=name)
1106    # Slice off the loop-carried captured inputs.
1107    ret = ret[:-len(body.captured_inputs)]
1108  else:
1109    ret = gen_functional_ops._for(start, limit, delta, inputs, body, name=name)
1110  if hostmem:
1111    num_for_params = 3  # start/limit/delta
1112
1113    input_attr = attr_value_pb2.AttrValue()
1114    input_attr.list.i.extend([num_for_params + i for i in hostmem])
1115    ret[0].op._set_attr("_input_hostmem", input_attr)  # pylint: disable=protected-access
1116
1117    output_attr = attr_value_pb2.AttrValue()
1118    output_attr.list.i.extend(hostmem)
1119    ret[0].op._set_attr("_output_hostmem", output_attr)  # pylint: disable=protected-access
1120  return ret
1121
1122
1123# pylint: enable=invalid-name,protected-access
1124
1125
1126def partitioned_call(args,
1127                     f,
1128                     tout=None,
1129                     executing_eagerly=None,
1130                     config=None,
1131                     executor_type=None):
1132  """Executes a function while respecting device annotations.
1133
1134  Currently, only those functions that execute within the same address space
1135  can be executed.
1136
1137  Args:
1138    args: The arguments of the function, including captured inputs.
1139    f: The function to execute; an instance of `_DefinedFunction` or
1140      `_EagerDefinedFunction`.
1141    tout: a list containing the output dtypes enums; if `None`, inferred from
1142      the signature of `f`.
1143    executing_eagerly: (Optional) A boolean indicating whether the context is
1144      executing eagerly. If `None`, fetched from the global context.
1145    config: (Optional) A `tensorflow::ConfigProto` proto, serialized. If `None`,
1146      all optimizations are disabled. Currently only handled for eager defined
1147      functions.
1148    executor_type: (Optional) A string for the name of the executor to be used
1149      in the function call. If not set, or set to an empty string, the default
1150      tensorflow executor will be used.
1151
1152  Returns:
1153    The list of `Tensor`s returned by invoking `f(args)`. If the function does
1154    not return anything, then returns `None` if eager execution is enabled, or
1155    the `Operation` if not.
1156  """
1157
1158  if tout is None:
1159    tout = tuple(x.type for x in f.definition.signature.output_arg)
1160
1161  if executing_eagerly is None:
1162    executing_eagerly = context.executing_eagerly()
1163
1164  if config is None:
1165    config = function_utils.get_disabled_rewriter_config()
1166
1167  if executor_type is None:
1168    executor_type = ""
1169
1170  if executing_eagerly:
1171    if f.stateful_ops:
1172      outputs = gen_functional_ops.stateful_partitioned_call(
1173          args=args,
1174          Tout=tout,
1175          f=f,
1176          config_proto=config,
1177          executor_type=executor_type)
1178    else:
1179      outputs = gen_functional_ops.partitioned_call(
1180          args=args,
1181          Tout=tout,
1182          f=f,
1183          config_proto=config,
1184          executor_type=executor_type)
1185    return outputs if outputs else None
1186
1187  # The generated binding returns an empty list for functions that don't
1188  # return any Tensors, hence the need to use `create_op` directly.
1189  args = [ops.convert_to_tensor(x) for x in args]
1190  tin_attr = attr_value_pb2.AttrValue(
1191      list=attr_value_pb2.AttrValue.ListValue(
1192          type=[x.dtype.as_datatype_enum for x in args]))
1193  tout_attr = attr_value_pb2.AttrValue(
1194      list=attr_value_pb2.AttrValue.ListValue(type=tout))
1195  func_attr = attr_value_pb2.AttrValue(
1196      func=attr_value_pb2.NameAttrList(name=f.name))
1197  executor_type_attr = attr_value_pb2.AttrValue(
1198      s=compat.as_bytes(executor_type))
1199
1200  # When running in graph mode, the graph and function graphs are optimized
1201  # (i.e. run through grappler) per the session options, so we can disable any
1202  # eager-specific rewriting.
1203  config_proto = attr_value_pb2.AttrValue(s=config)
1204
1205  graph = ops.get_default_graph()
1206  f.add_to_graph(graph)
1207  op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall"
1208
1209  # Propagate the attribute indicating the need to compile from function to the
1210  # call itself.
1211  xla_compile_attr = "_XlaMustCompile"
1212  op_attrs = {
1213      "Tin": tin_attr,
1214      "Tout": tout_attr,
1215      "f": func_attr,
1216      "config_proto": config_proto,
1217      "executor_type": executor_type_attr,
1218  }
1219  if xla_compile_attr in f.definition.attr:
1220    op_attrs[xla_compile_attr] = f.definition.attr[xla_compile_attr]
1221  op = graph.create_op(op_name, args, tout, name=op_name, attrs=op_attrs)
1222  outputs = op.outputs
1223  if hasattr(f, "graph"):
1224    _set_read_only_resource_inputs_attr(op, f.graph)
1225    if hasattr(f.graph, "collective_manager_ids_used"):
1226      ops.set_int_list_attr(op, acd.COLLECTIVE_MANAGER_IDS,
1227                            f.graph.collective_manager_ids_used)
1228  return outputs if outputs else op
1229
1230
1231def _set_read_only_resource_inputs_attr(op, func_graph):
1232  """Sets the list of resource inputs which are read-only.
1233
1234  This is used by AutomaticControlDependencies.
1235
1236  Args:
1237    op: PartitionedCall Operation.
1238    func_graph: FuncGraph.
1239  """
1240  read_only_indices = acd.get_read_only_resource_input_indices_graph(func_graph)
1241  ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR,
1242                        read_only_indices)
1243