1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Functional operations.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22 23import re 24 25from tensorflow.python.autograph.core import ag_ctx as autograph_ctx 26from tensorflow.python.autograph.impl import api as autograph 27from tensorflow.python.eager import context 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import sparse_tensor 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.framework import tensor_spec 33from tensorflow.python.framework import type_spec 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import control_flow_ops 36from tensorflow.python.ops import tensor_array_ops 37from tensorflow.python.ops import variable_scope as vs 38from tensorflow.python.ops.ragged import ragged_tensor 39from tensorflow.python.platform import tf_logging as logging 40from tensorflow.python.util import deprecation 41from tensorflow.python.util import nest 42from tensorflow.python.util.tf_export import tf_export 43 44 45@tf_export(v1=["map_fn"]) 46@deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype") 47def map_fn(fn, 48 elems, 49 dtype=None, 50 parallel_iterations=None, 51 back_prop=True, 52 swap_memory=False, 53 infer_shape=True, 54 name=None, 55 fn_output_signature=None): 56 """Transforms `elems` by applying `fn` to each element unstacked on axis 0. 57 58 See also `tf.scan`. 59 60 `map_fn` unstacks `elems` on axis 0 to obtain a sequence of elements; 61 calls `fn` to transform each element; and then stacks the transformed 62 values back together. 63 64 #### Mapping functions with single-Tensor inputs and outputs 65 66 If `elems` is a single tensor and `fn`'s signature is `tf.Tensor->tf.Tensor`, 67 then `map_fn(fn, elems)` is equivalent to 68 `tf.stack([fn(elem) for elem in tf.unstack(elems)])`. E.g.: 69 70 >>> tf.map_fn(fn=lambda t: tf.range(t, t + 3), elems=tf.constant([3, 5, 2])) 71 <tf.Tensor: shape=(3, 3), dtype=int32, numpy= 72 array([[3, 4, 5], 73 [5, 6, 7], 74 [2, 3, 4]], dtype=int32)> 75 76 `map_fn(fn, elems).shape = [elems.shape[0]] + fn(elems[0]).shape`. 77 78 #### Mapping functions with multi-arity inputs and outputs 79 80 `map_fn` also supports functions with multi-arity inputs and outputs: 81 82 * If `elems` is a tuple (or nested structure) of tensors, then those tensors 83 must all have the same outer-dimension size (`num_elems`); and `fn` is 84 used to transform each tuple (or structure) of corresponding slices from 85 `elems`. E.g., if `elems` is a tuple `(t1, t2, t3)`, then `fn` is used to 86 transform each tuple of slices `(t1[i], t2[i], t3[i])` 87 (where `0 <= i < num_elems`). 88 89 * If `fn` returns a tuple (or nested structure) of tensors, then the 90 result is formed by stacking corresponding elements from those structures. 91 92 #### Specifying `fn`'s output signature 93 94 If `fn`'s input and output signatures are different, then the output 95 signature must be specified using `fn_output_signature`. (The input and 96 output signatures are differ if their structures, dtypes, or tensor types do 97 not match). E.g.: 98 99 >>> tf.map_fn(fn=tf.strings.length, # input & output have different dtypes 100 ... elems=tf.constant(["hello", "moon"]), 101 ... fn_output_signature=tf.int32) 102 <tf.Tensor: shape=(2,), dtype=int32, numpy=array([5, 4], dtype=int32)> 103 >>> tf.map_fn(fn=tf.strings.join, # input & output have different structures 104 ... elems=[tf.constant(['The', 'A']), tf.constant(['Dog', 'Cat'])], 105 ... fn_output_signature=tf.string) 106 <tf.Tensor: shape=(2,), dtype=string, 107 numpy=array([b'TheDog', b'ACat'], dtype=object)> 108 109 `fn_output_signature` can be specified using any of the following: 110 111 * A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`) 112 * A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`) 113 * A `tf.SparseTensorSpec` (to describe a `tf.sparse.SparseTensor`) 114 * A (possibly nested) tuple, list, or dict containing the above types. 115 116 #### RaggedTensors 117 118 `map_fn` supports `tf.RaggedTensor` inputs and outputs. In particular: 119 120 * If `elems` is a `RaggedTensor`, then `fn` will be called with each 121 row of that ragged tensor. 122 * If `elems` has only one ragged dimension, then the values passed to 123 `fn` will be `tf.Tensor`s. 124 * If `elems` has multiple ragged dimensions, then the values passed to 125 `fn` will be `tf.RaggedTensor`s with one fewer ragged dimension. 126 127 * If the result of `map_fn` should be a `RaggedTensor`, then use a 128 `tf.RaggedTensorSpec` to specify `fn_output_signature`. 129 * If `fn` returns `tf.Tensor`s with varying sizes, then use a 130 `tf.RaggedTensorSpec` with `ragged_rank=0` to combine them into a 131 single ragged tensor (which will have ragged_rank=1). 132 * If `fn` returns `tf.RaggedTensor`s, then use a `tf.RaggedTensorSpec` 133 with the same `ragged_rank`. 134 135 >>> # Example: RaggedTensor input 136 >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]]) 137 >>> tf.map_fn(tf.reduce_sum, rt, fn_output_signature=tf.int32) 138 <tf.Tensor: shape=(4,), dtype=int32, numpy=array([6, 0, 9, 6], dtype=int32)> 139 140 >>> # Example: RaggedTensor output 141 >>> elems = tf.constant([3, 5, 0, 2]) 142 >>> tf.map_fn(tf.range, elems, 143 ... fn_output_signature=tf.RaggedTensorSpec(shape=[None], 144 ... dtype=tf.int32)) 145 <tf.RaggedTensor [[0, 1, 2], [0, 1, 2, 3, 4], [], [0, 1]]> 146 147 Note: `map_fn` should only be used if you need to map a function over the 148 *rows* of a `RaggedTensor`. If you wish to map a function over the 149 individual values, then you should use: 150 151 * `tf.ragged.map_flat_values(fn, rt)` 152 (if fn is expressible as TensorFlow ops) 153 * `rt.with_flat_values(map_fn(fn, rt.flat_values))` 154 (otherwise) 155 156 E.g.: 157 158 >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]]) 159 >>> tf.ragged.map_flat_values(lambda x: x + 2, rt) 160 <tf.RaggedTensor [[3, 4, 5], [], [6, 7], [8]]> 161 162 #### SparseTensors 163 164 `map_fn` supports `tf.sparse.SparseTensor` inputs and outputs. In particular: 165 166 * If `elems` is a `SparseTensor`, then `fn` will be called with each row 167 of that sparse tensor. In particular, the value passed to `fn` will be a 168 `tf.sparse.SparseTensor` with one fewer dimension than `elems`. 169 170 * If the result of `map_fn` should be a `SparseTensor`, then use a 171 `tf.SparseTensorSpec` to specify `fn_output_signature`. The individual 172 `SparseTensor`s returned by `fn` will be stacked into a single 173 `SparseTensor` with one more dimension. 174 175 >>> # Example: SparseTensor input 176 >>> st = tf.sparse.SparseTensor([[0, 0], [2, 0], [2, 1]], [2, 3, 4], [4, 4]) 177 >>> tf.map_fn(tf.sparse.reduce_sum, st, fn_output_signature=tf.int32) 178 <tf.Tensor: shape=(4,), dtype=int32, numpy=array([2, 0, 7, 0], dtype=int32)> 179 180 >>> # Example: SparseTensor output 181 >>> tf.sparse.to_dense( 182 ... tf.map_fn(tf.sparse.eye, tf.constant([2, 3]), 183 ... fn_output_signature=tf.SparseTensorSpec(None, tf.float32))) 184 <tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy= 185 array([[[1., 0., 0.], 186 [0., 1., 0.], 187 [0., 0., 0.]], 188 [[1., 0., 0.], 189 [0., 1., 0.], 190 [0., 0., 1.]]], dtype=float32)> 191 192 Note: `map_fn` should only be used if you need to map a function over the 193 *rows* of a `SparseTensor`. If you wish to map a function over the nonzero 194 values, then you should use: 195 196 * If the function is expressible as TensorFlow ops, use: 197 ```python 198 tf.sparse.SparseTensor(st.indices, fn(st.values), st.dense_shape) 199 ``` 200 * Otherwise, use: 201 ```python 202 tf.sparse.SparseTensor(st.indices, tf.map_fn(fn, st.values), 203 st.dense_shape) 204 ``` 205 206 #### `map_fn` vs. vectorized operations 207 208 `map_fn` will apply the operations used by `fn` to each element of `elems`, 209 resulting in `O(elems.shape[0])` total operations. This is somewhat 210 mitigated by the fact that `map_fn` can process elements in parallel. 211 However, a transform expressed using `map_fn` is still typically less 212 efficient than an equivalent transform expressed using vectorized operations. 213 214 `map_fn` should typically only be used if one of the following is true: 215 216 * It is difficult or expensive to express the desired transform with 217 vectorized operations. 218 * `fn` creates large intermediate values, so an equivalent vectorized 219 transform would take too much memory. 220 * Processing elements in parallel is more efficient than an equivalent 221 vectorized transform. 222 * Efficiency of the transform is not critical, and using `map_fn` is 223 more readable. 224 225 E.g., the example given above that maps `fn=lambda t: tf.range(t, t + 3)` 226 across `elems` could be rewritten more efficiently using vectorized ops: 227 228 >>> elems = tf.constant([3, 5, 2]) 229 >>> tf.range(3) + tf.expand_dims(elems, 1) 230 <tf.Tensor: shape=(3, 3), dtype=int32, numpy= 231 array([[3, 4, 5], 232 [5, 6, 7], 233 [2, 3, 4]], dtype=int32)> 234 235 In some cases, `tf.vectorized_map` can be used to automatically convert a 236 function to a vectorized equivalent. 237 238 #### Eager execution 239 240 When executing eagerly, `map_fn` does not execute in parallel even if 241 `parallel_iterations` is set to a value > 1. You can still get the 242 performance benefits of running a function in parallel by using the 243 `tf.function` decorator: 244 245 >>> fn=lambda t: tf.range(t, t + 3) 246 >>> @tf.function 247 ... def func(elems): 248 ... return tf.map_fn(fn, elems, parallel_iterations=3) 249 >>> func(tf.constant([3, 5, 2])) 250 <tf.Tensor: shape=(3, 3), dtype=int32, numpy= 251 array([[3, 4, 5], 252 [5, 6, 7], 253 [2, 3, 4]], dtype=int32)> 254 255 256 Note: if you use the `tf.function` decorator, any non-TensorFlow Python 257 code that you may have written in your function won't get executed. See 258 `tf.function` for more details. The recommendation would be to debug without 259 `tf.function` but switch to it to get performance benefits of running `map_fn` 260 in parallel. 261 262 Args: 263 fn: The callable to be performed. It accepts one argument, which will have 264 the same (possibly nested) structure as `elems`. Its output must have the 265 same structure as `fn_output_signature` if one is provided; otherwise it 266 must have the same structure as `elems`. 267 elems: A tensor or (possibly nested) sequence of tensors, each of which will 268 be unstacked along their first dimension. `fn` will be applied to the 269 nested sequence of the resulting slices. `elems` may include ragged and 270 sparse tensors. `elems` must consist of at least one tensor. 271 dtype: Deprecated: Equivalent to `fn_output_signature`. 272 parallel_iterations: (optional) The number of iterations allowed to run in 273 parallel. When graph building, the default value is 10. While executing 274 eagerly, the default value is set to 1. 275 back_prop: (optional) False disables support for back propagation. 276 swap_memory: (optional) True enables GPU-CPU memory swapping. 277 infer_shape: (optional) False disables tests for consistent output shapes. 278 name: (optional) Name prefix for the returned tensors. 279 fn_output_signature: The output signature of `fn`. Must be specified if 280 `fn`'s input and output signatures are different (i.e., if their 281 structures, dtypes, or tensor types do not match). 282 `fn_output_signature` can be specified using any of the following: 283 284 * A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`) 285 * A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`) 286 * A `tf.SparseTensorSpec` (to describe a `tf.sparse.SparseTensor`) 287 * A (possibly nested) tuple, list, or dict containing the above types. 288 289 Returns: 290 A tensor or (possibly nested) sequence of tensors. Each tensor stacks the 291 results of applying `fn` to tensors unstacked from `elems` along the first 292 dimension, from first to last. The result may include ragged and sparse 293 tensors. 294 295 Raises: 296 TypeError: if `fn` is not callable or the structure of the output of 297 `fn` and `fn_output_signature` do not match. 298 ValueError: if the lengths of the output of `fn` and `fn_output_signature` 299 do not match, or if the `elems` does not contain any tensor. 300 301 Examples: 302 303 >>> elems = np.array([1, 2, 3, 4, 5, 6]) 304 >>> tf.map_fn(lambda x: x * x, elems) 305 <tf.Tensor: shape=(6,), dtype=int64, numpy=array([ 1, 4, 9, 16, 25, 36])> 306 307 >>> elems = (np.array([1, 2, 3]), np.array([-1, 1, -1])) 308 >>> tf.map_fn(lambda x: x[0] * x[1], elems, fn_output_signature=tf.int64) 309 <tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1, 2, -3])> 310 311 >>> elems = np.array([1, 2, 3]) 312 >>> tf.map_fn(lambda x: (x, -x), elems, 313 ... fn_output_signature=(tf.int64, tf.int64)) 314 (<tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>, 315 <tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1, -2, -3])>) 316 """ 317 # This function uses a `while_loop` to call `fn` on each value of the input 318 # tensor(s) (unstacked on dimension 0). The following sequence of variables 319 # are used to transform the input tensor(s) (`elems`) into the output 320 # tensor(s) (`result`): 321 # 322 # - Preparing and unstacking input values for the while_loop: 323 # - elems: The input tensor(s) to map_fn. May include composite tensors. 324 # - elems_flat: Flattened list of tensors from elems (using nest.flatten) 325 # May include composite tensors. 326 # - elems_batchable: Concatenation of "batchable tensor lists" for each 327 # tensor in elems_flat. This "boxes" composite tensors 328 # into sliceable tf.Tensor objects. For more info see: 329 # TensorSpec._to_batched_tensor_list 330 # - elems_batchable_ta: List of TensorArrays used to unstack each Tensor 331 # in elems_batchable into elems_value_batchable. 332 # 333 # - Calling `fn` on each unstacked value in the body of the while_loop: 334 # - elems_value_batchable: Single unstacked value from elems_batchable. 335 # - elems_value_flat: Single unstacked value from elems_flat, 336 # constructed from elems_value_batchable (using 337 # TensorSpec._from_tensor_list). 338 # - elems_value: Single unstacked value from elems (the input to fn). 339 # - result_value: Result of calling `fn(elems_value)`. May contain 340 # composite tensors. 341 # - result_value_flat: Flattened list of tensors from result_value. 342 # May contain composite tensors. 343 # - result_value_batchable: Concatenation of batchable tensor lists for 344 # each tensor in result_value_flat 345 # (using TensorSpec._to_tensor_list). 346 # 347 # - Collecting and stacking output values from the while_loop: 348 # - result_batchable_ta: List of TensorArrays used to stack each tensor 349 # ta result_value_batchable into result_batchable. 350 # - result_batchable: Stacked tensors from result_batchable_ta. 351 # - result_flat: Flat list of tensors for the result, constructed from 352 # results bactchable (using TensorSpec._from_tensor_list). 353 # - result: Structured result value packed from results flat 354 # (using nest.pack_sequence_as). 355 356 if fn_output_signature is None: 357 fn_output_signature = dtype 358 359 if not callable(fn): 360 raise TypeError("fn must be callable.") 361 362 in_graph_mode = not context.executing_eagerly() 363 # Set the default number of parallel_iterations depending on graph/eager mode. 364 if in_graph_mode and not parallel_iterations: 365 parallel_iterations = 10 366 elif not in_graph_mode and not parallel_iterations: 367 parallel_iterations = 1 368 elif not in_graph_mode and parallel_iterations > 1: 369 logging.log_first_n( 370 logging.WARN, "Setting parallel_iterations > 1 has no " 371 "effect when executing eagerly. Consider calling map_fn" 372 " with tf.function to execute fn in " 373 "parallel.", 1) 374 parallel_iterations = 1 375 376 # Flatten the input tensors, and get the TypeSpec for each one. 377 elems_flat = nest.flatten(elems) 378 379 # Check in case this is an empty list 380 if len(elems_flat) == 0: 381 raise ValueError( 382 "elems must be a Tensor or (possibly nested) sequence of Tensors. " 383 "Got {}, which does not contain any Tensors.".format(elems)) 384 385 elems_flat_signature = [type_spec.type_spec_from_value(e) for e in elems_flat] 386 elems_unflatten = lambda x: nest.pack_sequence_as(elems, x) 387 388 # Flatten fn's output signature. 389 if fn_output_signature is None: 390 # If fn_output_signature was not specified, then assume that it matches the 391 # input signature. 392 result_flat_signature = [ 393 _most_general_compatible_type(s)._unbatch() # pylint: disable=protected-access 394 for s in elems_flat_signature 395 ] 396 result_unflatten = elems_unflatten 397 else: 398 result_flat_signature = [ 399 _dtype_to_spec(d) for d in nest.flatten(fn_output_signature) 400 ] 401 result_unflatten = lambda x: nest.pack_sequence_as(fn_output_signature, x) 402 403 with ops.name_scope(name, "map", elems_flat): 404 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 405 # supported in Eager 406 if in_graph_mode: 407 # Any get_variable calls in fn will cache the first call locally 408 # and not issue repeated network I/O requests for each iteration. 409 varscope = vs.get_variable_scope() 410 varscope_caching_device_was_none = False 411 if varscope.caching_device is None: 412 # TODO(ebrevdo): Change to using colocate_with here and in other 413 # methods. 414 varscope.set_caching_device(lambda op: op.device) 415 varscope_caching_device_was_none = True 416 417 elems_flat = [ 418 ops.convert_to_tensor_or_composite(t, name="elem") for t in elems_flat 419 ] 420 421 # Check that inputs are not scalars. 422 first_elem = elems_flat[0] 423 elems_static_shape = first_elem.shape 424 if elems_static_shape.ndims is not None and elems_static_shape.ndims < 1: 425 if len(elems_flat) == 1: 426 raise ValueError("elems must be a 1+ dimensional Tensor, not a scalar") 427 else: 428 raise ValueError( 429 "elements in elems must be 1+ dimensional Tensors, not scalars" 430 ) 431 432 # Box any composite tensors into tensor lists. 433 elems_batchable = _elems_flat_to_batchable(elems_flat) 434 435 # Find the number of iterations, n. (may be known statically.) 436 n_static = tensor_shape.Dimension( 437 tensor_shape.dimension_value( 438 elems_batchable[0].get_shape().with_rank_at_least(1)[0])) 439 for tensor in elems_batchable[1:]: 440 n_static.assert_is_compatible_with( 441 tensor_shape.Dimension( 442 tensor_shape.dimension_value( 443 tensor.get_shape().with_rank_at_least(1)[0]))) 444 n = n_static.value or array_ops.shape(elems_batchable[0])[0] 445 446 # Convert elems to tensor array. 447 # TODO(edloper): Should we set infer_shape=False for composite tensors? 448 elems_batchable_ta = [ 449 tensor_array_ops.TensorArray( 450 dtype=t.dtype, size=n, dynamic_size=False, infer_shape=True) 451 for t in elems_batchable 452 ] 453 # Unpack elements 454 elems_batchable_ta = [ 455 ta.unstack(t) for (ta, t) in zip(elems_batchable_ta, elems_batchable) 456 ] 457 458 i = constant_op.constant(0) 459 460 # Prepare result tensor array. 461 # TODO(edloper): Should we set infer_shape=False for composite tensors? 462 result_batchable_tensor_spec = ( 463 _result_flat_signature_to_batchable_tensor_spec(result_flat_signature)) 464 result_batchable_ta = [] 465 for spec in result_batchable_tensor_spec: 466 result_batchable_ta.append( 467 tensor_array_ops.TensorArray( 468 dtype=spec.dtype, size=n, dynamic_size=False, 469 infer_shape=infer_shape, element_shape=spec.shape)) 470 471 def compute(i, tas): 472 """The loop body of map_fn. 473 474 Args: 475 i: the loop counter 476 tas: the flat TensorArray accumulator list 477 478 Returns: 479 (i + 1, tas): the updated counter + updated TensorArrays 480 481 Raises: 482 TypeError: if fn_output_signature and result_value structure don't match 483 ValueType: if fn_output_signature and result_value lengths don't match 484 """ 485 elems_value_batchable = [ta.read(i) for ta in elems_batchable_ta] 486 elems_value_flat = _elems_value_batchable_to_flat(elems_value_batchable, 487 elems_flat_signature) 488 elems_value = elems_unflatten(elems_value_flat) 489 ag_ctx = autograph_ctx.control_status_ctx() 490 autographed_fn = autograph.tf_convert(fn, ag_ctx) 491 result_value = autographed_fn(elems_value) 492 nest.assert_same_structure(fn_output_signature or elems, result_value) 493 result_value_flat = nest.flatten(result_value) 494 result_value_batchable = _result_value_flat_to_batchable( 495 result_value_flat, result_flat_signature) 496 tas = [ 497 ta.write(i, value) for (ta, value) in zip(tas, result_value_batchable) 498 ] 499 return (i + 1, tas) 500 501 _, r_a = control_flow_ops.while_loop( 502 lambda i, _: i < n, 503 compute, (i, result_batchable_ta), 504 parallel_iterations=parallel_iterations, 505 back_prop=back_prop, 506 swap_memory=swap_memory, 507 maximum_iterations=n) 508 result_batchable = [r.stack() for r in r_a] 509 510 # Update each output tensor w/ static shape info about the outer dimension. 511 for r in result_batchable: 512 r.set_shape(tensor_shape.TensorShape(n_static).concatenate( 513 r.get_shape()[1:])) 514 515 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 516 # supported in Eager 517 if in_graph_mode and varscope_caching_device_was_none: 518 varscope.set_caching_device(None) 519 520 result_flat = _result_batchable_to_flat(result_batchable, 521 result_flat_signature, 522 n_static) 523 result = result_unflatten(result_flat) 524 return result 525 526 527def _dtype_to_spec(d): 528 if not isinstance(d, type_spec.TypeSpec): 529 d = tensor_spec.TensorSpec(None, d) 530 return d 531 532 533def _most_general_compatible_type(spec): 534 """Returns the most general TypeSpec compatible with `spec`.""" 535 # TODO(edloper): Consider adding most_general_compatible_type to TypeSpec API 536 if isinstance(spec, tensor_spec.TensorSpec): 537 return tensor_spec.TensorSpec(None, spec.dtype) 538 elif isinstance(spec, ragged_tensor.RaggedTensorSpec): 539 # pylint: disable=protected-access 540 return ragged_tensor.RaggedTensorSpec(None, spec._dtype, spec._ragged_rank, 541 spec._row_splits_dtype) 542 elif isinstance(spec, sparse_tensor.SparseTensorSpec): 543 # pylint: disable=protected-access 544 return sparse_tensor.SparseTensorSpec(None, spec.dtype) 545 else: 546 return spec 547 548 549def _result_flat_signature_to_batchable_tensor_spec(result_flat_signature): 550 """Converts result_flat_signature -> result_batchable_tensor_specs.""" 551 tensor_specs = [] 552 for spec in result_flat_signature: 553 if not isinstance(spec, type_spec.BatchableTypeSpec): 554 raise TypeError("map_fn can not generate %s outputs" % (spec,)) 555 tensor_specs.extend(spec._flat_tensor_specs) # pylint: disable=protected-access 556 return tensor_specs 557 558 559def _elems_flat_to_batchable(elems_flat): 560 """Converts elems_flat -> elems_batchable.""" 561 elems_batchable = [] 562 for elems_tensor in elems_flat: 563 spec = type_spec.type_spec_from_value(elems_tensor) 564 if not isinstance(spec, type_spec.BatchableTypeSpec): 565 raise TypeError("map_fn can not consume %s inputs: got %r" % 566 (spec, elems_tensor)) 567 # pylint: disable=protected-access 568 elems_batchable.extend(spec._to_batched_tensor_list(elems_tensor)) 569 return elems_batchable 570 571 572def _elems_value_batchable_to_flat(elems_value_batchable, elems_flat_signature): 573 """Converts elems_value_batchable -> elems_value_flat.""" 574 elems_value_flat = [] 575 i = 0 576 for spec in elems_flat_signature: 577 # pylint: disable=protected-access 578 spec = spec._unbatch() 579 tensor_list = elems_value_batchable[i:i + len(spec._flat_tensor_specs)] 580 elems_value_flat.append(spec._from_compatible_tensor_list(tensor_list)) 581 i += len(tensor_list) 582 assert i == len(elems_value_batchable) 583 return elems_value_flat 584 585 586def _result_value_flat_to_batchable(result_value_flat, result_flat_signature): 587 """Converts result_value_flat -> result_value_batchable.""" 588 result_value_batchable = [] 589 for (r_value, r_spec) in zip(result_value_flat, result_flat_signature): 590 if isinstance(r_spec, tensor_spec.TensorSpec): 591 result_value_batchable.append(r_value) 592 else: 593 if not r_spec.is_compatible_with(r_value): 594 raise ValueError( 595 "Error in map_fn:\n Expected `fn` to return a:\n %s\n" 596 " But it returned a:\n %s\n (value=%s)\n" 597 " To fix, update the `fn_output_signature` (or `dtype`) " 598 "argument to `map_fn`." % 599 (r_spec, type_spec.type_spec_from_value(r_value), r_value)) 600 result_value_batchable.extend(r_spec._to_tensor_list(r_value)) # pylint: disable=protected-access 601 return result_value_batchable 602 603 604def _result_batchable_to_flat(result_batchable, result_flat_signature, 605 batch_size): 606 """Converts result_batchable -> result_flat.""" 607 result_flat = [] 608 i = 0 609 for spec in result_flat_signature: 610 # pylint: disable=protected-access 611 num_tensors = len(spec._flat_tensor_specs) 612 result_flat.append( 613 spec._batch(batch_size)._from_compatible_tensor_list( 614 result_batchable[i:i + num_tensors])) 615 i += num_tensors 616 assert i == len(result_batchable) 617 return result_flat 618 619 620@tf_export("map_fn", v1=[]) 621@deprecation.deprecated_arg_values( 622 None, 623 """back_prop=False is deprecated. Consider using tf.stop_gradient instead. 624Instead of: 625results = tf.map_fn(fn, elems, back_prop=False) 626Use: 627results = tf.nest.map_structure(tf.stop_gradient, tf.map_fn(fn, elems))""", 628 warn_once=True, 629 back_prop=False) 630@deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype") 631def map_fn_v2(fn, 632 elems, 633 dtype=None, 634 parallel_iterations=None, 635 back_prop=True, 636 swap_memory=False, 637 infer_shape=True, 638 name=None, 639 fn_output_signature=None): 640 """Transform `elems` by applying `fn` to each element unstacked on axis 0.""" 641 if fn_output_signature is None: 642 fn_output_signature = dtype 643 return map_fn( 644 fn=fn, 645 elems=elems, 646 fn_output_signature=fn_output_signature, 647 parallel_iterations=parallel_iterations, 648 back_prop=back_prop, 649 swap_memory=swap_memory, 650 infer_shape=infer_shape, 651 name=name) 652 653 654# Docstring for v2 is the same as v1, except that back_prop is deprecated. 655map_fn_v2.__doc__ = re.sub( 656 r"( back_prop: \(optional\) )(.*)", 657 r"\1Deprecated: prefer using `tf.stop_gradient` instead. \2", 658 map_fn.__doc__) 659assert "prefer using `tf.stop_gradient` instead" in map_fn_v2.__doc__ 660