• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Functional operations."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22
23import re
24
25from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
26from tensorflow.python.autograph.impl import api as autograph
27from tensorflow.python.eager import context
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import sparse_tensor
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework import tensor_spec
33from tensorflow.python.framework import type_spec
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import control_flow_ops
36from tensorflow.python.ops import tensor_array_ops
37from tensorflow.python.ops import variable_scope as vs
38from tensorflow.python.ops.ragged import ragged_tensor
39from tensorflow.python.platform import tf_logging as logging
40from tensorflow.python.util import deprecation
41from tensorflow.python.util import nest
42from tensorflow.python.util.tf_export import tf_export
43
44
45@tf_export(v1=["map_fn"])
46@deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype")
47def map_fn(fn,
48           elems,
49           dtype=None,
50           parallel_iterations=None,
51           back_prop=True,
52           swap_memory=False,
53           infer_shape=True,
54           name=None,
55           fn_output_signature=None):
56  """Transforms `elems` by applying `fn` to each element unstacked on axis 0.
57
58  See also `tf.scan`.
59
60  `map_fn` unstacks `elems` on axis 0 to obtain a sequence of elements;
61  calls `fn` to transform each element; and then stacks the transformed
62  values back together.
63
64  #### Mapping functions with single-Tensor inputs and outputs
65
66  If `elems` is a single tensor and `fn`'s signature is `tf.Tensor->tf.Tensor`,
67  then `map_fn(fn, elems)` is equivalent to
68  `tf.stack([fn(elem) for elem in tf.unstack(elems)])`.  E.g.:
69
70  >>> tf.map_fn(fn=lambda t: tf.range(t, t + 3), elems=tf.constant([3, 5, 2]))
71  <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
72    array([[3, 4, 5],
73           [5, 6, 7],
74           [2, 3, 4]], dtype=int32)>
75
76  `map_fn(fn, elems).shape = [elems.shape[0]] + fn(elems[0]).shape`.
77
78  #### Mapping functions with multi-arity inputs and outputs
79
80  `map_fn` also supports functions with multi-arity inputs and outputs:
81
82  * If `elems` is a tuple (or nested structure) of tensors, then those tensors
83    must all have the same outer-dimension size (`num_elems`); and `fn` is
84    used to transform each tuple (or structure) of corresponding slices from
85    `elems`.  E.g., if `elems` is a tuple `(t1, t2, t3)`, then `fn` is used to
86    transform each tuple of slices `(t1[i], t2[i], t3[i])`
87    (where `0 <= i < num_elems`).
88
89  * If `fn` returns a tuple (or nested structure) of tensors, then the
90    result is formed by stacking corresponding elements from those structures.
91
92  #### Specifying `fn`'s output signature
93
94  If `fn`'s input and output signatures are different, then the output
95  signature must be specified using `fn_output_signature`.  (The input and
96  output signatures are differ if their structures, dtypes, or tensor types do
97  not match).  E.g.:
98
99  >>> tf.map_fn(fn=tf.strings.length,  # input & output have different dtypes
100  ...           elems=tf.constant(["hello", "moon"]),
101  ...           fn_output_signature=tf.int32)
102  <tf.Tensor: shape=(2,), dtype=int32, numpy=array([5, 4], dtype=int32)>
103  >>> tf.map_fn(fn=tf.strings.join,  # input & output have different structures
104  ...           elems=[tf.constant(['The', 'A']), tf.constant(['Dog', 'Cat'])],
105  ...           fn_output_signature=tf.string)
106  <tf.Tensor: shape=(2,), dtype=string,
107   numpy=array([b'TheDog', b'ACat'], dtype=object)>
108
109  `fn_output_signature` can be specified using any of the following:
110
111  * A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`)
112  * A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`)
113  * A `tf.SparseTensorSpec` (to describe a `tf.sparse.SparseTensor`)
114  * A (possibly nested) tuple, list, or dict containing the above types.
115
116  #### RaggedTensors
117
118  `map_fn` supports `tf.RaggedTensor` inputs and outputs.  In particular:
119
120  * If `elems` is a `RaggedTensor`, then `fn` will be called with each
121    row of that ragged tensor.
122    * If `elems` has only one ragged dimension, then the values passed to
123      `fn` will be `tf.Tensor`s.
124    * If `elems` has multiple ragged dimensions, then the values passed to
125      `fn` will be `tf.RaggedTensor`s with one fewer ragged dimension.
126
127  * If the result of `map_fn` should be a `RaggedTensor`, then use a
128    `tf.RaggedTensorSpec` to specify `fn_output_signature`.
129    * If `fn` returns `tf.Tensor`s with varying sizes, then use a
130      `tf.RaggedTensorSpec` with `ragged_rank=0` to combine them into a
131      single ragged tensor (which will have ragged_rank=1).
132    * If `fn` returns `tf.RaggedTensor`s, then use a `tf.RaggedTensorSpec`
133      with the same `ragged_rank`.
134
135  >>> # Example: RaggedTensor input
136  >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
137  >>> tf.map_fn(tf.reduce_sum, rt, fn_output_signature=tf.int32)
138  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([6, 0, 9, 6], dtype=int32)>
139
140  >>> # Example: RaggedTensor output
141  >>> elems = tf.constant([3, 5, 0, 2])
142  >>> tf.map_fn(tf.range, elems,
143  ...           fn_output_signature=tf.RaggedTensorSpec(shape=[None],
144  ...                                                   dtype=tf.int32))
145  <tf.RaggedTensor [[0, 1, 2], [0, 1, 2, 3, 4], [], [0, 1]]>
146
147  Note: `map_fn` should only be used if you need to map a function over the
148  *rows* of a `RaggedTensor`.  If you wish to map a function over the
149  individual values, then you should use:
150
151  * `tf.ragged.map_flat_values(fn, rt)`
152    (if fn is expressible as TensorFlow ops)
153  * `rt.with_flat_values(map_fn(fn, rt.flat_values))`
154    (otherwise)
155
156  E.g.:
157
158  >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
159  >>> tf.ragged.map_flat_values(lambda x: x + 2, rt)
160  <tf.RaggedTensor [[3, 4, 5], [], [6, 7], [8]]>
161
162  #### SparseTensors
163
164  `map_fn` supports `tf.sparse.SparseTensor` inputs and outputs.  In particular:
165
166  * If `elems` is a `SparseTensor`, then `fn` will be called with each row
167    of that sparse tensor. In particular, the value passed to `fn` will be a
168    `tf.sparse.SparseTensor` with one fewer dimension than `elems`.
169
170  * If the result of `map_fn` should be a `SparseTensor`, then use a
171    `tf.SparseTensorSpec` to specify `fn_output_signature`.  The individual
172    `SparseTensor`s returned by `fn` will be stacked into a single
173    `SparseTensor` with one more dimension.
174
175  >>> # Example: SparseTensor input
176  >>> st = tf.sparse.SparseTensor([[0, 0], [2, 0], [2, 1]], [2, 3, 4], [4, 4])
177  >>> tf.map_fn(tf.sparse.reduce_sum, st, fn_output_signature=tf.int32)
178  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([2, 0, 7, 0], dtype=int32)>
179
180  >>> # Example: SparseTensor output
181  >>> tf.sparse.to_dense(
182  ...     tf.map_fn(tf.sparse.eye, tf.constant([2, 3]),
183  ...               fn_output_signature=tf.SparseTensorSpec(None, tf.float32)))
184  <tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
185    array([[[1., 0., 0.],
186            [0., 1., 0.],
187            [0., 0., 0.]],
188           [[1., 0., 0.],
189            [0., 1., 0.],
190            [0., 0., 1.]]], dtype=float32)>
191
192  Note: `map_fn` should only be used if you need to map a function over the
193  *rows* of a `SparseTensor`.  If you wish to map a function over the nonzero
194  values, then you should use:
195
196  * If the function is expressible as TensorFlow ops, use:
197    ```python
198    tf.sparse.SparseTensor(st.indices, fn(st.values), st.dense_shape)
199    ```
200  * Otherwise, use:
201    ```python
202    tf.sparse.SparseTensor(st.indices, tf.map_fn(fn, st.values),
203                           st.dense_shape)
204    ```
205
206  #### `map_fn` vs. vectorized operations
207
208  `map_fn` will apply the operations used by `fn` to each element of `elems`,
209  resulting in `O(elems.shape[0])` total operations.  This is somewhat
210  mitigated by the fact that `map_fn` can process elements in parallel.
211  However, a transform expressed using `map_fn` is still typically less
212  efficient than an equivalent transform expressed using vectorized operations.
213
214  `map_fn` should typically only be used if one of the following is true:
215
216  * It is difficult or expensive to express the desired transform with
217    vectorized operations.
218  * `fn` creates large intermediate values, so an equivalent vectorized
219    transform would take too much memory.
220  * Processing elements in parallel is more efficient than an equivalent
221    vectorized transform.
222  * Efficiency of the transform is not critical, and using `map_fn` is
223    more readable.
224
225  E.g., the example given above that maps `fn=lambda t: tf.range(t, t + 3)`
226  across `elems` could be rewritten more efficiently using vectorized ops:
227
228  >>> elems = tf.constant([3, 5, 2])
229  >>> tf.range(3) + tf.expand_dims(elems, 1)
230  <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
231    array([[3, 4, 5],
232           [5, 6, 7],
233           [2, 3, 4]], dtype=int32)>
234
235  In some cases, `tf.vectorized_map` can be used to automatically convert a
236  function to a vectorized equivalent.
237
238  #### Eager execution
239
240  When executing eagerly, `map_fn` does not execute in parallel even if
241  `parallel_iterations` is set to a value > 1. You can still get the
242  performance benefits of running a function in parallel by using the
243  `tf.function` decorator:
244
245  >>> fn=lambda t: tf.range(t, t + 3)
246  >>> @tf.function
247  ... def func(elems):
248  ...   return tf.map_fn(fn, elems, parallel_iterations=3)
249  >>> func(tf.constant([3, 5, 2]))
250  <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
251    array([[3, 4, 5],
252           [5, 6, 7],
253           [2, 3, 4]], dtype=int32)>
254
255
256  Note: if you use the `tf.function` decorator, any non-TensorFlow Python
257  code that you may have written in your function won't get executed. See
258  `tf.function` for more  details. The recommendation would be to debug without
259  `tf.function` but switch to it to get performance benefits of running `map_fn`
260  in parallel.
261
262  Args:
263    fn: The callable to be performed.  It accepts one argument, which will have
264      the same (possibly nested) structure as `elems`.  Its output must have the
265      same structure as `fn_output_signature` if one is provided; otherwise it
266      must have the same structure as `elems`.
267    elems: A tensor or (possibly nested) sequence of tensors, each of which will
268      be unstacked along their first dimension.  `fn` will be applied to the
269      nested sequence of the resulting slices.  `elems` may include ragged and
270      sparse tensors. `elems` must consist of at least one tensor.
271    dtype: Deprecated: Equivalent to `fn_output_signature`.
272    parallel_iterations: (optional) The number of iterations allowed to run in
273      parallel. When graph building, the default value is 10. While executing
274      eagerly, the default value is set to 1.
275    back_prop: (optional) False disables support for back propagation.
276    swap_memory: (optional) True enables GPU-CPU memory swapping.
277    infer_shape: (optional) False disables tests for consistent output shapes.
278    name: (optional) Name prefix for the returned tensors.
279    fn_output_signature: The output signature of `fn`. Must be specified if
280      `fn`'s input and output signatures are different (i.e., if their
281      structures, dtypes, or tensor types do not match).
282      `fn_output_signature` can be specified using any of the following:
283
284      * A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`)
285      * A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`)
286      * A `tf.SparseTensorSpec` (to describe a `tf.sparse.SparseTensor`)
287      * A (possibly nested) tuple, list, or dict containing the above types.
288
289  Returns:
290    A tensor or (possibly nested) sequence of tensors.  Each tensor stacks the
291    results of applying `fn` to tensors unstacked from `elems` along the first
292    dimension, from first to last.  The result may include ragged and sparse
293    tensors.
294
295  Raises:
296    TypeError: if `fn` is not callable or the structure of the output of
297      `fn` and `fn_output_signature` do not match.
298    ValueError: if the lengths of the output of `fn` and `fn_output_signature`
299      do not match, or if the `elems` does not contain any tensor.
300
301  Examples:
302
303    >>> elems = np.array([1, 2, 3, 4, 5, 6])
304    >>> tf.map_fn(lambda x: x * x, elems)
305    <tf.Tensor: shape=(6,), dtype=int64, numpy=array([ 1,  4,  9, 16, 25, 36])>
306
307    >>> elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
308    >>> tf.map_fn(lambda x: x[0] * x[1], elems, fn_output_signature=tf.int64)
309    <tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1,  2, -3])>
310
311    >>> elems = np.array([1, 2, 3])
312    >>> tf.map_fn(lambda x: (x, -x), elems,
313    ...          fn_output_signature=(tf.int64, tf.int64))
314    (<tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>,
315     <tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1, -2, -3])>)
316  """
317  # This function uses a `while_loop` to call `fn` on each value of the input
318  # tensor(s) (unstacked on dimension 0).  The following sequence of variables
319  # are used to transform the input tensor(s) (`elems`) into the output
320  # tensor(s) (`result`):
321  #
322  #   - Preparing and unstacking input values for the while_loop:
323  #     - elems: The input tensor(s) to map_fn. May include composite tensors.
324  #     - elems_flat: Flattened list of tensors from elems (using nest.flatten)
325  #                   May include composite tensors.
326  #     - elems_batchable: Concatenation of "batchable tensor lists" for each
327  #                        tensor in elems_flat.  This "boxes" composite tensors
328  #                        into sliceable tf.Tensor objects.  For more info see:
329  #                        TensorSpec._to_batched_tensor_list
330  #     - elems_batchable_ta: List of TensorArrays used to unstack each Tensor
331  #                           in elems_batchable into elems_value_batchable.
332  #
333  #   - Calling `fn` on each unstacked value in the body of the while_loop:
334  #     - elems_value_batchable: Single unstacked value from elems_batchable.
335  #     - elems_value_flat: Single unstacked value from elems_flat,
336  #                         constructed from elems_value_batchable (using
337  #                         TensorSpec._from_tensor_list).
338  #     - elems_value: Single unstacked value from elems (the input to fn).
339  #     - result_value: Result of calling `fn(elems_value)`.  May contain
340  #                     composite tensors.
341  #     - result_value_flat: Flattened list of tensors from result_value.
342  #                          May contain composite tensors.
343  #     - result_value_batchable: Concatenation of batchable tensor lists for
344  #                               each tensor in result_value_flat
345  #                               (using TensorSpec._to_tensor_list).
346  #
347  #   - Collecting and stacking output values from the while_loop:
348  #     - result_batchable_ta: List of TensorArrays used to stack each tensor
349  #                            ta result_value_batchable into result_batchable.
350  #     - result_batchable: Stacked tensors from result_batchable_ta.
351  #     - result_flat: Flat list of tensors for the result, constructed from
352  #                    results bactchable (using TensorSpec._from_tensor_list).
353  #     - result: Structured result value packed from results flat
354  #               (using nest.pack_sequence_as).
355
356  if fn_output_signature is None:
357    fn_output_signature = dtype
358
359  if not callable(fn):
360    raise TypeError("fn must be callable.")
361
362  in_graph_mode = not context.executing_eagerly()
363  # Set the default number of parallel_iterations depending on graph/eager mode.
364  if in_graph_mode and not parallel_iterations:
365    parallel_iterations = 10
366  elif not in_graph_mode and not parallel_iterations:
367    parallel_iterations = 1
368  elif not in_graph_mode and parallel_iterations > 1:
369    logging.log_first_n(
370        logging.WARN, "Setting parallel_iterations > 1 has no "
371        "effect when executing eagerly. Consider calling map_fn"
372        " with tf.function to execute fn in "
373        "parallel.", 1)
374    parallel_iterations = 1
375
376  # Flatten the input tensors, and get the TypeSpec for each one.
377  elems_flat = nest.flatten(elems)
378
379  # Check in case this is an empty list
380  if len(elems_flat) == 0:
381    raise ValueError(
382        "elems must be a Tensor or (possibly nested) sequence of Tensors. "
383        "Got {}, which does not contain any Tensors.".format(elems))
384
385  elems_flat_signature = [type_spec.type_spec_from_value(e) for e in elems_flat]
386  elems_unflatten = lambda x: nest.pack_sequence_as(elems, x)
387
388  # Flatten fn's output signature.
389  if fn_output_signature is None:
390    # If fn_output_signature was not specified, then assume that it matches the
391    # input signature.
392    result_flat_signature = [
393        _most_general_compatible_type(s)._unbatch()  # pylint: disable=protected-access
394        for s in elems_flat_signature
395    ]
396    result_unflatten = elems_unflatten
397  else:
398    result_flat_signature = [
399        _dtype_to_spec(d) for d in nest.flatten(fn_output_signature)
400    ]
401    result_unflatten = lambda x: nest.pack_sequence_as(fn_output_signature, x)
402
403  with ops.name_scope(name, "map", elems_flat):
404    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
405    # supported in Eager
406    if in_graph_mode:
407      # Any get_variable calls in fn will cache the first call locally
408      # and not issue repeated network I/O requests for each iteration.
409      varscope = vs.get_variable_scope()
410      varscope_caching_device_was_none = False
411      if varscope.caching_device is None:
412        # TODO(ebrevdo): Change to using colocate_with here and in other
413        # methods.
414        varscope.set_caching_device(lambda op: op.device)
415        varscope_caching_device_was_none = True
416
417    elems_flat = [
418        ops.convert_to_tensor_or_composite(t, name="elem") for t in elems_flat
419    ]
420
421    # Check that inputs are not scalars.
422    first_elem = elems_flat[0]
423    elems_static_shape = first_elem.shape
424    if elems_static_shape.ndims is not None and elems_static_shape.ndims < 1:
425      if len(elems_flat) == 1:
426        raise ValueError("elems must be a 1+ dimensional Tensor, not a scalar")
427      else:
428        raise ValueError(
429            "elements in elems must be 1+ dimensional Tensors, not scalars"
430        )
431
432    # Box any composite tensors into tensor lists.
433    elems_batchable = _elems_flat_to_batchable(elems_flat)
434
435    # Find the number of iterations, n.  (may be known statically.)
436    n_static = tensor_shape.Dimension(
437        tensor_shape.dimension_value(
438            elems_batchable[0].get_shape().with_rank_at_least(1)[0]))
439    for tensor in elems_batchable[1:]:
440      n_static.assert_is_compatible_with(
441          tensor_shape.Dimension(
442              tensor_shape.dimension_value(
443                  tensor.get_shape().with_rank_at_least(1)[0])))
444    n = n_static.value or array_ops.shape(elems_batchable[0])[0]
445
446    # Convert elems to tensor array.
447    # TODO(edloper): Should we set infer_shape=False for composite tensors?
448    elems_batchable_ta = [
449        tensor_array_ops.TensorArray(
450            dtype=t.dtype, size=n, dynamic_size=False, infer_shape=True)
451        for t in elems_batchable
452    ]
453    # Unpack elements
454    elems_batchable_ta = [
455        ta.unstack(t) for (ta, t) in zip(elems_batchable_ta, elems_batchable)
456    ]
457
458    i = constant_op.constant(0)
459
460    # Prepare result tensor array.
461    # TODO(edloper): Should we set infer_shape=False for composite tensors?
462    result_batchable_tensor_spec = (
463        _result_flat_signature_to_batchable_tensor_spec(result_flat_signature))
464    result_batchable_ta = []
465    for spec in result_batchable_tensor_spec:
466      result_batchable_ta.append(
467          tensor_array_ops.TensorArray(
468              dtype=spec.dtype, size=n, dynamic_size=False,
469              infer_shape=infer_shape, element_shape=spec.shape))
470
471    def compute(i, tas):
472      """The loop body of map_fn.
473
474      Args:
475        i: the loop counter
476        tas: the flat TensorArray accumulator list
477
478      Returns:
479        (i + 1, tas): the updated counter + updated TensorArrays
480
481      Raises:
482        TypeError: if fn_output_signature and result_value structure don't match
483        ValueType: if fn_output_signature and result_value lengths don't match
484      """
485      elems_value_batchable = [ta.read(i) for ta in elems_batchable_ta]
486      elems_value_flat = _elems_value_batchable_to_flat(elems_value_batchable,
487                                                        elems_flat_signature)
488      elems_value = elems_unflatten(elems_value_flat)
489      ag_ctx = autograph_ctx.control_status_ctx()
490      autographed_fn = autograph.tf_convert(fn, ag_ctx)
491      result_value = autographed_fn(elems_value)
492      nest.assert_same_structure(fn_output_signature or elems, result_value)
493      result_value_flat = nest.flatten(result_value)
494      result_value_batchable = _result_value_flat_to_batchable(
495          result_value_flat, result_flat_signature)
496      tas = [
497          ta.write(i, value) for (ta, value) in zip(tas, result_value_batchable)
498      ]
499      return (i + 1, tas)
500
501    _, r_a = control_flow_ops.while_loop(
502        lambda i, _: i < n,
503        compute, (i, result_batchable_ta),
504        parallel_iterations=parallel_iterations,
505        back_prop=back_prop,
506        swap_memory=swap_memory,
507        maximum_iterations=n)
508    result_batchable = [r.stack() for r in r_a]
509
510    # Update each output tensor w/ static shape info about the outer dimension.
511    for r in result_batchable:
512      r.set_shape(tensor_shape.TensorShape(n_static).concatenate(
513          r.get_shape()[1:]))
514
515    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
516    # supported in Eager
517    if in_graph_mode and varscope_caching_device_was_none:
518      varscope.set_caching_device(None)
519
520    result_flat = _result_batchable_to_flat(result_batchable,
521                                            result_flat_signature,
522                                            n_static)
523    result = result_unflatten(result_flat)
524    return result
525
526
527def _dtype_to_spec(d):
528  if not isinstance(d, type_spec.TypeSpec):
529    d = tensor_spec.TensorSpec(None, d)
530  return d
531
532
533def _most_general_compatible_type(spec):
534  """Returns the most general TypeSpec compatible with `spec`."""
535  # TODO(edloper): Consider adding most_general_compatible_type to TypeSpec API
536  if isinstance(spec, tensor_spec.TensorSpec):
537    return tensor_spec.TensorSpec(None, spec.dtype)
538  elif isinstance(spec, ragged_tensor.RaggedTensorSpec):
539    # pylint: disable=protected-access
540    return ragged_tensor.RaggedTensorSpec(None, spec._dtype, spec._ragged_rank,
541                                          spec._row_splits_dtype)
542  elif isinstance(spec, sparse_tensor.SparseTensorSpec):
543    # pylint: disable=protected-access
544    return sparse_tensor.SparseTensorSpec(None, spec.dtype)
545  else:
546    return spec
547
548
549def _result_flat_signature_to_batchable_tensor_spec(result_flat_signature):
550  """Converts result_flat_signature -> result_batchable_tensor_specs."""
551  tensor_specs = []
552  for spec in result_flat_signature:
553    if not isinstance(spec, type_spec.BatchableTypeSpec):
554      raise TypeError("map_fn can not generate %s outputs" % (spec,))
555    tensor_specs.extend(spec._flat_tensor_specs)  # pylint: disable=protected-access
556  return tensor_specs
557
558
559def _elems_flat_to_batchable(elems_flat):
560  """Converts elems_flat -> elems_batchable."""
561  elems_batchable = []
562  for elems_tensor in elems_flat:
563    spec = type_spec.type_spec_from_value(elems_tensor)
564    if not isinstance(spec, type_spec.BatchableTypeSpec):
565      raise TypeError("map_fn can not consume %s inputs: got %r" %
566                      (spec, elems_tensor))
567    # pylint: disable=protected-access
568    elems_batchable.extend(spec._to_batched_tensor_list(elems_tensor))
569  return elems_batchable
570
571
572def _elems_value_batchable_to_flat(elems_value_batchable, elems_flat_signature):
573  """Converts elems_value_batchable -> elems_value_flat."""
574  elems_value_flat = []
575  i = 0
576  for spec in elems_flat_signature:
577    # pylint: disable=protected-access
578    spec = spec._unbatch()
579    tensor_list = elems_value_batchable[i:i + len(spec._flat_tensor_specs)]
580    elems_value_flat.append(spec._from_compatible_tensor_list(tensor_list))
581    i += len(tensor_list)
582  assert i == len(elems_value_batchable)
583  return elems_value_flat
584
585
586def _result_value_flat_to_batchable(result_value_flat, result_flat_signature):
587  """Converts result_value_flat -> result_value_batchable."""
588  result_value_batchable = []
589  for (r_value, r_spec) in zip(result_value_flat, result_flat_signature):
590    if isinstance(r_spec, tensor_spec.TensorSpec):
591      result_value_batchable.append(r_value)
592    else:
593      if not r_spec.is_compatible_with(r_value):
594        raise ValueError(
595            "Error in map_fn:\n  Expected `fn` to return a:\n    %s\n"
596            "  But it returned a:\n    %s\n    (value=%s)\n"
597            "  To fix, update the `fn_output_signature` (or `dtype`) "
598            "argument to `map_fn`." %
599            (r_spec, type_spec.type_spec_from_value(r_value), r_value))
600      result_value_batchable.extend(r_spec._to_tensor_list(r_value))  # pylint: disable=protected-access
601  return result_value_batchable
602
603
604def _result_batchable_to_flat(result_batchable, result_flat_signature,
605                              batch_size):
606  """Converts result_batchable -> result_flat."""
607  result_flat = []
608  i = 0
609  for spec in result_flat_signature:
610    # pylint: disable=protected-access
611    num_tensors = len(spec._flat_tensor_specs)
612    result_flat.append(
613        spec._batch(batch_size)._from_compatible_tensor_list(
614            result_batchable[i:i + num_tensors]))
615    i += num_tensors
616  assert i == len(result_batchable)
617  return result_flat
618
619
620@tf_export("map_fn", v1=[])
621@deprecation.deprecated_arg_values(
622    None,
623    """back_prop=False is deprecated. Consider using tf.stop_gradient instead.
624Instead of:
625results = tf.map_fn(fn, elems, back_prop=False)
626Use:
627results = tf.nest.map_structure(tf.stop_gradient, tf.map_fn(fn, elems))""",
628    warn_once=True,
629    back_prop=False)
630@deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype")
631def map_fn_v2(fn,
632              elems,
633              dtype=None,
634              parallel_iterations=None,
635              back_prop=True,
636              swap_memory=False,
637              infer_shape=True,
638              name=None,
639              fn_output_signature=None):
640  """Transform `elems` by applying `fn` to each element unstacked on axis 0."""
641  if fn_output_signature is None:
642    fn_output_signature = dtype
643  return map_fn(
644      fn=fn,
645      elems=elems,
646      fn_output_signature=fn_output_signature,
647      parallel_iterations=parallel_iterations,
648      back_prop=back_prop,
649      swap_memory=swap_memory,
650      infer_shape=infer_shape,
651      name=name)
652
653
654# Docstring for v2 is the same as v1, except that back_prop is deprecated.
655map_fn_v2.__doc__ = re.sub(
656    r"(  back_prop: \(optional\) )(.*)",
657    r"\1Deprecated: prefer using `tf.stop_gradient` instead.  \2",
658    map_fn.__doc__)
659assert "prefer using `tf.stop_gradient` instead" in map_fn_v2.__doc__
660