1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Functional operations.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.framework import attr_value_pb2 22from tensorflow.python.eager import context 23from tensorflow.python.framework import auto_control_deps_utils as acd 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import function 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import gen_functional_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import tensor_array_ops 34from tensorflow.python.ops import variable_scope as vs 35# pylint: disable=unused-import 36from tensorflow.python.ops.gen_functional_ops import remote_call 37# pylint: enable=unused-import 38from tensorflow.python.ops.gen_functional_ops import symbolic_gradient 39from tensorflow.python.util import compat 40from tensorflow.python.util import deprecation 41from tensorflow.python.util import dispatch 42from tensorflow.python.util import function_utils 43from tensorflow.python.util import nest 44from tensorflow.python.util.tf_export import tf_export 45 46 47# TODO(yuanbyu, mrry): Handle stride to support sliding windows. 48@tf_export(v1=["foldl"]) 49@dispatch.add_dispatch_support 50def foldl(fn, 51 elems, 52 initializer=None, 53 parallel_iterations=10, 54 back_prop=True, 55 swap_memory=False, 56 name=None): 57 """foldl on the list of tensors unpacked from `elems` on dimension 0. 58 59 This foldl operator repeatedly applies the callable `fn` to a sequence 60 of elements from first to last. The elements are made of the tensors 61 unpacked from `elems` on dimension 0. The callable fn takes two tensors as 62 arguments. The first argument is the accumulated value computed from the 63 preceding invocation of fn, and the second is the value at the current 64 position of `elems`. If `initializer` is None, `elems` must contain at least 65 one element, and its first element is used as the initializer. 66 67 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 68 of the result tensor is fn(initializer, values[0]).shape`. 69 70 This method also allows multi-arity `elems` and output of `fn`. If `elems` 71 is a (possibly nested) list or tuple of tensors, then each of these tensors 72 must have a matching first (unpack) dimension. The signature of `fn` may 73 match the structure of `elems`. That is, if `elems` is 74 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: 75 `fn = lambda (t1, [t2, t3, [t4, t5]]):`. 76 77 Args: 78 fn: The callable to be performed. 79 elems: A tensor or (possibly nested) sequence of tensors, each of which will 80 be unpacked along their first dimension. The nested sequence of the 81 resulting slices will be the first argument to `fn`. 82 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 83 as the initial value for the accumulator. 84 parallel_iterations: (optional) The number of iterations allowed to run in 85 parallel. 86 back_prop: (optional) True enables support for back propagation. 87 swap_memory: (optional) True enables GPU-CPU memory swapping. 88 name: (optional) Name prefix for the returned tensors. 89 90 Returns: 91 A tensor or (possibly nested) sequence of tensors, resulting from applying 92 `fn` consecutively to the list of tensors unpacked from `elems`, from first 93 to last. 94 95 Raises: 96 TypeError: if `fn` is not callable. 97 98 Example: 99 ```python 100 elems = tf.constant([1, 2, 3, 4, 5, 6]) 101 sum = foldl(lambda a, x: a + x, elems) 102 # sum == 21 103 ``` 104 """ 105 if not callable(fn): 106 raise TypeError( 107 f"{fn.__name__} is not callable. Please provide a callable function.") 108 109 def create_ta(elem): 110 return tensor_array_ops.TensorArray( 111 dtype=elem.dtype, size=n, dynamic_size=False, 112 infer_shape=True).unstack(elem) 113 114 in_graph_mode = not context.executing_eagerly() 115 with ops.name_scope(name, "foldl", [elems]): 116 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 117 # supported in Eager 118 if in_graph_mode: 119 # Any get_variable calls in fn will cache the first call locally 120 # and not issue repeated network I/O requests for each iteration. 121 varscope = vs.get_variable_scope() 122 varscope_caching_device_was_none = False 123 if varscope.caching_device is None: 124 # TODO(ebrevdo): Change to using colocate_with here and in other 125 # methods. 126 varscope.set_caching_device(lambda op: op.device) 127 varscope_caching_device_was_none = True 128 129 # Convert elems to tensor array. n may be known statically. 130 elems_flat = [ 131 ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems) 132 ] 133 n = ( 134 tensor_shape.dimension_value(elems_flat[0].shape[0]) or 135 array_ops.shape(elems_flat[0])[0]) 136 137 elems_ta = nest.map_structure(create_ta, elems) 138 139 if initializer is None: 140 a = nest.map_structure(lambda elem: elem.read(0), elems_ta) 141 i = constant_op.constant(1) 142 else: 143 a = initializer 144 i = constant_op.constant(0) 145 146 def compute(i, a): 147 elem_i = nest.map_structure(lambda elem: elem.read(i), elems_ta) 148 a = fn(a, elem_i) 149 return [i + 1, a] 150 151 _, r_a = control_flow_ops.while_loop( 152 lambda i, a: i < n, 153 compute, [i, a], 154 parallel_iterations=parallel_iterations, 155 back_prop=back_prop, 156 swap_memory=swap_memory, 157 maximum_iterations=n) 158 159 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 160 # supported in Eager 161 if in_graph_mode and varscope_caching_device_was_none: 162 varscope.set_caching_device(None) 163 164 return r_a 165 166 167@tf_export("foldl", v1=[]) 168@dispatch.add_dispatch_support 169@deprecation.deprecated_arg_values( 170 None, 171 """back_prop=False is deprecated. Consider using tf.stop_gradient instead. 172Instead of: 173results = tf.foldl(fn, elems, back_prop=False) 174Use: 175results = tf.nest.map_structure(tf.stop_gradient, tf.foldl(fn, elems))""", 176 warn_once=True, 177 back_prop=False) 178def foldl_v2(fn, 179 elems, 180 initializer=None, 181 parallel_iterations=10, 182 back_prop=True, 183 swap_memory=False, 184 name=None): 185 """foldl on the list of tensors unpacked from `elems` on dimension 0. 186 187 This foldl operator repeatedly applies the callable `fn` to a sequence 188 of elements from first to last. The elements are made of the tensors 189 unpacked from `elems` on dimension 0. The callable fn takes two tensors as 190 arguments. The first argument is the accumulated value computed from the 191 preceding invocation of fn, and the second is the value at the current 192 position of `elems`. If `initializer` is None, `elems` must contain at least 193 one element, and its first element is used as the initializer. 194 195 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 196 of the result tensor is fn(initializer, values[0]).shape`. 197 198 This method also allows multi-arity `elems` and output of `fn`. If `elems` 199 is a (possibly nested) list or tuple of tensors, then each of these tensors 200 must have a matching first (unpack) dimension. The signature of `fn` may 201 match the structure of `elems`. That is, if `elems` is 202 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: 203 `fn = lambda (t1, [t2, t3, [t4, t5]]):`. 204 205 Args: 206 fn: The callable to be performed. 207 elems: A tensor or (possibly nested) sequence of tensors, each of which will 208 be unpacked along their first dimension. The nested sequence of the 209 resulting slices will be the first argument to `fn`. 210 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 211 as the initial value for the accumulator. 212 parallel_iterations: (optional) The number of iterations allowed to run in 213 parallel. 214 back_prop: (optional) Deprecated. False disables support for back 215 propagation. Prefer using `tf.stop_gradient` instead. 216 swap_memory: (optional) True enables GPU-CPU memory swapping. 217 name: (optional) Name prefix for the returned tensors. 218 219 Returns: 220 A tensor or (possibly nested) sequence of tensors, resulting from applying 221 `fn` consecutively to the list of tensors unpacked from `elems`, from first 222 to last. 223 224 Raises: 225 TypeError: if `fn` is not callable. 226 227 Example: 228 ```python 229 elems = tf.constant([1, 2, 3, 4, 5, 6]) 230 sum = foldl(lambda a, x: a + x, elems) 231 # sum == 21 232 ``` 233 """ 234 return foldl( 235 fn=fn, 236 elems=elems, 237 initializer=initializer, 238 parallel_iterations=parallel_iterations, 239 back_prop=back_prop, 240 swap_memory=swap_memory, 241 name=name) 242 243 244@tf_export(v1=["foldr"]) 245@dispatch.add_dispatch_support 246def foldr(fn, 247 elems, 248 initializer=None, 249 parallel_iterations=10, 250 back_prop=True, 251 swap_memory=False, 252 name=None): 253 """foldr on the list of tensors unpacked from `elems` on dimension 0. 254 255 This foldr operator repeatedly applies the callable `fn` to a sequence 256 of elements from last to first. The elements are made of the tensors 257 unpacked from `elems`. The callable fn takes two tensors as arguments. 258 The first argument is the accumulated value computed from the preceding 259 invocation of fn, and the second is the value at the current position of 260 `elems`. If `initializer` is None, `elems` must contain at least one element, 261 and its first element is used as the initializer. 262 263 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 264 of the result tensor is `fn(initializer, values[0]).shape`. 265 266 This method also allows multi-arity `elems` and output of `fn`. If `elems` 267 is a (possibly nested) list or tuple of tensors, then each of these tensors 268 must have a matching first (unpack) dimension. The signature of `fn` may 269 match the structure of `elems`. That is, if `elems` is 270 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: 271 `fn = lambda (t1, [t2, t3, [t4, t5]]):`. 272 273 Args: 274 fn: The callable to be performed. 275 elems: A tensor or (possibly nested) sequence of tensors, each of which will 276 be unpacked along their first dimension. The nested sequence of the 277 resulting slices will be the first argument to `fn`. 278 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 279 as the initial value for the accumulator. 280 parallel_iterations: (optional) The number of iterations allowed to run in 281 parallel. 282 back_prop: (optional) True enables support for back propagation. 283 swap_memory: (optional) True enables GPU-CPU memory swapping. 284 name: (optional) Name prefix for the returned tensors. 285 286 Returns: 287 A tensor or (possibly nested) sequence of tensors, resulting from applying 288 `fn` consecutively to the list of tensors unpacked from `elems`, from last 289 to first. 290 291 Raises: 292 TypeError: if `fn` is not callable. 293 294 Example: 295 ```python 296 elems = [1, 2, 3, 4, 5, 6] 297 sum = foldr(lambda a, x: a + x, elems) 298 # sum == 21 299 ``` 300 """ 301 if not callable(fn): 302 raise TypeError( 303 f"{fn.__name__} is not callable. Please provide a callable function.") 304 305 def create_ta(elem): 306 return tensor_array_ops.TensorArray( 307 dtype=elem.dtype, size=n, dynamic_size=False, 308 infer_shape=True).unstack(elem) 309 310 in_graph_mode = not context.executing_eagerly() 311 with ops.name_scope(name, "foldr", [elems]): 312 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 313 # supported in Eager 314 if in_graph_mode: 315 # Any get_variable calls in fn will cache the first call locally and not 316 # issue repeated network I/O requests for each iteration. 317 varscope = vs.get_variable_scope() 318 varscope_caching_device_was_none = False 319 if varscope.caching_device is None: 320 # TODO(ebrevdo): Change to using colocate_with here and in other 321 # methods. 322 varscope.set_caching_device(lambda op: op.device) 323 varscope_caching_device_was_none = True 324 325 # Convert elems to tensor array. n may be known statically. 326 elems_flat = [ 327 ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems) 328 ] 329 n = ( 330 tensor_shape.dimension_value(elems_flat[0].shape[0]) or 331 array_ops.shape(elems_flat[0])[0]) 332 333 elems_ta = nest.map_structure(create_ta, elems) 334 335 if initializer is None: 336 i = n - 1 337 a = nest.map_structure(lambda elem: elem.read(i), elems_ta) 338 else: 339 i = n 340 a = initializer 341 342 def compute(i, a): 343 i -= 1 344 elem = nest.map_structure(lambda elem: elem.read(i), elems_ta) 345 a_out = fn(a, elem) 346 return [i, a_out] 347 348 _, r_a = control_flow_ops.while_loop( 349 lambda i, a: i > 0, 350 compute, [i, a], 351 parallel_iterations=parallel_iterations, 352 back_prop=back_prop, 353 swap_memory=swap_memory, 354 maximum_iterations=n) 355 356 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 357 # supported in Eager 358 if in_graph_mode and varscope_caching_device_was_none: 359 varscope.set_caching_device(None) 360 361 return r_a 362 363 364@tf_export("foldr", v1=[]) 365@dispatch.add_dispatch_support 366@deprecation.deprecated_arg_values( 367 None, 368 """back_prop=False is deprecated. Consider using tf.stop_gradient instead. 369Instead of: 370results = tf.foldr(fn, elems, back_prop=False) 371Use: 372results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))""", 373 warn_once=True, 374 back_prop=False) 375def foldr_v2(fn, 376 elems, 377 initializer=None, 378 parallel_iterations=10, 379 back_prop=True, 380 swap_memory=False, 381 name=None): 382 """foldr on the list of tensors unpacked from `elems` on dimension 0. 383 384 This foldr operator repeatedly applies the callable `fn` to a sequence 385 of elements from last to first. The elements are made of the tensors 386 unpacked from `elems`. The callable fn takes two tensors as arguments. 387 The first argument is the accumulated value computed from the preceding 388 invocation of fn, and the second is the value at the current position of 389 `elems`. If `initializer` is None, `elems` must contain at least one element, 390 and its first element is used as the initializer. 391 392 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 393 of the result tensor is `fn(initializer, values[0]).shape`. 394 395 This method also allows multi-arity `elems` and output of `fn`. If `elems` 396 is a (possibly nested) list or tuple of tensors, then each of these tensors 397 must have a matching first (unpack) dimension. The signature of `fn` may 398 match the structure of `elems`. That is, if `elems` is 399 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: 400 `fn = lambda (t1, [t2, t3, [t4, t5]]):`. 401 402 Args: 403 fn: The callable to be performed. 404 elems: A tensor or (possibly nested) sequence of tensors, each of which will 405 be unpacked along their first dimension. The nested sequence of the 406 resulting slices will be the first argument to `fn`. 407 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 408 as the initial value for the accumulator. 409 parallel_iterations: (optional) The number of iterations allowed to run in 410 parallel. 411 back_prop: (optional) Deprecated. False disables support for back 412 propagation. Prefer using `tf.stop_gradient` instead. 413 swap_memory: (optional) True enables GPU-CPU memory swapping. 414 name: (optional) Name prefix for the returned tensors. 415 416 Returns: 417 A tensor or (possibly nested) sequence of tensors, resulting from applying 418 `fn` consecutively to the list of tensors unpacked from `elems`, from last 419 to first. 420 421 Raises: 422 TypeError: if `fn` is not callable. 423 424 Example: 425 ```python 426 elems = [1, 2, 3, 4, 5, 6] 427 sum = foldr(lambda a, x: a + x, elems) 428 # sum == 21 429 ``` 430 """ 431 return foldr( 432 fn=fn, 433 elems=elems, 434 initializer=initializer, 435 parallel_iterations=parallel_iterations, 436 back_prop=back_prop, 437 swap_memory=swap_memory, 438 name=name) 439 440 441@tf_export(v1=["scan"]) 442@dispatch.add_dispatch_support 443def scan(fn, 444 elems, 445 initializer=None, 446 parallel_iterations=10, 447 back_prop=True, 448 swap_memory=False, 449 infer_shape=True, 450 reverse=False, 451 name=None): 452 """scan on the list of tensors unpacked from `elems` on dimension 0. 453 454 See also `tf.map_fn`. 455 456 The simplest version of `scan` repeatedly applies the callable `fn` to a 457 sequence of elements from first to last. The elements are made of the tensors 458 unpacked from `elems` on dimension 0. The callable fn takes two tensors as 459 arguments. The first argument is the accumulated value computed from the 460 preceding invocation of fn, and the second is the value at the current 461 position of `elems`. If `initializer` is None, `elems` must contain at least 462 one element, and its first element is used as the initializer. 463 464 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 465 of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`. 466 If reverse=True, it's fn(initializer, values[-1]).shape. 467 468 This method also allows multi-arity `elems` and accumulator. If `elems` 469 is a (possibly nested) list or tuple of tensors, then each of these tensors 470 must have a matching first (unpack) dimension. The second argument of 471 `fn` must match the structure of `elems`. 472 473 If no `initializer` is provided, the output structure and dtypes of `fn` 474 are assumed to be the same as its input; and in this case, the first 475 argument of `fn` must match the structure of `elems`. 476 477 If an `initializer` is provided, then the output of `fn` must have the same 478 structure as `initializer`; and the first argument of `fn` must match 479 this structure. 480 481 For example, if `elems` is `(t1, [t2, t3])` and `initializer` is 482 `[i1, i2]` then an appropriate signature for `fn` in `python2` is: 483 `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list, 484 `[acc_n1, acc_n2]`. An alternative correct signature for `fn`, and the 485 one that works in `python3`, is: 486 `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples. 487 488 Args: 489 fn: The callable to be performed. It accepts two arguments. The first will 490 have the same structure as `initializer` if one is provided, otherwise it 491 will have the same structure as `elems`. The second will have the same 492 (possibly nested) structure as `elems`. Its output must have the same 493 structure as `initializer` if one is provided, otherwise it must have the 494 same structure as `elems`. 495 elems: A tensor or (possibly nested) sequence of tensors, each of which will 496 be unpacked along their first dimension. The nested sequence of the 497 resulting slices will be the first argument to `fn`. 498 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 499 initial value for the accumulator, and the expected output type of `fn`. 500 parallel_iterations: (optional) The number of iterations allowed to run in 501 parallel. 502 back_prop: (optional) True enables support for back propagation. 503 swap_memory: (optional) True enables GPU-CPU memory swapping. 504 infer_shape: (optional) False disables tests for consistent output shapes. 505 reverse: (optional) True scans the tensor last to first (instead of first to 506 last). 507 name: (optional) Name prefix for the returned tensors. 508 509 Returns: 510 A tensor or (possibly nested) sequence of tensors. Each tensor packs the 511 results of applying `fn` to tensors unpacked from `elems` along the first 512 dimension, and the previous accumulator value(s), from first to last (or 513 last to first, if `reverse=True`). 514 515 Raises: 516 TypeError: if `fn` is not callable or the structure of the output of 517 `fn` and `initializer` do not match. 518 ValueError: if the lengths of the output of `fn` and `initializer` 519 do not match. 520 521 Examples: 522 ```python 523 elems = np.array([1, 2, 3, 4, 5, 6]) 524 sum = scan(lambda a, x: a + x, elems) 525 # sum == [1, 3, 6, 10, 15, 21] 526 sum = scan(lambda a, x: a + x, elems, reverse=True) 527 # sum == [21, 20, 18, 15, 11, 6] 528 ``` 529 530 ```python 531 elems = np.array([1, 2, 3, 4, 5, 6]) 532 initializer = np.array(0) 533 sum_one = scan( 534 lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer) 535 # sum_one == [1, 2, 3, 4, 5, 6] 536 ``` 537 538 ```python 539 elems = np.array([1, 0, 0, 0, 0, 0]) 540 initializer = (np.array(0), np.array(1)) 541 fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer) 542 # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13]) 543 ``` 544 """ 545 if not callable(fn): 546 raise TypeError( 547 f"{fn.__name__} is not callable. Please provide a callable function.") 548 549 input_is_sequence = nest.is_sequence(elems) 550 input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x] 551 552 def input_pack(x): 553 return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0] 554 555 if initializer is None: 556 output_is_sequence = input_is_sequence 557 output_flatten = input_flatten 558 output_pack = input_pack 559 else: 560 output_is_sequence = nest.is_sequence(initializer) 561 output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x] 562 563 def output_pack(x): 564 return (nest.pack_sequence_as(initializer, x) 565 if output_is_sequence else x[0]) 566 567 elems_flat = input_flatten(elems) 568 569 in_graph_mode = not context.executing_eagerly() 570 with ops.name_scope(name, "scan", elems_flat): 571 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 572 # supported in Eager 573 if in_graph_mode: 574 # Any get_variable calls in fn will cache the first call locally 575 # and not issue repeated network I/O requests for each iteration. 576 varscope = vs.get_variable_scope() 577 varscope_caching_device_was_none = False 578 if varscope.caching_device is None: 579 # TODO(ebrevdo): Change to using colocate_with here and in other 580 # methods. 581 varscope.set_caching_device(lambda op: op.device) 582 varscope_caching_device_was_none = True 583 584 # Convert elems to tensor array. 585 elems_flat = [ 586 ops.convert_to_tensor(elem, name="elem") for elem in elems_flat 587 ] 588 589 # Convert elems to tensor array. n may be known statically. 590 n = tensor_shape.dimension_value(elems_flat[0].shape[0]) 591 if n is None: 592 n = array_ops.shape(elems_flat[0])[0] 593 594 # TensorArrays are always flat 595 elems_ta = [ 596 tensor_array_ops.TensorArray( 597 dtype=elem.dtype, 598 size=n, 599 dynamic_size=False, 600 element_shape=elem.shape[1:], 601 infer_shape=True) for elem in elems_flat 602 ] 603 # Unpack elements 604 elems_ta = [ 605 elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat) 606 ] 607 608 if initializer is None: 609 a_flat = [elem.read(n - 1 if reverse else 0) for elem in elems_ta] 610 i = 1 611 else: 612 initializer_flat = output_flatten(initializer) 613 a_flat = [ops.convert_to_tensor(init) for init in initializer_flat] 614 i = 0 615 616 # Create a tensor array to store the intermediate values. 617 accs_ta = [ 618 tensor_array_ops.TensorArray( 619 dtype=init.dtype, 620 size=n, 621 element_shape=init.shape if infer_shape else None, 622 dynamic_size=False, 623 infer_shape=infer_shape) for init in a_flat 624 ] 625 626 if initializer is None: 627 accs_ta = [ 628 acc_ta.write(n - 1 if reverse else 0, a) 629 for (acc_ta, a) in zip(accs_ta, a_flat) 630 ] 631 632 def compute(i, a_flat, tas): 633 """The loop body of scan. 634 635 Args: 636 i: the loop counter. 637 a_flat: the accumulator value(s), flattened. 638 tas: the output accumulator TensorArray(s), flattened. 639 640 Returns: 641 [i + 1, a_flat, tas]: the updated counter + new accumulator values + 642 updated TensorArrays 643 644 Raises: 645 TypeError: if initializer and fn() output structure do not match 646 ValueType: if initializer and fn() output lengths do not match 647 """ 648 packed_elems = input_pack([elem_ta.read(i) for elem_ta in elems_ta]) 649 packed_a = output_pack(a_flat) 650 a_out = fn(packed_a, packed_elems) 651 nest.assert_same_structure(elems if initializer is None else initializer, 652 a_out) 653 flat_a_out = output_flatten(a_out) 654 tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)] 655 if reverse: 656 next_i = i - 1 657 else: 658 next_i = i + 1 659 return (next_i, flat_a_out, tas) 660 661 if reverse: 662 initial_i = n - 1 - i 663 condition = lambda i, _1, _2: i >= 0 664 else: 665 initial_i = i 666 condition = lambda i, _1, _2: i < n 667 _, _, r_a = control_flow_ops.while_loop( 668 condition, 669 compute, (initial_i, a_flat, accs_ta), 670 parallel_iterations=parallel_iterations, 671 back_prop=back_prop, 672 swap_memory=swap_memory, 673 maximum_iterations=n) 674 675 results_flat = [r.stack() for r in r_a] 676 677 n_static = tensor_shape.Dimension( 678 tensor_shape.dimension_value( 679 elems_flat[0].get_shape().with_rank_at_least(1)[0])) 680 for elem in elems_flat[1:]: 681 n_static.assert_is_compatible_with( 682 tensor_shape.Dimension( 683 tensor_shape.dimension_value( 684 elem.get_shape().with_rank_at_least(1)[0]))) 685 for r in results_flat: 686 r.set_shape( 687 tensor_shape.TensorShape(n_static).concatenate(r.get_shape()[1:])) 688 689 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 690 # supported in Eager 691 if in_graph_mode and varscope_caching_device_was_none: 692 varscope.set_caching_device(None) 693 694 return output_pack(results_flat) 695 696 697@tf_export("scan", v1=[]) 698@dispatch.add_dispatch_support 699@deprecation.deprecated_arg_values( 700 None, 701 """back_prop=False is deprecated. Consider using tf.stop_gradient instead. 702Instead of: 703results = tf.scan(fn, elems, back_prop=False) 704Use: 705results = tf.nest.map_structure(tf.stop_gradient, tf.scan(fn, elems))""", 706 warn_once=True, 707 back_prop=False) 708def scan_v2(fn, 709 elems, 710 initializer=None, 711 parallel_iterations=10, 712 back_prop=True, 713 swap_memory=False, 714 infer_shape=True, 715 reverse=False, 716 name=None): 717 """scan on the list of tensors unpacked from `elems` on dimension 0. 718 719 The simplest version of `scan` repeatedly applies the callable `fn` to a 720 sequence of elements from first to last. The elements are made of the tensors 721 unpacked from `elems` on dimension 0. The callable fn takes two tensors as 722 arguments. The first argument is the accumulated value computed from the 723 preceding invocation of fn, and the second is the value at the current 724 position of `elems`. If `initializer` is None, `elems` must contain at least 725 one element, and its first element is used as the initializer. 726 727 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 728 of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`. 729 If reverse=True, it's fn(initializer, values[-1]).shape. 730 731 This method also allows multi-arity `elems` and accumulator. If `elems` 732 is a (possibly nested) list or tuple of tensors, then each of these tensors 733 must have a matching first (unpack) dimension. The second argument of 734 `fn` must match the structure of `elems`. 735 736 If no `initializer` is provided, the output structure and dtypes of `fn` 737 are assumed to be the same as its input; and in this case, the first 738 argument of `fn` must match the structure of `elems`. 739 740 If an `initializer` is provided, then the output of `fn` must have the same 741 structure as `initializer`; and the first argument of `fn` must match 742 this structure. 743 744 For example, if `elems` is `(t1, [t2, t3])` and `initializer` is 745 `[i1, i2]` then an appropriate signature for `fn` in `python2` is: 746 `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list, 747 `[acc_n1, acc_n2]`. An alternative correct signature for `fn`, and the 748 one that works in `python3`, is: 749 `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples. 750 751 Args: 752 fn: The callable to be performed. It accepts two arguments. The first will 753 have the same structure as `initializer` if one is provided, otherwise it 754 will have the same structure as `elems`. The second will have the same 755 (possibly nested) structure as `elems`. Its output must have the same 756 structure as `initializer` if one is provided, otherwise it must have the 757 same structure as `elems`. 758 elems: A tensor or (possibly nested) sequence of tensors, each of which will 759 be unpacked along their first dimension. The nested sequence of the 760 resulting slices will be the first argument to `fn`. 761 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 762 initial value for the accumulator, and the expected output type of `fn`. 763 parallel_iterations: (optional) The number of iterations allowed to run in 764 parallel. 765 back_prop: (optional) Deprecated. False disables support for back 766 propagation. Prefer using `tf.stop_gradient` instead. 767 swap_memory: (optional) True enables GPU-CPU memory swapping. 768 infer_shape: (optional) False disables tests for consistent output shapes. 769 reverse: (optional) True scans the tensor last to first (instead of first to 770 last). 771 name: (optional) Name prefix for the returned tensors. 772 773 Returns: 774 A tensor or (possibly nested) sequence of tensors. Each tensor packs the 775 results of applying `fn` to tensors unpacked from `elems` along the first 776 dimension, and the previous accumulator value(s), from first to last (or 777 last to first, if `reverse=True`). 778 779 Raises: 780 TypeError: if `fn` is not callable or the structure of the output of 781 `fn` and `initializer` do not match. 782 ValueError: if the lengths of the output of `fn` and `initializer` 783 do not match. 784 785 Examples: 786 ```python 787 elems = np.array([1, 2, 3, 4, 5, 6]) 788 sum = scan(lambda a, x: a + x, elems) 789 # sum == [1, 3, 6, 10, 15, 21] 790 sum = scan(lambda a, x: a + x, elems, reverse=True) 791 # sum == [21, 20, 18, 15, 11, 6] 792 ``` 793 794 ```python 795 elems = np.array([1, 2, 3, 4, 5, 6]) 796 initializer = np.array(0) 797 sum_one = scan( 798 lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer) 799 # sum_one == [1, 2, 3, 4, 5, 6] 800 ``` 801 802 ```python 803 elems = np.array([1, 0, 0, 0, 0, 0]) 804 initializer = (np.array(0), np.array(1)) 805 fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer) 806 # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13]) 807 ``` 808 """ 809 return scan( 810 fn=fn, 811 elems=elems, 812 initializer=initializer, 813 parallel_iterations=parallel_iterations, 814 back_prop=back_prop, 815 swap_memory=swap_memory, 816 infer_shape=infer_shape, 817 reverse=reverse, 818 name=name) 819 820 821# pylint: disable=invalid-name 822def If(cond, inputs, then_branch, else_branch, name=None): 823 r"""output = Cond(inputs) ? 824 825 then_branch(inputs) : else_branch(inputs). 826 827 Args: 828 cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is 829 converted to a boolean according to the following rule: if the scalar is a 830 numerical value, non-zero means True and zero means False; if the scalar 831 is a string, non-empty means True and empty means False. 832 inputs: A list of input tensors. 833 then_branch: A function takes 'inputs' and returns a list of tensors, whose 834 types are the same as what else_branch returns. 835 else_branch: A function takes 'inputs' and returns a list of tensors. whose 836 types are the same as what then_branch returns. 837 name: A name for the operation (optional). 838 839 Returns: 840 A list of tensors returned by either then_branch(inputs) 841 or else_branch(inputs). 842 """ 843 # pylint: disable=protected-access 844 # Handle the Defun case until users have transitioned to tf.function. Note 845 # that composites may need to be re-packed by the caller. 846 if isinstance(then_branch, function._DefinedFunction): 847 tlist = [_.type for _ in then_branch.definition.signature.output_arg] 848 return gen_functional_ops._if( 849 cond, inputs, tlist, then_branch, else_branch, name=name) 850 851 # We assume that `then_branch` is a ConcreteFunction here. 852 then_out = then_branch.structured_outputs 853 else_out = else_branch.structured_outputs 854 855 # Ensure then/else are the same type of composites to avoid an invalid call 856 # to pack_sequence_as later on. 857 nest.assert_same_structure(then_out, else_out, expand_composites=True) 858 859 tlist = nest.flatten(then_branch.output_dtypes) 860 ret = gen_functional_ops._if( 861 cond, inputs, tlist, then_branch, else_branch, name=name) 862 863 # Re-pack the outputs to restore any CompositeTensors 864 return nest.pack_sequence_as(then_out, ret, expand_composites=True) 865 866 867def Gradient(inputs, f, name=None): 868 r"""Computes the gradient function for function f via backpropagation. 869 870 Args: 871 inputs: A list of tensors of size N + M. 872 f: The function we want to compute the gradient for. The function 'f' must 873 be a numerical function which takes N inputs and produces M outputs. Its 874 gradient function 'g', which is a function taking N + M inputs and 875 produces N outputs. I.e. if we have (y1, y2, ..., yM) = f(x1, x2, ..., 876 xN), then, g is (dL/dx1, dL/dx2, ..., dL/dxN) = g(x1, x2, ..., xN, dL/dy1, 877 dL/dy2, ..., dL/dyM), where L is a scalar-value function of (x1, x2, ..., 878 xN) (e.g., the loss function). dL/dxi is the partial derivative of L with 879 respect to xi. 880 name: A name for the operation (optional). 881 882 Returns: 883 A list of tensors of size N. 884 """ 885 # TODO(zhifengc): Pretty-print the above spec in latex. 886 # TODO(zhfiengc): Needs some math expert to say the comment above better. 887 tlist = [_.type for _ in f.definition.signature.input_arg] 888 return symbolic_gradient(input=inputs, Tout=tlist, f=f, name=name) 889 890 891def _GetInputDtypes(func): 892 """Returns the input dtypes of func, excluding dtypes for captured inputs.""" 893 if isinstance(func, function._DefinedFunction): # pylint: disable=protected-access 894 return func.declared_input_types 895 896 # We assume that `func` is a ConcreteFunction here, but we are not able to 897 # verify since importing eager function library will cause cyclic dependence. 898 # 899 # ConcreteFunction.inputs includes captured inputs. 900 num_non_captured_inputs = len(func.inputs) - len(func.captured_inputs) 901 inputs_without_captured = func.inputs[:num_non_captured_inputs] 902 return [t.dtype for t in inputs_without_captured] 903 904 905def _LoopBodyCaptureWrapper(func): 906 """Returns a wrapper for `func` that handles loop-carried captured inputs.""" 907 908 @function.Defun(*_GetInputDtypes(func), func_name="%s_Wrapper" % func.name) 909 def Wrapper(*args): 910 """A wrapper that handles loop-carried captured inputs.""" 911 result = func(*args) 912 extra_args = tuple(function.get_extra_args()) 913 # Nullary functions return an Operation. Normal functions can't do this 914 # because their return values are converted to Tensors. 915 if isinstance(result, ops.Operation): 916 return extra_args 917 # Unary functions return a single Tensor value. 918 elif not isinstance(result, (list, tuple)): 919 return (result,) + extra_args 920 # N-ary functions return a tuple of Tensors. 921 else: 922 return result + type(result)(extra_args) 923 924 return Wrapper 925 926 927# pylint: disable=invalid-name,protected-access 928def While(input_, cond, body, name=None, hostmem=None): 929 r"""output = input; While (Cond(output)) { output = Body(output) }. 930 931 Args: 932 input_: A list of `Tensor` objects. A list of input tensors whose types are 933 T. 934 cond: . A function takes 'input' and returns a tensor. If the tensor is a 935 scalar of non-boolean, the scalar is converted to a boolean 936 according to the following rule: if the scalar is a numerical value, 937 non-zero means True and zero means False; if the scalar is a string, 938 non-empty means True and empty means False. If the tensor is not a 939 scalar, non-emptiness means True and False otherwise. 940 body: . A function takes a list of tensors and returns another list tensors. 941 Both lists have the same types as specified by T. 942 name: A name for the operation (optional). 943 hostmem: A list of integer. If i is in the list, input[i] is a host memory 944 tensor. 945 946 Raises: 947 ValueError: if `cond` has implicitly captured inputs or if `cond` and `body` 948 have different signatures. 949 950 Returns: 951 A list of `Tensor` objects. Has the same type as `input`. 952 A list of output tensors whose types are T. 953 """ 954 if cond.captured_inputs: 955 raise ValueError( 956 "The 'cond' argument can not have implicitly captured inputs. Received " 957 f"captured_inputs: {cond.captured_inputs}") 958 959 cond_input_types = _GetInputDtypes(cond) 960 body_input_types = _GetInputDtypes(body) 961 962 if cond_input_types != body_input_types: 963 raise ValueError( 964 "The 'cond' and 'body' signatures do not match. Received: " 965 f"cond_input_types={cond_input_types}, body_input_types=" 966 f"{body_input_types}") 967 968 if body.captured_inputs: 969 cond_dtypes = list(body_input_types) + [ 970 t.dtype for t in body.captured_inputs 971 ] 972 973 @function.Defun(*cond_dtypes, func_name="%s_Wrapper" % cond.name) 974 def CondWrapper(*args): 975 """A wrapper that handles loop-carried captured inputs.""" 976 return cond(*args[:len(body_input_types)]) 977 978 ret = gen_functional_ops._while( 979 input_ + body.captured_inputs, 980 CondWrapper, 981 _LoopBodyCaptureWrapper(body), 982 name=name) 983 # Slice off the loop-carried captured inputs. 984 ret = ret[:-len(body.captured_inputs)] 985 else: 986 ret = gen_functional_ops._while(input_, cond, body, name=name) 987 if hostmem: 988 input_attr = attr_value_pb2.AttrValue() 989 input_attr.list.i.extend(hostmem) 990 ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access 991 992 output_attr = attr_value_pb2.AttrValue() 993 output_attr.list.i.extend(hostmem) 994 ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access 995 return ret 996 997 998# b/36459430 999# 1000# Ideally, we do not need this rewrite For loop into a While loop. 1001# However, today, if a While runs on GPU and the condition returns a 1002# boolean, the While kernel crashes. Even if we fix the crash, the 1003# bool needs to be copied between GPU and CPU. So, a for loop is much 1004# preferred when running on GPU. 1005# 1006# On the other hand, For op has no directly XLA kernel. So, when we run 1007# a for loop, we need to rewrite it using a While op. 1008# 1009# It should be possible and probably better to write a XLA C++ kernel 1010# implementing the logic in _ForUsingWhile. 1011def _ForUsingWhile(start, 1012 limit, 1013 delta, 1014 inputs, 1015 forbody, 1016 name=None, 1017 hostmem=None): 1018 """Helper to implement a For loop using a While.""" 1019 # To support negative delta (e.g., range(100, 0, -3)), we iterate 1020 # over the range(n) and use iter * delta + start as the real 1021 # iteration index. (e.g., for i in range(34): iter = i * (-3) + 1022 # 100). 1023 d = math_ops.abs(delta) 1024 # XLA on TPUs doesn't support integer division 1025 n = math_ops.cast( 1026 math_ops.cast((math_ops.abs(limit - start) + d - 1), dtypes.float32) / 1027 math_ops.cast(d, dtypes.float32), dtypes.int32) 1028 1029 # Carried loop variables ("extra_args") are implicitly added to the input list 1030 # of the WhileBody function. WhileCond does not call forbody, and so does not 1031 # depend on any of forbody's extra_args. Since WhileCond and WhileBody 1032 # must have identical inputs, we have to augment the cond signature to take 1033 # the same types as the carried loop variables. 1034 body_sig = [dtypes.int32] * 4 + list(forbody.declared_input_types)[1:] 1035 1036 cond_name = "%s_Cond" % forbody.name 1037 1038 @function.Defun(*body_sig, func_name=cond_name) 1039 def WhileCond(i, n, *args): 1040 del args 1041 return i < n 1042 1043 body_name = "%s_Body" % forbody.name 1044 1045 @function.Defun(*body_sig, func_name=body_name) 1046 def WhileBody(i, n, start, delta, *args): 1047 """A While wrapper for forbody that handles loop-carried captured inputs.""" 1048 for_result = forbody(start + i * delta, *args) 1049 # Nullary functions return an Operation. Normal functions can't do this 1050 # because their return values are converted to Tensors. 1051 if isinstance(for_result, ops.Operation): 1052 for_result = () 1053 # Unary functions return a single Tensor value. 1054 elif isinstance(for_result, ops.Tensor): 1055 for_result = (for_result,) 1056 return (i + 1, n, start, delta) + tuple(for_result) 1057 1058 if hostmem is not None: 1059 hostmem = [0, 1, 2, 3] + [(4 + _) for _ in hostmem] 1060 else: 1061 hostmem = [0, 1, 2, 3] 1062 1063 results = While( 1064 input_=[0, n, start, delta] + inputs, 1065 cond=WhileCond, 1066 body=WhileBody, 1067 name=name, 1068 hostmem=hostmem) 1069 # Slice off the loop-carried captured inputs. 1070 return list(results[4:len(results)]) 1071 1072 1073def For(start, 1074 limit, 1075 delta, 1076 inputs, 1077 body, 1078 name=None, 1079 hostmem=None, 1080 rewrite_with_while=None): 1081 r"""out = input; for i in range(start, limit, delta) out = body(i, out). 1082 1083 Args: 1084 start: A `Tensor` of type `int32`. 1085 limit: A `Tensor` of type `int32`. 1086 delta: A `Tensor` of type `int32`. 1087 inputs: A list of `Tensor` objects. A list of input tensors whose types are 1088 T. 1089 body: A function takes a list of tensors and returns another list of 1090 tensors. Both lists have the same types as (int32, T...). 1091 name: A name for the operation (optional). 1092 hostmem: A list of integer. If i is in the list, inputs[i] is a host memory 1093 tensor. In other words, (i+1)-th argument of the body function is 1094 expecting a host memory. 1095 rewrite_with_while: If True, using While op to implement the For. 1096 1097 Returns: 1098 A list of `Tensor` objects. Has the same type as `input`. 1099 A list of output tensors whose types are T. 1100 """ 1101 if rewrite_with_while: 1102 return _ForUsingWhile(start, limit, delta, inputs, body, name, hostmem) 1103 if body.captured_inputs: 1104 ret = gen_functional_ops._for( 1105 start, 1106 limit, 1107 delta, 1108 inputs + body.captured_inputs, 1109 _LoopBodyCaptureWrapper(body), 1110 name=name) 1111 # Slice off the loop-carried captured inputs. 1112 ret = ret[:-len(body.captured_inputs)] 1113 else: 1114 ret = gen_functional_ops._for(start, limit, delta, inputs, body, name=name) 1115 if hostmem: 1116 num_for_params = 3 # start/limit/delta 1117 1118 input_attr = attr_value_pb2.AttrValue() 1119 input_attr.list.i.extend([num_for_params + i for i in hostmem]) 1120 ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access 1121 1122 output_attr = attr_value_pb2.AttrValue() 1123 output_attr.list.i.extend(hostmem) 1124 ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access 1125 return ret 1126 1127 1128# pylint: enable=invalid-name,protected-access 1129 1130 1131def partitioned_call(args, 1132 f, 1133 tout=None, 1134 executing_eagerly=None, 1135 config=None, 1136 executor_type=None): 1137 """Executes a function while respecting device annotations. 1138 1139 Currently, only those functions that execute within the same address space 1140 can be executed. 1141 1142 Args: 1143 args: The arguments of the function, including captured inputs. 1144 f: The function to execute; an instance of `_DefinedFunction` or 1145 `_EagerDefinedFunction`. 1146 tout: a list containing the output dtypes enums; if `None`, inferred from 1147 the signature of `f`. 1148 executing_eagerly: (Optional) A boolean indicating whether the context is 1149 executing eagerly. If `None`, fetched from the global context. 1150 config: (Optional) A `tensorflow::ConfigProto` proto, serialized. If `None`, 1151 all optimizations are disabled. Currently only handled for eager defined 1152 functions. 1153 executor_type: (Optional) A string for the name of the executor to be used 1154 in the function call. If not set, or set to an empty string, the default 1155 tensorflow executor will be used. 1156 1157 Returns: 1158 The list of `Tensor`s returned by invoking `f(args)`. If the function does 1159 not return anything, then returns `None` if eager execution is enabled, or 1160 the `Operation` if not. 1161 """ 1162 1163 if tout is None: 1164 tout = tuple(x.type for x in f.definition.signature.output_arg) 1165 1166 if executing_eagerly is None: 1167 executing_eagerly = context.executing_eagerly() 1168 1169 if config is None: 1170 config = function_utils.get_disabled_rewriter_config() 1171 1172 if executor_type is None: 1173 executor_type = "" 1174 1175 if executing_eagerly: 1176 if f.stateful_ops: 1177 outputs = gen_functional_ops.stateful_partitioned_call( 1178 args=args, 1179 Tout=tout, 1180 f=f, 1181 config_proto=config, 1182 executor_type=executor_type) 1183 else: 1184 outputs = gen_functional_ops.partitioned_call( 1185 args=args, 1186 Tout=tout, 1187 f=f, 1188 config_proto=config, 1189 executor_type=executor_type) 1190 return outputs if outputs else None 1191 1192 # The generated binding returns an empty list for functions that don't 1193 # return any Tensors, hence the need to use `create_op` directly. 1194 args = [ops.convert_to_tensor(x) for x in args] 1195 tin_attr = attr_value_pb2.AttrValue( 1196 list=attr_value_pb2.AttrValue.ListValue( 1197 type=[x.dtype.as_datatype_enum for x in args])) 1198 tout_attr = attr_value_pb2.AttrValue( 1199 list=attr_value_pb2.AttrValue.ListValue(type=tout)) 1200 func_attr = attr_value_pb2.AttrValue( 1201 func=attr_value_pb2.NameAttrList(name=f.name)) 1202 executor_type_attr = attr_value_pb2.AttrValue( 1203 s=compat.as_bytes(executor_type)) 1204 1205 # When running in graph mode, the graph and function graphs are optimized 1206 # (i.e. run through grappler) per the session options, so we can disable any 1207 # eager-specific rewriting. 1208 config_proto = attr_value_pb2.AttrValue(s=config) 1209 1210 graph = ops.get_default_graph() 1211 f.add_to_graph(graph) 1212 op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall" 1213 1214 # Propagate the attribute indicating the need to compile from function to the 1215 # call itself. 1216 xla_compile_attr = "_XlaMustCompile" 1217 op_attrs = { 1218 "Tin": tin_attr, 1219 "Tout": tout_attr, 1220 "f": func_attr, 1221 "config_proto": config_proto, 1222 "executor_type": executor_type_attr, 1223 } 1224 if xla_compile_attr in f.definition.attr: 1225 op_attrs[xla_compile_attr] = f.definition.attr[xla_compile_attr] 1226 op = graph.create_op(op_name, args, tout, name=op_name, attrs=op_attrs) 1227 outputs = op.outputs 1228 if hasattr(f, "graph"): 1229 _set_read_only_resource_inputs_attr(op, f.graph) 1230 if hasattr(f.graph, "collective_manager_ids_used"): 1231 ops.set_int_list_attr(op, acd.COLLECTIVE_MANAGER_IDS, 1232 f.graph.collective_manager_ids_used) 1233 return outputs if outputs else op 1234 1235 1236def _set_read_only_resource_inputs_attr(op, func_graph): 1237 """Sets the list of resource inputs which are read-only. 1238 1239 This is used by AutomaticControlDependencies. 1240 1241 Args: 1242 op: PartitionedCall Operation. 1243 func_graph: FuncGraph. 1244 """ 1245 read_only_indices = acd.get_read_only_resource_input_indices_graph(func_graph) 1246 ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR, 1247 read_only_indices) 1248