1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""for_loop and pfor ops.""" 16# pylint: disable=g-direct-tensorflow-import 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import functools 23 24from tensorflow.python.eager import context 25from tensorflow.python.eager import def_function 26from tensorflow.python.framework import composite_tensor 27from tensorflow.python.framework import indexed_slices 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import sparse_tensor 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.framework import tensor_util 32from tensorflow.python.framework import type_spec 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import tensor_array_ops 37from tensorflow.python.ops.parallel_for.pfor import PFor 38from tensorflow.python.ops.parallel_for.pfor import PForConfig 39from tensorflow.python.platform import tf_logging as logging 40from tensorflow.python.util import nest 41from tensorflow.python.util import tf_decorator 42from tensorflow.python.util import tf_inspect 43from tensorflow.python.util.tf_export import tf_export 44 45 46def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None): 47 """Runs `loop_fn` `iters` times and stacks the outputs. 48 49 50 Runs `loop_fn` `iters` times, with input values from 0 to `iters - 1`, and 51 stacks corresponding outputs of the different runs. 52 53 Args: 54 loop_fn: A function that takes an int32 scalar tf.Tensor object representing 55 the iteration number, and returns a possibly nested structure of tensor 56 objects. The shape of these outputs should not depend on the input. 57 loop_fn_dtypes: dtypes for the outputs of `loop_fn`. 58 iters: Number of iterations for which to run `loop_fn`. 59 parallel_iterations: The number of iterations that can be dispatched in 60 parallel. This knob can be used to control the total memory usage. 61 62 Returns: 63 Returns a nested structure of stacked output tensor objects with the same 64 nested structure as the output of `loop_fn`. 65 """ 66 67 flat_loop_fn_dtypes = nest.flatten(loop_fn_dtypes) 68 is_none_list = [] 69 70 def while_body(i, *ta_list): 71 """Body of while loop.""" 72 fn_output = nest.flatten(loop_fn(i)) 73 if len(fn_output) != len(flat_loop_fn_dtypes): 74 raise ValueError( 75 "Number of expected outputs, %d, does not match the number of " 76 "actual outputs, %d, from loop_fn" % (len(flat_loop_fn_dtypes), 77 len(fn_output))) 78 outputs = [] 79 del is_none_list[:] 80 is_none_list.extend(x is None for x in fn_output) 81 for out, ta in zip(fn_output, ta_list): 82 # TODO(agarwal): support returning Operation objects from loop_fn. 83 if out is not None: 84 # out may be a ref tensor, wrap it in identity to get a non-ref tensor. 85 ta = ta.write(i, array_ops.expand_dims(out, 0)) 86 outputs.append(ta) 87 return tuple([i + 1] + outputs) 88 89 if parallel_iterations is not None: 90 extra_args = {"parallel_iterations": parallel_iterations} 91 else: 92 extra_args = {} 93 ta_list = control_flow_ops.while_loop( 94 lambda i, *ta: i < iters, 95 while_body, 96 [0] + [tensor_array_ops.TensorArray(dtype.base_dtype, iters) 97 for dtype in flat_loop_fn_dtypes], 98 **extra_args)[1:] 99 100 # TODO(rachelim): enable this for sparse tensors 101 102 output = [None if is_none else ta.concat() 103 for ta, is_none in zip(ta_list, is_none_list)] 104 assert len(output) in (0, len(flat_loop_fn_dtypes)) 105 if not output: 106 # This may happen for the case where iters == 0. 107 return None 108 else: 109 return nest.pack_sequence_as(loop_fn_dtypes, output) 110 111 112def _flatten_first_two_dims(x): 113 """Flattens the first two dimensions of x into a single dimension.""" 114 old_shape = array_ops.shape(x) 115 new_shape = array_ops.concat([[old_shape[0] * old_shape[1]], old_shape[2:]], 116 axis=0) 117 return array_ops.reshape(x, new_shape) 118 119 120PFOR_CONFIG_ARG = "pfor_config" 121 122 123def _is_under_xla_context(): 124 """Check if we are currently inside an XLA compile context.""" 125 g = ops.get_default_graph() 126 while g is not None: 127 control_flow_context = g._get_control_flow_context() # pylint: disable=protected-access 128 while control_flow_context is not None: 129 if control_flow_context.IsXLAContext(): 130 return True 131 else: 132 control_flow_context = control_flow_context.outer_context 133 # If g is a FuncGraph, get its outer_graph. 134 g = getattr(g, "outer_graph", None) 135 return False 136 137 138def pfor(loop_fn, iters, fallback_to_while_loop=True, parallel_iterations=None): 139 """Equivalent to running `loop_fn` `iters` times and stacking the outputs. 140 141 `pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters` 142 times, with input from 0 to `iters - 1`, and stacking corresponding output of 143 each iteration. However the implementation does not use a `tf.while_loop`. 144 Instead it adds new operations to the graph that collectively compute the same 145 value as what running `loop_fn` in a loop would compute. 146 147 148 This is an experimental feature and currently has a lot of limitations: 149 - There should be no data dependency between the different iterations. For 150 example, a future iteration should not depend on a value or side-effect of 151 a previous iteration. 152 - Stateful kernels may mostly not be supported since these often imply a 153 data dependency or ordering of the iterations. We do support a limited set 154 of such stateful kernels though (like RandomFoo, Variable operations like 155 reads, etc). 156 - Conversion works only on a limited set of kernels for which a converter 157 has been registered. 158 - `loop_fn` has limited support for control flow operations. `tf.cond` in 159 particular is not supported. 160 - `loop_fn` should return nested structure of Tensors or Operations. However 161 if an Operation is returned, it should have zero outputs. 162 - The shape and dtype of `loop_fn` outputs should not depend on the input 163 to loop_fn. 164 165 Args: 166 loop_fn: A function that takes an int32 scalar tf.Tensor object representing 167 the iteration number, and optionally a keyword argument `pfor_config` set 168 to a PForConfig object. It returns a possibly nested structure of Tensor 169 or Operation objects. Note that if setting `parallel_iterations` argument 170 to something other than None, `loop_fn` may be called more than once 171 during graph construction. So it may need to avoid mutating global state. 172 iters: Number of iterations for which to run `loop_fn`. 173 fallback_to_while_loop: If true, on failing to vectorize an operation, pfor 174 fallbacks to using a `tf.while_loop` to dispatch the iterations. 175 parallel_iterations: A knob to control how many iterations are vectorized 176 and dispatched in parallel. The default value of None corresponds to 177 vectorizing all the iterations. If `parallel_iterations` is smaller than 178 `iters`, then chunks of at most that many iterations are dispatched in 179 sequence. This knob can be used to control the total memory usage. 180 181 Returns: 182 Returns a nested structure of stacked tensor objects with the same nested 183 structure as the output of `loop_fn`. 184 Raises: 185 ValueError: If parallel_iterations is not None and not an integer > 1. 186 """ 187 def f(): 188 return _pfor_impl(loop_fn, 189 iters, 190 fallback_to_while_loop=fallback_to_while_loop, 191 parallel_iterations=parallel_iterations) 192 # Note that we wrap into a tf.function if in eager execution mode or under 193 # XLA compilation. The latter is so that we don't compile operations like 194 # tf.placeholder that are created by the loop body. 195 functions_run_eagerly = None 196 if context.executing_eagerly() or _is_under_xla_context(): 197 functions_run_eagerly = def_function.functions_run_eagerly() 198 if functions_run_eagerly: 199 logging.warning( 200 "It looks like tf.function behavior was disabled, perhaps using " 201 "tf.config.run_functions_eagerly. Vectorization " 202 "primitives (e.g. tf.vectorized_map) require tf.function to work. " 203 "These primitives will override the disable.") 204 def_function.run_functions_eagerly(False) 205 f = def_function.function(f) 206 outputs = f() 207 if functions_run_eagerly is not None: 208 def_function.run_functions_eagerly(functions_run_eagerly) 209 return outputs 210 211 212def _should_expand_composite(value): 213 return (isinstance(value, composite_tensor.CompositeTensor) 214 # Leave sparse tensors to be converted by `PFor._convert_sparse`. 215 and not isinstance(value, sparse_tensor.SparseTensor) 216 and not isinstance(value, indexed_slices.IndexedSlices)) 217 218 219# pylint: disable=protected-access 220def _composite_to_tensors(value, is_batched=False): 221 """Converts a CompositeTensor into a list of stackable tensors.""" 222 if _should_expand_composite(value): 223 spec = value._type_spec 224 if not isinstance(spec, type_spec.BatchableTypeSpec): 225 raise ValueError("CompositeTensor instance {} returned from " 226 "parallel_for or vectorized_map loop body must provide " 227 "a `BatchableTypeSpec` (saw: {}).".format( 228 value, spec)) 229 if is_batched: 230 return spec._to_batched_tensor_list(value) 231 return spec._to_tensor_list(value) 232 return value 233# pylint: enable=protected-access 234 235 236# pylint: disable=protected-access 237def _composite_from_tensors(stacked_tensors, 238 preconverted_value, 239 batch_size): 240 """Converts a list of stacked tensors to a batch CompositeTensor.""" 241 if _should_expand_composite(preconverted_value): 242 batch_type_spec = preconverted_value._type_spec._batch(batch_size) 243 return batch_type_spec._from_compatible_tensor_list(stacked_tensors) 244 return stacked_tensors 245# pylint: enable=protected-access 246 247 248def _loop_fn_has_config(loop_fn): 249 """Test if `loop_fn` has a `pfor_config` argument.""" 250 if tf_inspect.isfunction(loop_fn): 251 argspec = tf_inspect.getargspec(loop_fn) 252 return PFOR_CONFIG_ARG in argspec.args 253 elif isinstance(loop_fn, functools.partial): 254 fn = loop_fn.func 255 argspec = tf_inspect.getargspec(fn) 256 return (PFOR_CONFIG_ARG in argspec.args and 257 PFOR_CONFIG_ARG not in loop_fn.keywords) 258 else: 259 loop_class = tf_decorator.unwrap(loop_fn)[1] 260 if not hasattr(loop_class, "__call__"): 261 raise ValueError("loop_fn object did not have a __call__ method") 262 argspec = tf_inspect.getargspec(loop_class.__call__) 263 return PFOR_CONFIG_ARG in argspec.args 264 265 266def _pfor_impl(loop_fn, 267 iters, 268 fallback_to_while_loop, 269 parallel_iterations=None, 270 pfor_config=None): 271 """Implementation of pfor.""" 272 assert not context.executing_eagerly() 273 loop_fn_has_config = _loop_fn_has_config(loop_fn) 274 existing_ops = set(ops.get_default_graph().get_operations()) 275 iters_value = tensor_util.constant_value(iters) 276 # Run the loop body 277 with ops.name_scope("loop_body"): 278 loop_var = array_ops.placeholder_with_default(0, shape=[]) 279 if loop_fn_has_config: 280 if pfor_config is None: 281 pfor_config = PForConfig() 282 pfor_config._set_iters(iters) # pylint: disable=protected-access 283 loop_fn_outputs = loop_fn(loop_var, **{PFOR_CONFIG_ARG: pfor_config}) 284 else: 285 assert pfor_config is None 286 loop_fn_outputs = loop_fn(loop_var) 287 loop_fn_output_tensors = nest.map_structure(_composite_to_tensors, 288 loop_fn_outputs) 289 290 # Convert outputs to Tensor if needed. 291 tmp_loop_fn_outputs = [] 292 for loop_fn_output in nest.flatten(loop_fn_output_tensors): 293 if (loop_fn_output is not None and not isinstance( 294 loop_fn_output, 295 (ops.Operation, ops.Tensor, sparse_tensor.SparseTensor))): 296 if isinstance(loop_fn_output, indexed_slices.IndexedSlices): 297 logging.warn("Converting %s to a dense representation may make it slow." 298 " Alternatively, output the indices and values of the" 299 " IndexedSlices separately, and handle the vectorized" 300 " outputs directly." % loop_fn_output) 301 loop_fn_output = ops.convert_to_tensor(loop_fn_output) 302 else: 303 loop_fn_output = ops.convert_to_tensor(loop_fn_output) 304 tmp_loop_fn_outputs.append(loop_fn_output) 305 loop_fn_output_tensors = nest.pack_sequence_as(loop_fn_output_tensors, 306 tmp_loop_fn_outputs) 307 308 new_ops = set(ops.get_default_graph().get_operations()) - existing_ops 309 iters = ops.convert_to_tensor(iters) 310 if parallel_iterations is not None: 311 if parallel_iterations < 1: 312 raise ValueError("parallel_iterations must be None or a positive integer") 313 if parallel_iterations == 1: 314 raise ValueError("Found parallel_iterations == 1. Use for_loop instead.") 315 if iters_value is not None and iters_value < parallel_iterations: 316 parallel_iterations = None 317 if parallel_iterations is None: 318 with ops.name_scope("pfor"): 319 converter = PFor(loop_var, iters, new_ops, 320 fallback_to_while_loop=fallback_to_while_loop, 321 pfor_config=pfor_config) 322 flattened_output_tensors = [] 323 for loop_fn_output in nest.flatten(loop_fn_output_tensors): 324 output = converter.convert(loop_fn_output) 325 flattened_output_tensors.append(output) 326 else: 327 if pfor_config is not None and pfor_config._has_reductions(): # pylint: disable=protected-access 328 raise ValueError("Setting parallel_iterations currently unsupported if" 329 " reductions across iterations are performed.") 330 num_tiled_iterations = iters // parallel_iterations 331 num_remaining_iterations = iters % parallel_iterations 332 # TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside 333 # a tf.function and extract the graph from there to vectorize it. 334 with ops.name_scope("pfor_untiled"): 335 converter = PFor(loop_var, num_remaining_iterations, new_ops, 336 fallback_to_while_loop=fallback_to_while_loop, 337 pfor_config=pfor_config) 338 remaining_output_tensors = [] 339 flattened_output_tensors = nest.flatten(loop_fn_output_tensors) 340 for loop_fn_output in flattened_output_tensors: 341 output = converter.convert(loop_fn_output) 342 remaining_output_tensors.append(output) 343 344 with ops.name_scope("pfor_tiled"): 345 loop_fn_dtypes = [ops.convert_to_tensor(x).dtype 346 for x in flattened_output_tensors] 347 348 def tiled_loop_body(j): 349 offset = j * parallel_iterations + num_remaining_iterations 350 351 def tiled_loop_fn(i, pfor_config=None): 352 if loop_fn_has_config: 353 loop_fn_outputs = loop_fn(i + offset, pfor_config=pfor_config) 354 else: 355 loop_fn_outputs = loop_fn(i + offset) 356 return nest.flatten( 357 # Stacking across iterations requires explicit Tensors. 358 nest.map_structure(_composite_to_tensors, loop_fn_outputs)) 359 360 return _pfor_impl( 361 tiled_loop_fn, 362 parallel_iterations, 363 fallback_to_while_loop=fallback_to_while_loop, 364 pfor_config=pfor_config) 365 366 tiled_output_tensors = for_loop( 367 tiled_loop_body, loop_fn_dtypes, 368 num_tiled_iterations, parallel_iterations=1) 369 tiled_output_tensors = [ 370 _flatten_first_two_dims(y) for y in tiled_output_tensors] 371 372 with ops.name_scope("pfor"): 373 if iters_value is None or iters_value % parallel_iterations: 374 output_tensors = control_flow_ops.cond( 375 math_ops.equal(num_remaining_iterations, 0), 376 lambda: tiled_output_tensors, 377 lambda: [array_ops.concat([x, y], axis=0) # pylint: disable=g-long-lambda 378 for x, y in zip(remaining_output_tensors, 379 tiled_output_tensors)]) 380 else: 381 output_tensors = tiled_output_tensors 382 flattened_output_tensors = nest.flatten(output_tensors) 383 384 for output, original_output in zip(flattened_output_tensors, 385 nest.flatten(loop_fn_output_tensors)): 386 # Restore any shape information lost from tiling. 387 # TODO(b/174254748): this may not be correct for stacked `variant`s. 388 output.set_shape( 389 tensor_shape.TensorShape([iters_value]).concatenate( 390 original_output.shape)) 391 392 return nest.map_structure_up_to( 393 loop_fn_outputs, 394 functools.partial(_composite_from_tensors, batch_size=iters_value), 395 nest.pack_sequence_as(loop_fn_output_tensors, 396 flattened_output_tensors), 397 loop_fn_outputs) 398 399 400def _broadcasting_gather(x, i): 401 """Wrapper for gather that implicitly broadcasts unit dimensions.""" 402 static_first_dim = tensor_shape.dimension_value(x.shape[0]) 403 if static_first_dim == 1: 404 i = 0 405 elif static_first_dim is None: 406 i = array_ops.where_v2(array_ops.shape(x)[0] > 1, i, 0) 407 result = array_ops.gather(x, i) 408 return result 409 410 411# pylint: disable=protected-access 412def _gather_from_tensor_or_composite(x, i): 413 """Wrapper for gather that handles CompositeTensors.""" 414 if _should_expand_composite(x): 415 spec = x._type_spec 416 gathered_tensors = [_broadcasting_gather(t, i) 417 for t in spec._to_batched_tensor_list(x)] 418 return spec._unbatch()._from_compatible_tensor_list(gathered_tensors) 419 return _broadcasting_gather(x, i) 420# pylint: enable=protected-access 421 422 423@tf_export("vectorized_map") 424def vectorized_map(fn, elems, fallback_to_while_loop=True): 425 """Parallel map on the list of tensors unpacked from `elems` on dimension 0. 426 427 This method works similar to `tf.map_fn` but is optimized to run much faster, 428 possibly with a much larger memory footprint. The speedups are obtained by 429 vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians, 430 Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea 431 behind vectorization is to semantically launch all the invocations of `fn` in 432 parallel and fuse corresponding operations across all these invocations. This 433 fusion is done statically at graph generation time and the generated code is 434 often similar in performance to a manually fused version. 435 436 Because `tf.vectorized_map` fully parallelizes the batch, this method will 437 generally be significantly faster than using `tf.map_fn`, especially in eager 438 mode. However this is an experimental feature and currently has a lot of 439 limitations: 440 - There should be no data dependency between the different semantic 441 invocations of `fn`, i.e. it should be safe to map the elements of the 442 inputs in any order. 443 - Stateful kernels may mostly not be supported since these often imply a 444 data dependency. We do support a limited set of such stateful kernels 445 though (like RandomFoo, Variable operations like reads, etc). 446 - `fn` has limited support for control flow operations. 447 - `fn` should return nested structure of Tensors or Operations. However 448 if an Operation is returned, it should have zero outputs. 449 - The shape and dtype of any intermediate or output tensors in the 450 computation of `fn` should not depend on the input to `fn`. 451 452 Examples: 453 ```python 454 def outer_product(a): 455 return tf.tensordot(a, a, 0) 456 457 batch_size = 100 458 a = tf.ones((batch_size, 32, 32)) 459 c = tf.vectorized_map(outer_product, a) 460 assert c.shape == (batch_size, 32, 32, 32, 32) 461 ``` 462 463 ```python 464 # Computing per-example gradients 465 466 batch_size = 10 467 num_features = 32 468 layer = tf.keras.layers.Dense(1) 469 470 def model_fn(arg): 471 with tf.GradientTape() as g: 472 inp, label = arg 473 inp = tf.expand_dims(inp, 0) 474 label = tf.expand_dims(label, 0) 475 prediction = layer(inp) 476 loss = tf.nn.l2_loss(label - prediction) 477 return g.gradient(loss, (layer.kernel, layer.bias)) 478 479 inputs = tf.random.uniform([batch_size, num_features]) 480 labels = tf.random.uniform([batch_size, 1]) 481 per_example_gradients = tf.vectorized_map(model_fn, (inputs, labels)) 482 assert per_example_gradients[0].shape == (batch_size, num_features, 1) 483 assert per_example_gradients[1].shape == (batch_size, 1) 484 ``` 485 486 Args: 487 fn: The callable to be performed. It accepts one argument, which will have 488 the same (possibly nested) structure as `elems`, and returns a possibly 489 nested structure of Tensors and Operations, which may be different than 490 the structure of `elems`. 491 elems: A tensor or (possibly nested) sequence of tensors, each of which will 492 be unpacked along their first dimension. The nested sequence of the 493 resulting slices will be mapped over by `fn`. The first dimensions of all 494 elements must broadcast to a consistent value; equivalently, each 495 element tensor must have first dimension of either `B` or `1`, for some 496 common batch size `B >= 1`. 497 fallback_to_while_loop: If true, on failing to vectorize an operation, 498 the unsupported op is wrapped in a tf.while_loop to execute the map 499 iterations. Note that this fallback only happens for unsupported ops and 500 other parts of `fn` are still vectorized. If false, on encountering an 501 unsupported op, a ValueError is thrown. Note that the fallbacks can result 502 in slowdowns since vectorization often yields speedup of one to two orders 503 of magnitude. 504 505 Returns: 506 A tensor or (possibly nested) sequence of tensors. Each tensor packs the 507 results of applying fn to tensors unpacked from elems along the first 508 dimension, from first to last. 509 510 Although they are less common as user-visible inputs and outputs, note that 511 tensors of type `tf.variant` which represent tensor lists (for example from 512 `tf.raw_ops.TensorListFromTensor`) are vectorized by stacking the list 513 contents rather than the variant itself, and so the container tensor will 514 have a scalar shape when returned rather than the usual stacked shape. This 515 improves the performance of control flow gradient vectorization. 516 517 Raises: 518 ValueError: If vectorization fails and fallback_to_while_loop is False. 519 """ 520 elems = nest.map_structure(ops.convert_to_tensor, 521 elems, 522 expand_composites=True) 523 524 def loop_fn(i): 525 gathered_elems = nest.map_structure( 526 lambda x: _gather_from_tensor_or_composite(x, i), elems) 527 return fn(gathered_elems) 528 529 # Extract batch size from the maximum first dimension of any element. 530 flat_elems = nest.flatten( 531 nest.map_structure( 532 functools.partial(_composite_to_tensors, 533 is_batched=True), 534 elems)) 535 def _get_shape(x): 536 if x.shape.rank is None: 537 return None 538 return x.shape.as_list()[0] 539 static_first_dims = [_get_shape(elem) for elem in flat_elems] 540 if any([s is None for s in static_first_dims]): 541 batch_size = math_ops.reduce_max( 542 [array_ops.shape(elem)[0] for elem in flat_elems]) 543 else: 544 batch_size = max(static_first_dims) 545 546 return pfor(loop_fn, batch_size, 547 fallback_to_while_loop=fallback_to_while_loop) 548