• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""for_loop and pfor ops."""
16# pylint: disable=g-direct-tensorflow-import
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import functools
23
24from tensorflow.python.eager import context
25from tensorflow.python.eager import def_function
26from tensorflow.python.framework import composite_tensor
27from tensorflow.python.framework import indexed_slices
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import sparse_tensor
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.framework import tensor_util
32from tensorflow.python.framework import type_spec
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import tensor_array_ops
37from tensorflow.python.ops.parallel_for.pfor import PFor
38from tensorflow.python.ops.parallel_for.pfor import PForConfig
39from tensorflow.python.platform import tf_logging as logging
40from tensorflow.python.util import nest
41from tensorflow.python.util import tf_decorator
42from tensorflow.python.util import tf_inspect
43from tensorflow.python.util.tf_export import tf_export
44
45
46def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None):
47  """Runs `loop_fn` `iters` times and stacks the outputs.
48
49
50  Runs `loop_fn` `iters` times, with input values from 0 to `iters - 1`, and
51  stacks corresponding outputs of the different runs.
52
53  Args:
54    loop_fn: A function that takes an int32 scalar tf.Tensor object representing
55      the iteration number, and returns a possibly nested structure of tensor
56      objects. The shape of these outputs should not depend on the input.
57    loop_fn_dtypes: dtypes for the outputs of `loop_fn`.
58    iters: Number of iterations for which to run `loop_fn`.
59    parallel_iterations: The number of iterations that can be dispatched in
60      parallel. This knob can be used to control the total memory usage.
61
62  Returns:
63    Returns a nested structure of stacked output tensor objects with the same
64    nested structure as the output of `loop_fn`.
65  """
66
67  flat_loop_fn_dtypes = nest.flatten(loop_fn_dtypes)
68  is_none_list = []
69
70  def while_body(i, *ta_list):
71    """Body of while loop."""
72    fn_output = nest.flatten(loop_fn(i))
73    if len(fn_output) != len(flat_loop_fn_dtypes):
74      raise ValueError(
75          "Number of expected outputs, %d, does not match the number of "
76          "actual outputs, %d, from loop_fn" % (len(flat_loop_fn_dtypes),
77                                                len(fn_output)))
78    outputs = []
79    del is_none_list[:]
80    is_none_list.extend(x is None for x in fn_output)
81    for out, ta in zip(fn_output, ta_list):
82      # TODO(agarwal): support returning Operation objects from loop_fn.
83      if out is not None:
84        # out may be a ref tensor, wrap it in identity to get a non-ref tensor.
85        ta = ta.write(i, array_ops.expand_dims(out, 0))
86      outputs.append(ta)
87    return tuple([i + 1] + outputs)
88
89  if parallel_iterations is not None:
90    extra_args = {"parallel_iterations": parallel_iterations}
91  else:
92    extra_args = {}
93  ta_list = control_flow_ops.while_loop(
94      lambda i, *ta: i < iters,
95      while_body,
96      [0] + [tensor_array_ops.TensorArray(dtype.base_dtype, iters)
97             for dtype in flat_loop_fn_dtypes],
98      **extra_args)[1:]
99
100  # TODO(rachelim): enable this for sparse tensors
101
102  output = [None if is_none else ta.concat()
103            for ta, is_none in zip(ta_list, is_none_list)]
104  assert len(output) in (0, len(flat_loop_fn_dtypes))
105  if not output:
106    # This may happen for the case where iters == 0.
107    return None
108  else:
109    return nest.pack_sequence_as(loop_fn_dtypes, output)
110
111
112def _flatten_first_two_dims(x):
113  """Flattens the first two dimensions of x into a single dimension."""
114  old_shape = array_ops.shape(x)
115  new_shape = array_ops.concat([[old_shape[0] * old_shape[1]], old_shape[2:]],
116                               axis=0)
117  return array_ops.reshape(x, new_shape)
118
119
120PFOR_CONFIG_ARG = "pfor_config"
121
122
123def _is_under_xla_context():
124  """Check if we are currently inside an XLA compile context."""
125  g = ops.get_default_graph()
126  while g is not None:
127    control_flow_context = g._get_control_flow_context()  # pylint: disable=protected-access
128    while control_flow_context is not None:
129      if control_flow_context.IsXLAContext():
130        return True
131      else:
132        control_flow_context = control_flow_context.outer_context
133    # If g is a FuncGraph, get its outer_graph.
134    g = getattr(g, "outer_graph", None)
135  return False
136
137
138def pfor(loop_fn, iters, fallback_to_while_loop=True, parallel_iterations=None):
139  """Equivalent to running `loop_fn` `iters` times and stacking the outputs.
140
141  `pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters`
142  times, with input from 0 to `iters - 1`, and stacking corresponding output of
143  each iteration. However the implementation does not use a `tf.while_loop`.
144  Instead it adds new operations to the graph that collectively compute the same
145  value as what running `loop_fn` in a loop would compute.
146
147
148  This is an experimental feature and currently has a lot of limitations:
149    - There should be no data dependency between the different iterations. For
150      example, a future iteration should not depend on a value or side-effect of
151      a previous iteration.
152    - Stateful kernels may mostly not be supported since these often imply a
153      data dependency or ordering of the iterations. We do support a limited set
154      of such stateful kernels though (like RandomFoo, Variable operations like
155      reads, etc).
156    - Conversion works only on a limited set of kernels for which a converter
157      has been registered.
158    - `loop_fn` has limited support for control flow operations. `tf.cond` in
159      particular is not supported.
160    - `loop_fn` should return nested structure of Tensors or Operations. However
161      if an Operation is returned, it should have zero outputs.
162    - The shape and dtype of `loop_fn` outputs should not depend on the input
163      to loop_fn.
164
165  Args:
166    loop_fn: A function that takes an int32 scalar tf.Tensor object representing
167      the iteration number, and optionally a keyword argument `pfor_config` set
168      to a PForConfig object. It returns a possibly nested structure of Tensor
169      or Operation objects. Note that if setting `parallel_iterations` argument
170      to something other than None, `loop_fn` may be called more than once
171      during graph construction. So it may need to avoid mutating global state.
172    iters: Number of iterations for which to run `loop_fn`.
173    fallback_to_while_loop: If true, on failing to vectorize an operation, pfor
174      fallbacks to using a `tf.while_loop` to dispatch the iterations.
175    parallel_iterations: A knob to control how many iterations are vectorized
176      and dispatched in parallel. The default value of None corresponds to
177      vectorizing all the iterations.  If `parallel_iterations` is smaller than
178      `iters`, then chunks of at most that many iterations are dispatched in
179      sequence. This knob can be used to control the total memory usage.
180
181  Returns:
182    Returns a nested structure of stacked tensor objects with the same nested
183    structure as the output of `loop_fn`.
184  Raises:
185    ValueError: If parallel_iterations is not None and not an integer > 1.
186  """
187  def f():
188    return _pfor_impl(loop_fn,
189                      iters,
190                      fallback_to_while_loop=fallback_to_while_loop,
191                      parallel_iterations=parallel_iterations)
192  # Note that we wrap into a tf.function if in eager execution mode or under
193  # XLA compilation. The latter is so that we don't compile operations like
194  # tf.placeholder that are created by the loop body.
195  functions_run_eagerly = None
196  if context.executing_eagerly() or _is_under_xla_context():
197    functions_run_eagerly = def_function.functions_run_eagerly()
198    if functions_run_eagerly:
199      logging.warning(
200          "It looks like tf.function behavior was disabled, perhaps using "
201          "tf.config.run_functions_eagerly. Vectorization "
202          "primitives (e.g. tf.vectorized_map) require tf.function to work. "
203          "These primitives will override the disable.")
204      def_function.run_functions_eagerly(False)
205    f = def_function.function(f)
206  outputs = f()
207  if functions_run_eagerly is not None:
208    def_function.run_functions_eagerly(functions_run_eagerly)
209  return outputs
210
211
212def _should_expand_composite(value):
213  return (isinstance(value, composite_tensor.CompositeTensor)
214          # Leave sparse tensors to be converted by `PFor._convert_sparse`.
215          and not isinstance(value, sparse_tensor.SparseTensor)
216          and not isinstance(value, indexed_slices.IndexedSlices))
217
218
219# pylint: disable=protected-access
220def _composite_to_tensors(value, is_batched=False):
221  """Converts a CompositeTensor into a list of stackable tensors."""
222  if _should_expand_composite(value):
223    spec = value._type_spec
224    if not isinstance(spec, type_spec.BatchableTypeSpec):
225      raise ValueError("CompositeTensor instance {} returned from "
226                       "parallel_for or vectorized_map loop body must provide "
227                       "a `BatchableTypeSpec` (saw: {}).".format(
228                           value, spec))
229    if is_batched:
230      return spec._to_batched_tensor_list(value)
231    return spec._to_tensor_list(value)
232  return value
233# pylint: enable=protected-access
234
235
236# pylint: disable=protected-access
237def _composite_from_tensors(stacked_tensors,
238                            preconverted_value,
239                            batch_size):
240  """Converts a list of stacked tensors to a batch CompositeTensor."""
241  if _should_expand_composite(preconverted_value):
242    batch_type_spec = preconverted_value._type_spec._batch(batch_size)
243    return batch_type_spec._from_compatible_tensor_list(stacked_tensors)
244  return stacked_tensors
245# pylint: enable=protected-access
246
247
248def _loop_fn_has_config(loop_fn):
249  """Test if `loop_fn` has a `pfor_config` argument."""
250  if tf_inspect.isfunction(loop_fn):
251    argspec = tf_inspect.getargspec(loop_fn)
252    return PFOR_CONFIG_ARG in argspec.args
253  elif isinstance(loop_fn, functools.partial):
254    fn = loop_fn.func
255    argspec = tf_inspect.getargspec(fn)
256    return (PFOR_CONFIG_ARG in argspec.args and
257            PFOR_CONFIG_ARG not in loop_fn.keywords)
258  else:
259    loop_class = tf_decorator.unwrap(loop_fn)[1]
260    if not hasattr(loop_class, "__call__"):
261      raise ValueError("loop_fn object did not have a __call__ method")
262    argspec = tf_inspect.getargspec(loop_class.__call__)
263    return PFOR_CONFIG_ARG in argspec.args
264
265
266def _pfor_impl(loop_fn,
267               iters,
268               fallback_to_while_loop,
269               parallel_iterations=None,
270               pfor_config=None):
271  """Implementation of pfor."""
272  assert not context.executing_eagerly()
273  loop_fn_has_config = _loop_fn_has_config(loop_fn)
274  existing_ops = set(ops.get_default_graph().get_operations())
275  iters_value = tensor_util.constant_value(iters)
276  # Run the loop body
277  with ops.name_scope("loop_body"):
278    loop_var = array_ops.placeholder_with_default(0, shape=[])
279    if loop_fn_has_config:
280      if pfor_config is None:
281        pfor_config = PForConfig()
282        pfor_config._set_iters(iters)  # pylint: disable=protected-access
283      loop_fn_outputs = loop_fn(loop_var, **{PFOR_CONFIG_ARG: pfor_config})
284    else:
285      assert pfor_config is None
286      loop_fn_outputs = loop_fn(loop_var)
287    loop_fn_output_tensors = nest.map_structure(_composite_to_tensors,
288                                                loop_fn_outputs)
289
290  # Convert outputs to Tensor if needed.
291  tmp_loop_fn_outputs = []
292  for loop_fn_output in nest.flatten(loop_fn_output_tensors):
293    if (loop_fn_output is not None and not isinstance(
294        loop_fn_output,
295        (ops.Operation, ops.Tensor, sparse_tensor.SparseTensor))):
296      if isinstance(loop_fn_output, indexed_slices.IndexedSlices):
297        logging.warn("Converting %s to a dense representation may make it slow."
298                     " Alternatively, output the indices and values of the"
299                     " IndexedSlices separately, and handle the vectorized"
300                     " outputs directly." % loop_fn_output)
301        loop_fn_output = ops.convert_to_tensor(loop_fn_output)
302      else:
303        loop_fn_output = ops.convert_to_tensor(loop_fn_output)
304    tmp_loop_fn_outputs.append(loop_fn_output)
305  loop_fn_output_tensors = nest.pack_sequence_as(loop_fn_output_tensors,
306                                                 tmp_loop_fn_outputs)
307
308  new_ops = set(ops.get_default_graph().get_operations()) - existing_ops
309  iters = ops.convert_to_tensor(iters)
310  if parallel_iterations is not None:
311    if parallel_iterations < 1:
312      raise ValueError("parallel_iterations must be None or a positive integer")
313    if parallel_iterations == 1:
314      raise ValueError("Found parallel_iterations == 1. Use for_loop instead.")
315    if iters_value is not None and iters_value < parallel_iterations:
316      parallel_iterations = None
317  if parallel_iterations is None:
318    with ops.name_scope("pfor"):
319      converter = PFor(loop_var, iters, new_ops,
320                       fallback_to_while_loop=fallback_to_while_loop,
321                       pfor_config=pfor_config)
322      flattened_output_tensors = []
323      for loop_fn_output in nest.flatten(loop_fn_output_tensors):
324        output = converter.convert(loop_fn_output)
325        flattened_output_tensors.append(output)
326  else:
327    if pfor_config is not None and pfor_config._has_reductions():  # pylint: disable=protected-access
328      raise ValueError("Setting parallel_iterations currently unsupported if"
329                       " reductions across iterations are performed.")
330    num_tiled_iterations = iters // parallel_iterations
331    num_remaining_iterations = iters % parallel_iterations
332    # TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside
333    # a tf.function and extract the graph from there to vectorize it.
334    with ops.name_scope("pfor_untiled"):
335      converter = PFor(loop_var, num_remaining_iterations, new_ops,
336                       fallback_to_while_loop=fallback_to_while_loop,
337                       pfor_config=pfor_config)
338      remaining_output_tensors = []
339      flattened_output_tensors = nest.flatten(loop_fn_output_tensors)
340      for loop_fn_output in flattened_output_tensors:
341        output = converter.convert(loop_fn_output)
342        remaining_output_tensors.append(output)
343
344    with ops.name_scope("pfor_tiled"):
345      loop_fn_dtypes = [ops.convert_to_tensor(x).dtype
346                        for x in flattened_output_tensors]
347
348      def tiled_loop_body(j):
349        offset = j * parallel_iterations + num_remaining_iterations
350
351        def tiled_loop_fn(i, pfor_config=None):
352          if loop_fn_has_config:
353            loop_fn_outputs = loop_fn(i + offset, pfor_config=pfor_config)
354          else:
355            loop_fn_outputs = loop_fn(i + offset)
356          return nest.flatten(
357              # Stacking across iterations requires explicit Tensors.
358              nest.map_structure(_composite_to_tensors, loop_fn_outputs))
359
360        return _pfor_impl(
361            tiled_loop_fn,
362            parallel_iterations,
363            fallback_to_while_loop=fallback_to_while_loop,
364            pfor_config=pfor_config)
365
366      tiled_output_tensors = for_loop(
367          tiled_loop_body, loop_fn_dtypes,
368          num_tiled_iterations, parallel_iterations=1)
369      tiled_output_tensors = [
370          _flatten_first_two_dims(y) for y in tiled_output_tensors]
371
372    with ops.name_scope("pfor"):
373      if iters_value is None or iters_value % parallel_iterations:
374        output_tensors = control_flow_ops.cond(
375            math_ops.equal(num_remaining_iterations, 0),
376            lambda: tiled_output_tensors,
377            lambda: [array_ops.concat([x, y], axis=0)  # pylint: disable=g-long-lambda
378                     for x, y in zip(remaining_output_tensors,
379                                     tiled_output_tensors)])
380      else:
381        output_tensors = tiled_output_tensors
382      flattened_output_tensors = nest.flatten(output_tensors)
383
384      for output, original_output in zip(flattened_output_tensors,
385                                         nest.flatten(loop_fn_output_tensors)):
386        # Restore any shape information lost from tiling.
387        # TODO(b/174254748): this may not be correct for stacked `variant`s.
388        output.set_shape(
389            tensor_shape.TensorShape([iters_value]).concatenate(
390                original_output.shape))
391
392  return nest.map_structure_up_to(
393      loop_fn_outputs,
394      functools.partial(_composite_from_tensors, batch_size=iters_value),
395      nest.pack_sequence_as(loop_fn_output_tensors,
396                            flattened_output_tensors),
397      loop_fn_outputs)
398
399
400def _broadcasting_gather(x, i):
401  """Wrapper for gather that implicitly broadcasts unit dimensions."""
402  static_first_dim = tensor_shape.dimension_value(x.shape[0])
403  if static_first_dim == 1:
404    i = 0
405  elif static_first_dim is None:
406    i = array_ops.where_v2(array_ops.shape(x)[0] > 1, i, 0)
407  result = array_ops.gather(x, i)
408  return result
409
410
411# pylint: disable=protected-access
412def _gather_from_tensor_or_composite(x, i):
413  """Wrapper for gather that handles CompositeTensors."""
414  if _should_expand_composite(x):
415    spec = x._type_spec
416    gathered_tensors = [_broadcasting_gather(t, i)
417                        for t in spec._to_batched_tensor_list(x)]
418    return spec._unbatch()._from_compatible_tensor_list(gathered_tensors)
419  return _broadcasting_gather(x, i)
420# pylint: enable=protected-access
421
422
423@tf_export("vectorized_map")
424def vectorized_map(fn, elems, fallback_to_while_loop=True):
425  """Parallel map on the list of tensors unpacked from `elems` on dimension 0.
426
427  This method works similar to `tf.map_fn` but is optimized to run much faster,
428  possibly with a much larger memory footprint. The speedups are obtained by
429  vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians,
430  Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea
431  behind vectorization is to semantically launch all the invocations of `fn` in
432  parallel and fuse corresponding operations across all these invocations. This
433  fusion is done statically at graph generation time and the generated code is
434  often similar in performance to a manually fused version.
435
436  Because `tf.vectorized_map` fully parallelizes the batch, this method will
437  generally be significantly faster than using `tf.map_fn`, especially in eager
438  mode. However this is an experimental feature and currently has a lot of
439  limitations:
440    - There should be no data dependency between the different semantic
441      invocations of `fn`, i.e. it should be safe to map the elements of the
442      inputs in any order.
443    - Stateful kernels may mostly not be supported since these often imply a
444      data dependency. We do support a limited set of such stateful kernels
445      though (like RandomFoo, Variable operations like reads, etc).
446    - `fn` has limited support for control flow operations.
447    - `fn` should return nested structure of Tensors or Operations. However
448      if an Operation is returned, it should have zero outputs.
449    - The shape and dtype of any intermediate or output tensors in the
450      computation of `fn` should not depend on the input to `fn`.
451
452  Examples:
453  ```python
454  def outer_product(a):
455    return tf.tensordot(a, a, 0)
456
457  batch_size = 100
458  a = tf.ones((batch_size, 32, 32))
459  c = tf.vectorized_map(outer_product, a)
460  assert c.shape == (batch_size, 32, 32, 32, 32)
461  ```
462
463  ```python
464  # Computing per-example gradients
465
466  batch_size = 10
467  num_features = 32
468  layer = tf.keras.layers.Dense(1)
469
470  def model_fn(arg):
471    with tf.GradientTape() as g:
472      inp, label = arg
473      inp = tf.expand_dims(inp, 0)
474      label = tf.expand_dims(label, 0)
475      prediction = layer(inp)
476      loss = tf.nn.l2_loss(label - prediction)
477    return g.gradient(loss, (layer.kernel, layer.bias))
478
479  inputs = tf.random.uniform([batch_size, num_features])
480  labels = tf.random.uniform([batch_size, 1])
481  per_example_gradients = tf.vectorized_map(model_fn, (inputs, labels))
482  assert per_example_gradients[0].shape == (batch_size, num_features, 1)
483  assert per_example_gradients[1].shape == (batch_size, 1)
484  ```
485
486  Args:
487    fn: The callable to be performed. It accepts one argument, which will have
488      the same (possibly nested) structure as `elems`, and returns a possibly
489      nested structure of Tensors and Operations, which may be different than
490      the structure of `elems`.
491    elems: A tensor or (possibly nested) sequence of tensors, each of which will
492      be unpacked along their first dimension. The nested sequence of the
493      resulting slices will be mapped over by `fn`. The first dimensions of all
494      elements must broadcast to a consistent value; equivalently, each
495      element tensor must have first dimension of either `B` or `1`, for some
496      common batch size `B >= 1`.
497    fallback_to_while_loop: If true, on failing to vectorize an operation,
498      the unsupported op is wrapped in a tf.while_loop to execute the map
499      iterations. Note that this fallback only happens for unsupported ops and
500      other parts of `fn` are still vectorized. If false, on encountering an
501      unsupported op, a ValueError is thrown. Note that the fallbacks can result
502      in slowdowns since vectorization often yields speedup of one to two orders
503      of magnitude.
504
505  Returns:
506    A tensor or (possibly nested) sequence of tensors. Each tensor packs the
507    results of applying fn to tensors unpacked from elems along the first
508    dimension, from first to last.
509
510    Although they are less common as user-visible inputs and outputs, note that
511    tensors of type `tf.variant` which represent tensor lists (for example from
512    `tf.raw_ops.TensorListFromTensor`) are vectorized by stacking the list
513    contents rather than the variant itself, and so the container tensor will
514    have a scalar shape when returned rather than the usual stacked shape. This
515    improves the performance of control flow gradient vectorization.
516
517  Raises:
518    ValueError: If vectorization fails and fallback_to_while_loop is False.
519  """
520  elems = nest.map_structure(ops.convert_to_tensor,
521                             elems,
522                             expand_composites=True)
523
524  def loop_fn(i):
525    gathered_elems = nest.map_structure(
526        lambda x: _gather_from_tensor_or_composite(x, i), elems)
527    return fn(gathered_elems)
528
529  # Extract batch size from the maximum first dimension of any element.
530  flat_elems = nest.flatten(
531      nest.map_structure(
532          functools.partial(_composite_to_tensors,
533                            is_batched=True),
534          elems))
535  def _get_shape(x):
536    if x.shape.rank is None:
537      return None
538    return x.shape.as_list()[0]
539  static_first_dims = [_get_shape(elem) for elem in flat_elems]
540  if any([s is None for s in static_first_dims]):
541    batch_size = math_ops.reduce_max(
542        [array_ops.shape(elem)[0] for elem in flat_elems])
543  else:
544    batch_size = max(static_first_dims)
545
546  return pfor(loop_fn, batch_size,
547              fallback_to_while_loop=fallback_to_while_loop)
548