1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Functional operations.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.framework import attr_value_pb2 22from tensorflow.python.eager import context 23from tensorflow.python.framework import auto_control_deps_utils as acd 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import function 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import gen_functional_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import tensor_array_ops 34from tensorflow.python.ops import variable_scope as vs 35# pylint: disable=unused-import 36from tensorflow.python.ops.gen_functional_ops import remote_call 37# pylint: enable=unused-import 38from tensorflow.python.ops.gen_functional_ops import symbolic_gradient 39from tensorflow.python.util import compat 40from tensorflow.python.util import deprecation 41from tensorflow.python.util import dispatch 42from tensorflow.python.util import function_utils 43from tensorflow.python.util import nest 44from tensorflow.python.util.tf_export import tf_export 45 46 47# TODO(yuanbyu, mrry): Handle stride to support sliding windows. 48@tf_export(v1=["foldl"]) 49@dispatch.add_dispatch_support 50def foldl(fn, 51 elems, 52 initializer=None, 53 parallel_iterations=10, 54 back_prop=True, 55 swap_memory=False, 56 name=None): 57 """foldl on the list of tensors unpacked from `elems` on dimension 0. 58 59 This foldl operator repeatedly applies the callable `fn` to a sequence 60 of elements from first to last. The elements are made of the tensors 61 unpacked from `elems` on dimension 0. The callable fn takes two tensors as 62 arguments. The first argument is the accumulated value computed from the 63 preceding invocation of fn, and the second is the value at the current 64 position of `elems`. If `initializer` is None, `elems` must contain at least 65 one element, and its first element is used as the initializer. 66 67 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 68 of the result tensor is fn(initializer, values[0]).shape`. 69 70 This method also allows multi-arity `elems` and output of `fn`. If `elems` 71 is a (possibly nested) list or tuple of tensors, then each of these tensors 72 must have a matching first (unpack) dimension. The signature of `fn` may 73 match the structure of `elems`. That is, if `elems` is 74 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: 75 `fn = lambda (t1, [t2, t3, [t4, t5]]):`. 76 77 Args: 78 fn: The callable to be performed. 79 elems: A tensor or (possibly nested) sequence of tensors, each of which will 80 be unpacked along their first dimension. The nested sequence of the 81 resulting slices will be the first argument to `fn`. 82 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 83 as the initial value for the accumulator. 84 parallel_iterations: (optional) The number of iterations allowed to run in 85 parallel. 86 back_prop: (optional) True enables support for back propagation. 87 swap_memory: (optional) True enables GPU-CPU memory swapping. 88 name: (optional) Name prefix for the returned tensors. 89 90 Returns: 91 A tensor or (possibly nested) sequence of tensors, resulting from applying 92 `fn` consecutively to the list of tensors unpacked from `elems`, from first 93 to last. 94 95 Raises: 96 TypeError: if `fn` is not callable. 97 98 Example: 99 ```python 100 elems = tf.constant([1, 2, 3, 4, 5, 6]) 101 sum = foldl(lambda a, x: a + x, elems) 102 # sum == 21 103 ``` 104 """ 105 if not callable(fn): 106 raise TypeError("fn must be callable.") 107 108 def create_ta(elem): 109 return tensor_array_ops.TensorArray( 110 dtype=elem.dtype, size=n, dynamic_size=False, 111 infer_shape=True).unstack(elem) 112 113 in_graph_mode = not context.executing_eagerly() 114 with ops.name_scope(name, "foldl", [elems]): 115 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 116 # supported in Eager 117 if in_graph_mode: 118 # Any get_variable calls in fn will cache the first call locally 119 # and not issue repeated network I/O requests for each iteration. 120 varscope = vs.get_variable_scope() 121 varscope_caching_device_was_none = False 122 if varscope.caching_device is None: 123 # TODO(ebrevdo): Change to using colocate_with here and in other 124 # methods. 125 varscope.set_caching_device(lambda op: op.device) 126 varscope_caching_device_was_none = True 127 128 # Convert elems to tensor array. n may be known statically. 129 elems_flat = [ 130 ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems) 131 ] 132 n = ( 133 tensor_shape.dimension_value(elems_flat[0].shape[0]) or 134 array_ops.shape(elems_flat[0])[0]) 135 136 elems_ta = nest.map_structure(create_ta, elems) 137 138 if initializer is None: 139 a = nest.map_structure(lambda elem: elem.read(0), elems_ta) 140 i = constant_op.constant(1) 141 else: 142 a = initializer 143 i = constant_op.constant(0) 144 145 def compute(i, a): 146 elem_i = nest.map_structure(lambda elem: elem.read(i), elems_ta) 147 a = fn(a, elem_i) 148 return [i + 1, a] 149 150 _, r_a = control_flow_ops.while_loop( 151 lambda i, a: i < n, 152 compute, [i, a], 153 parallel_iterations=parallel_iterations, 154 back_prop=back_prop, 155 swap_memory=swap_memory, 156 maximum_iterations=n) 157 158 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 159 # supported in Eager 160 if in_graph_mode and varscope_caching_device_was_none: 161 varscope.set_caching_device(None) 162 163 return r_a 164 165 166@tf_export("foldl", v1=[]) 167@dispatch.add_dispatch_support 168@deprecation.deprecated_arg_values( 169 None, 170 """back_prop=False is deprecated. Consider using tf.stop_gradient instead. 171Instead of: 172results = tf.foldl(fn, elems, back_prop=False) 173Use: 174results = tf.nest.map_structure(tf.stop_gradient, tf.foldl(fn, elems))""", 175 warn_once=True, 176 back_prop=False) 177def foldl_v2(fn, 178 elems, 179 initializer=None, 180 parallel_iterations=10, 181 back_prop=True, 182 swap_memory=False, 183 name=None): 184 """foldl on the list of tensors unpacked from `elems` on dimension 0. 185 186 This foldl operator repeatedly applies the callable `fn` to a sequence 187 of elements from first to last. The elements are made of the tensors 188 unpacked from `elems` on dimension 0. The callable fn takes two tensors as 189 arguments. The first argument is the accumulated value computed from the 190 preceding invocation of fn, and the second is the value at the current 191 position of `elems`. If `initializer` is None, `elems` must contain at least 192 one element, and its first element is used as the initializer. 193 194 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 195 of the result tensor is fn(initializer, values[0]).shape`. 196 197 This method also allows multi-arity `elems` and output of `fn`. If `elems` 198 is a (possibly nested) list or tuple of tensors, then each of these tensors 199 must have a matching first (unpack) dimension. The signature of `fn` may 200 match the structure of `elems`. That is, if `elems` is 201 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: 202 `fn = lambda (t1, [t2, t3, [t4, t5]]):`. 203 204 Args: 205 fn: The callable to be performed. 206 elems: A tensor or (possibly nested) sequence of tensors, each of which will 207 be unpacked along their first dimension. The nested sequence of the 208 resulting slices will be the first argument to `fn`. 209 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 210 as the initial value for the accumulator. 211 parallel_iterations: (optional) The number of iterations allowed to run in 212 parallel. 213 back_prop: (optional) Deprecated. False disables support for back 214 propagation. Prefer using `tf.stop_gradient` instead. 215 swap_memory: (optional) True enables GPU-CPU memory swapping. 216 name: (optional) Name prefix for the returned tensors. 217 218 Returns: 219 A tensor or (possibly nested) sequence of tensors, resulting from applying 220 `fn` consecutively to the list of tensors unpacked from `elems`, from first 221 to last. 222 223 Raises: 224 TypeError: if `fn` is not callable. 225 226 Example: 227 ```python 228 elems = tf.constant([1, 2, 3, 4, 5, 6]) 229 sum = foldl(lambda a, x: a + x, elems) 230 # sum == 21 231 ``` 232 """ 233 return foldl( 234 fn=fn, 235 elems=elems, 236 initializer=initializer, 237 parallel_iterations=parallel_iterations, 238 back_prop=back_prop, 239 swap_memory=swap_memory, 240 name=name) 241 242 243@tf_export(v1=["foldr"]) 244@dispatch.add_dispatch_support 245def foldr(fn, 246 elems, 247 initializer=None, 248 parallel_iterations=10, 249 back_prop=True, 250 swap_memory=False, 251 name=None): 252 """foldr on the list of tensors unpacked from `elems` on dimension 0. 253 254 This foldr operator repeatedly applies the callable `fn` to a sequence 255 of elements from last to first. The elements are made of the tensors 256 unpacked from `elems`. The callable fn takes two tensors as arguments. 257 The first argument is the accumulated value computed from the preceding 258 invocation of fn, and the second is the value at the current position of 259 `elems`. If `initializer` is None, `elems` must contain at least one element, 260 and its first element is used as the initializer. 261 262 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 263 of the result tensor is `fn(initializer, values[0]).shape`. 264 265 This method also allows multi-arity `elems` and output of `fn`. If `elems` 266 is a (possibly nested) list or tuple of tensors, then each of these tensors 267 must have a matching first (unpack) dimension. The signature of `fn` may 268 match the structure of `elems`. That is, if `elems` is 269 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: 270 `fn = lambda (t1, [t2, t3, [t4, t5]]):`. 271 272 Args: 273 fn: The callable to be performed. 274 elems: A tensor or (possibly nested) sequence of tensors, each of which will 275 be unpacked along their first dimension. The nested sequence of the 276 resulting slices will be the first argument to `fn`. 277 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 278 as the initial value for the accumulator. 279 parallel_iterations: (optional) The number of iterations allowed to run in 280 parallel. 281 back_prop: (optional) True enables support for back propagation. 282 swap_memory: (optional) True enables GPU-CPU memory swapping. 283 name: (optional) Name prefix for the returned tensors. 284 285 Returns: 286 A tensor or (possibly nested) sequence of tensors, resulting from applying 287 `fn` consecutively to the list of tensors unpacked from `elems`, from last 288 to first. 289 290 Raises: 291 TypeError: if `fn` is not callable. 292 293 Example: 294 ```python 295 elems = [1, 2, 3, 4, 5, 6] 296 sum = foldr(lambda a, x: a + x, elems) 297 # sum == 21 298 ``` 299 """ 300 if not callable(fn): 301 raise TypeError("fn must be callable.") 302 303 def create_ta(elem): 304 return tensor_array_ops.TensorArray( 305 dtype=elem.dtype, size=n, dynamic_size=False, 306 infer_shape=True).unstack(elem) 307 308 in_graph_mode = not context.executing_eagerly() 309 with ops.name_scope(name, "foldr", [elems]): 310 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 311 # supported in Eager 312 if in_graph_mode: 313 # Any get_variable calls in fn will cache the first call locally and not 314 # issue repeated network I/O requests for each iteration. 315 varscope = vs.get_variable_scope() 316 varscope_caching_device_was_none = False 317 if varscope.caching_device is None: 318 # TODO(ebrevdo): Change to using colocate_with here and in other 319 # methods. 320 varscope.set_caching_device(lambda op: op.device) 321 varscope_caching_device_was_none = True 322 323 # Convert elems to tensor array. n may be known statically. 324 elems_flat = [ 325 ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems) 326 ] 327 n = ( 328 tensor_shape.dimension_value(elems_flat[0].shape[0]) or 329 array_ops.shape(elems_flat[0])[0]) 330 331 elems_ta = nest.map_structure(create_ta, elems) 332 333 if initializer is None: 334 i = n - 1 335 a = nest.map_structure(lambda elem: elem.read(i), elems_ta) 336 else: 337 i = n 338 a = initializer 339 340 def compute(i, a): 341 i -= 1 342 elem = nest.map_structure(lambda elem: elem.read(i), elems_ta) 343 a_out = fn(a, elem) 344 return [i, a_out] 345 346 _, r_a = control_flow_ops.while_loop( 347 lambda i, a: i > 0, 348 compute, [i, a], 349 parallel_iterations=parallel_iterations, 350 back_prop=back_prop, 351 swap_memory=swap_memory, 352 maximum_iterations=n) 353 354 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 355 # supported in Eager 356 if in_graph_mode and varscope_caching_device_was_none: 357 varscope.set_caching_device(None) 358 359 return r_a 360 361 362@tf_export("foldr", v1=[]) 363@dispatch.add_dispatch_support 364@deprecation.deprecated_arg_values( 365 None, 366 """back_prop=False is deprecated. Consider using tf.stop_gradient instead. 367Instead of: 368results = tf.foldr(fn, elems, back_prop=False) 369Use: 370results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))""", 371 warn_once=True, 372 back_prop=False) 373def foldr_v2(fn, 374 elems, 375 initializer=None, 376 parallel_iterations=10, 377 back_prop=True, 378 swap_memory=False, 379 name=None): 380 """foldr on the list of tensors unpacked from `elems` on dimension 0. 381 382 This foldr operator repeatedly applies the callable `fn` to a sequence 383 of elements from last to first. The elements are made of the tensors 384 unpacked from `elems`. The callable fn takes two tensors as arguments. 385 The first argument is the accumulated value computed from the preceding 386 invocation of fn, and the second is the value at the current position of 387 `elems`. If `initializer` is None, `elems` must contain at least one element, 388 and its first element is used as the initializer. 389 390 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 391 of the result tensor is `fn(initializer, values[0]).shape`. 392 393 This method also allows multi-arity `elems` and output of `fn`. If `elems` 394 is a (possibly nested) list or tuple of tensors, then each of these tensors 395 must have a matching first (unpack) dimension. The signature of `fn` may 396 match the structure of `elems`. That is, if `elems` is 397 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: 398 `fn = lambda (t1, [t2, t3, [t4, t5]]):`. 399 400 Args: 401 fn: The callable to be performed. 402 elems: A tensor or (possibly nested) sequence of tensors, each of which will 403 be unpacked along their first dimension. The nested sequence of the 404 resulting slices will be the first argument to `fn`. 405 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 406 as the initial value for the accumulator. 407 parallel_iterations: (optional) The number of iterations allowed to run in 408 parallel. 409 back_prop: (optional) Deprecated. False disables support for back 410 propagation. Prefer using `tf.stop_gradient` instead. 411 swap_memory: (optional) True enables GPU-CPU memory swapping. 412 name: (optional) Name prefix for the returned tensors. 413 414 Returns: 415 A tensor or (possibly nested) sequence of tensors, resulting from applying 416 `fn` consecutively to the list of tensors unpacked from `elems`, from last 417 to first. 418 419 Raises: 420 TypeError: if `fn` is not callable. 421 422 Example: 423 ```python 424 elems = [1, 2, 3, 4, 5, 6] 425 sum = foldr(lambda a, x: a + x, elems) 426 # sum == 21 427 ``` 428 """ 429 return foldr( 430 fn=fn, 431 elems=elems, 432 initializer=initializer, 433 parallel_iterations=parallel_iterations, 434 back_prop=back_prop, 435 swap_memory=swap_memory, 436 name=name) 437 438 439@tf_export(v1=["scan"]) 440@dispatch.add_dispatch_support 441def scan(fn, 442 elems, 443 initializer=None, 444 parallel_iterations=10, 445 back_prop=True, 446 swap_memory=False, 447 infer_shape=True, 448 reverse=False, 449 name=None): 450 """scan on the list of tensors unpacked from `elems` on dimension 0. 451 452 See also `tf.map_fn`. 453 454 The simplest version of `scan` repeatedly applies the callable `fn` to a 455 sequence of elements from first to last. The elements are made of the tensors 456 unpacked from `elems` on dimension 0. The callable fn takes two tensors as 457 arguments. The first argument is the accumulated value computed from the 458 preceding invocation of fn, and the second is the value at the current 459 position of `elems`. If `initializer` is None, `elems` must contain at least 460 one element, and its first element is used as the initializer. 461 462 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 463 of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`. 464 If reverse=True, it's fn(initializer, values[-1]).shape. 465 466 This method also allows multi-arity `elems` and accumulator. If `elems` 467 is a (possibly nested) list or tuple of tensors, then each of these tensors 468 must have a matching first (unpack) dimension. The second argument of 469 `fn` must match the structure of `elems`. 470 471 If no `initializer` is provided, the output structure and dtypes of `fn` 472 are assumed to be the same as its input; and in this case, the first 473 argument of `fn` must match the structure of `elems`. 474 475 If an `initializer` is provided, then the output of `fn` must have the same 476 structure as `initializer`; and the first argument of `fn` must match 477 this structure. 478 479 For example, if `elems` is `(t1, [t2, t3])` and `initializer` is 480 `[i1, i2]` then an appropriate signature for `fn` in `python2` is: 481 `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list, 482 `[acc_n1, acc_n2]`. An alternative correct signature for `fn`, and the 483 one that works in `python3`, is: 484 `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples. 485 486 Args: 487 fn: The callable to be performed. It accepts two arguments. The first will 488 have the same structure as `initializer` if one is provided, otherwise it 489 will have the same structure as `elems`. The second will have the same 490 (possibly nested) structure as `elems`. Its output must have the same 491 structure as `initializer` if one is provided, otherwise it must have the 492 same structure as `elems`. 493 elems: A tensor or (possibly nested) sequence of tensors, each of which will 494 be unpacked along their first dimension. The nested sequence of the 495 resulting slices will be the first argument to `fn`. 496 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 497 initial value for the accumulator, and the expected output type of `fn`. 498 parallel_iterations: (optional) The number of iterations allowed to run in 499 parallel. 500 back_prop: (optional) True enables support for back propagation. 501 swap_memory: (optional) True enables GPU-CPU memory swapping. 502 infer_shape: (optional) False disables tests for consistent output shapes. 503 reverse: (optional) True scans the tensor last to first (instead of first to 504 last). 505 name: (optional) Name prefix for the returned tensors. 506 507 Returns: 508 A tensor or (possibly nested) sequence of tensors. Each tensor packs the 509 results of applying `fn` to tensors unpacked from `elems` along the first 510 dimension, and the previous accumulator value(s), from first to last (or 511 last to first, if `reverse=True`). 512 513 Raises: 514 TypeError: if `fn` is not callable or the structure of the output of 515 `fn` and `initializer` do not match. 516 ValueError: if the lengths of the output of `fn` and `initializer` 517 do not match. 518 519 Examples: 520 ```python 521 elems = np.array([1, 2, 3, 4, 5, 6]) 522 sum = scan(lambda a, x: a + x, elems) 523 # sum == [1, 3, 6, 10, 15, 21] 524 sum = scan(lambda a, x: a + x, elems, reverse=True) 525 # sum == [21, 20, 18, 15, 11, 6] 526 ``` 527 528 ```python 529 elems = np.array([1, 2, 3, 4, 5, 6]) 530 initializer = np.array(0) 531 sum_one = scan( 532 lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer) 533 # sum_one == [1, 2, 3, 4, 5, 6] 534 ``` 535 536 ```python 537 elems = np.array([1, 0, 0, 0, 0, 0]) 538 initializer = (np.array(0), np.array(1)) 539 fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer) 540 # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13]) 541 ``` 542 """ 543 if not callable(fn): 544 raise TypeError("fn must be callable.") 545 546 input_is_sequence = nest.is_sequence(elems) 547 input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x] 548 549 def input_pack(x): 550 return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0] 551 552 if initializer is None: 553 output_is_sequence = input_is_sequence 554 output_flatten = input_flatten 555 output_pack = input_pack 556 else: 557 output_is_sequence = nest.is_sequence(initializer) 558 output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x] 559 560 def output_pack(x): 561 return (nest.pack_sequence_as(initializer, x) 562 if output_is_sequence else x[0]) 563 564 elems_flat = input_flatten(elems) 565 566 in_graph_mode = not context.executing_eagerly() 567 with ops.name_scope(name, "scan", elems_flat): 568 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 569 # supported in Eager 570 if in_graph_mode: 571 # Any get_variable calls in fn will cache the first call locally 572 # and not issue repeated network I/O requests for each iteration. 573 varscope = vs.get_variable_scope() 574 varscope_caching_device_was_none = False 575 if varscope.caching_device is None: 576 # TODO(ebrevdo): Change to using colocate_with here and in other 577 # methods. 578 varscope.set_caching_device(lambda op: op.device) 579 varscope_caching_device_was_none = True 580 581 # Convert elems to tensor array. 582 elems_flat = [ 583 ops.convert_to_tensor(elem, name="elem") for elem in elems_flat 584 ] 585 586 # Convert elems to tensor array. n may be known statically. 587 n = tensor_shape.dimension_value(elems_flat[0].shape[0]) 588 if n is None: 589 n = array_ops.shape(elems_flat[0])[0] 590 591 # TensorArrays are always flat 592 elems_ta = [ 593 tensor_array_ops.TensorArray( 594 dtype=elem.dtype, 595 size=n, 596 dynamic_size=False, 597 element_shape=elem.shape[1:], 598 infer_shape=True) for elem in elems_flat 599 ] 600 # Unpack elements 601 elems_ta = [ 602 elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat) 603 ] 604 605 if initializer is None: 606 a_flat = [elem.read(n - 1 if reverse else 0) for elem in elems_ta] 607 i = 1 608 else: 609 initializer_flat = output_flatten(initializer) 610 a_flat = [ops.convert_to_tensor(init) for init in initializer_flat] 611 i = 0 612 613 # Create a tensor array to store the intermediate values. 614 accs_ta = [ 615 tensor_array_ops.TensorArray( 616 dtype=init.dtype, 617 size=n, 618 element_shape=init.shape if infer_shape else None, 619 dynamic_size=False, 620 infer_shape=infer_shape) for init in a_flat 621 ] 622 623 if initializer is None: 624 accs_ta = [ 625 acc_ta.write(n - 1 if reverse else 0, a) 626 for (acc_ta, a) in zip(accs_ta, a_flat) 627 ] 628 629 def compute(i, a_flat, tas): 630 """The loop body of scan. 631 632 Args: 633 i: the loop counter. 634 a_flat: the accumulator value(s), flattened. 635 tas: the output accumulator TensorArray(s), flattened. 636 637 Returns: 638 [i + 1, a_flat, tas]: the updated counter + new accumulator values + 639 updated TensorArrays 640 641 Raises: 642 TypeError: if initializer and fn() output structure do not match 643 ValueType: if initializer and fn() output lengths do not match 644 """ 645 packed_elems = input_pack([elem_ta.read(i) for elem_ta in elems_ta]) 646 packed_a = output_pack(a_flat) 647 a_out = fn(packed_a, packed_elems) 648 nest.assert_same_structure(elems if initializer is None else initializer, 649 a_out) 650 flat_a_out = output_flatten(a_out) 651 tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)] 652 if reverse: 653 next_i = i - 1 654 else: 655 next_i = i + 1 656 return (next_i, flat_a_out, tas) 657 658 if reverse: 659 initial_i = n - 1 - i 660 condition = lambda i, _1, _2: i >= 0 661 else: 662 initial_i = i 663 condition = lambda i, _1, _2: i < n 664 _, _, r_a = control_flow_ops.while_loop( 665 condition, 666 compute, (initial_i, a_flat, accs_ta), 667 parallel_iterations=parallel_iterations, 668 back_prop=back_prop, 669 swap_memory=swap_memory, 670 maximum_iterations=n) 671 672 results_flat = [r.stack() for r in r_a] 673 674 n_static = tensor_shape.Dimension( 675 tensor_shape.dimension_value( 676 elems_flat[0].get_shape().with_rank_at_least(1)[0])) 677 for elem in elems_flat[1:]: 678 n_static.assert_is_compatible_with( 679 tensor_shape.Dimension( 680 tensor_shape.dimension_value( 681 elem.get_shape().with_rank_at_least(1)[0]))) 682 for r in results_flat: 683 r.set_shape( 684 tensor_shape.TensorShape(n_static).concatenate(r.get_shape()[1:])) 685 686 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 687 # supported in Eager 688 if in_graph_mode and varscope_caching_device_was_none: 689 varscope.set_caching_device(None) 690 691 return output_pack(results_flat) 692 693 694@tf_export("scan", v1=[]) 695@dispatch.add_dispatch_support 696@deprecation.deprecated_arg_values( 697 None, 698 """back_prop=False is deprecated. Consider using tf.stop_gradient instead. 699Instead of: 700results = tf.scan(fn, elems, back_prop=False) 701Use: 702results = tf.nest.map_structure(tf.stop_gradient, tf.scan(fn, elems))""", 703 warn_once=True, 704 back_prop=False) 705def scan_v2(fn, 706 elems, 707 initializer=None, 708 parallel_iterations=10, 709 back_prop=True, 710 swap_memory=False, 711 infer_shape=True, 712 reverse=False, 713 name=None): 714 """scan on the list of tensors unpacked from `elems` on dimension 0. 715 716 The simplest version of `scan` repeatedly applies the callable `fn` to a 717 sequence of elements from first to last. The elements are made of the tensors 718 unpacked from `elems` on dimension 0. The callable fn takes two tensors as 719 arguments. The first argument is the accumulated value computed from the 720 preceding invocation of fn, and the second is the value at the current 721 position of `elems`. If `initializer` is None, `elems` must contain at least 722 one element, and its first element is used as the initializer. 723 724 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 725 of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`. 726 If reverse=True, it's fn(initializer, values[-1]).shape. 727 728 This method also allows multi-arity `elems` and accumulator. If `elems` 729 is a (possibly nested) list or tuple of tensors, then each of these tensors 730 must have a matching first (unpack) dimension. The second argument of 731 `fn` must match the structure of `elems`. 732 733 If no `initializer` is provided, the output structure and dtypes of `fn` 734 are assumed to be the same as its input; and in this case, the first 735 argument of `fn` must match the structure of `elems`. 736 737 If an `initializer` is provided, then the output of `fn` must have the same 738 structure as `initializer`; and the first argument of `fn` must match 739 this structure. 740 741 For example, if `elems` is `(t1, [t2, t3])` and `initializer` is 742 `[i1, i2]` then an appropriate signature for `fn` in `python2` is: 743 `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list, 744 `[acc_n1, acc_n2]`. An alternative correct signature for `fn`, and the 745 one that works in `python3`, is: 746 `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples. 747 748 Args: 749 fn: The callable to be performed. It accepts two arguments. The first will 750 have the same structure as `initializer` if one is provided, otherwise it 751 will have the same structure as `elems`. The second will have the same 752 (possibly nested) structure as `elems`. Its output must have the same 753 structure as `initializer` if one is provided, otherwise it must have the 754 same structure as `elems`. 755 elems: A tensor or (possibly nested) sequence of tensors, each of which will 756 be unpacked along their first dimension. The nested sequence of the 757 resulting slices will be the first argument to `fn`. 758 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 759 initial value for the accumulator, and the expected output type of `fn`. 760 parallel_iterations: (optional) The number of iterations allowed to run in 761 parallel. 762 back_prop: (optional) Deprecated. False disables support for back 763 propagation. Prefer using `tf.stop_gradient` instead. 764 swap_memory: (optional) True enables GPU-CPU memory swapping. 765 infer_shape: (optional) False disables tests for consistent output shapes. 766 reverse: (optional) True scans the tensor last to first (instead of first to 767 last). 768 name: (optional) Name prefix for the returned tensors. 769 770 Returns: 771 A tensor or (possibly nested) sequence of tensors. Each tensor packs the 772 results of applying `fn` to tensors unpacked from `elems` along the first 773 dimension, and the previous accumulator value(s), from first to last (or 774 last to first, if `reverse=True`). 775 776 Raises: 777 TypeError: if `fn` is not callable or the structure of the output of 778 `fn` and `initializer` do not match. 779 ValueError: if the lengths of the output of `fn` and `initializer` 780 do not match. 781 782 Examples: 783 ```python 784 elems = np.array([1, 2, 3, 4, 5, 6]) 785 sum = scan(lambda a, x: a + x, elems) 786 # sum == [1, 3, 6, 10, 15, 21] 787 sum = scan(lambda a, x: a + x, elems, reverse=True) 788 # sum == [21, 20, 18, 15, 11, 6] 789 ``` 790 791 ```python 792 elems = np.array([1, 2, 3, 4, 5, 6]) 793 initializer = np.array(0) 794 sum_one = scan( 795 lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer) 796 # sum_one == [1, 2, 3, 4, 5, 6] 797 ``` 798 799 ```python 800 elems = np.array([1, 0, 0, 0, 0, 0]) 801 initializer = (np.array(0), np.array(1)) 802 fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer) 803 # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13]) 804 ``` 805 """ 806 return scan( 807 fn=fn, 808 elems=elems, 809 initializer=initializer, 810 parallel_iterations=parallel_iterations, 811 back_prop=back_prop, 812 swap_memory=swap_memory, 813 infer_shape=infer_shape, 814 reverse=reverse, 815 name=name) 816 817 818# pylint: disable=invalid-name 819def If(cond, inputs, then_branch, else_branch, name=None): 820 r"""output = Cond(inputs) ? 821 822 then_branch(inputs) : else_branch(inputs). 823 824 Args: 825 cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is 826 converted to a boolean according to the following rule: if the scalar is a 827 numerical value, non-zero means True and zero means False; if the scalar 828 is a string, non-empty means True and empty means False. 829 inputs: A list of input tensors. 830 then_branch: A function takes 'inputs' and returns a list of tensors, whose 831 types are the same as what else_branch returns. 832 else_branch: A function takes 'inputs' and returns a list of tensors. whose 833 types are the same as what then_branch returns. 834 name: A name for the operation (optional). 835 836 Returns: 837 A list of tensors returned by either then_branch(inputs) 838 or else_branch(inputs). 839 """ 840 # pylint: disable=protected-access 841 # Handle the Defun case until users have transitioned to tf.function. Note 842 # that composites may need to be re-packed by the caller. 843 if isinstance(then_branch, function._DefinedFunction): 844 tlist = [_.type for _ in then_branch.definition.signature.output_arg] 845 return gen_functional_ops._if( 846 cond, inputs, tlist, then_branch, else_branch, name=name) 847 848 # We assume that `then_branch` is a ConcreteFunction here. 849 then_out = then_branch.structured_outputs 850 else_out = else_branch.structured_outputs 851 852 # Ensure then/else are the same type of composites to avoid an invalid call 853 # to pack_sequence_as later on. 854 nest.assert_same_structure(then_out, else_out, expand_composites=True) 855 856 tlist = nest.flatten(then_branch.output_dtypes) 857 ret = gen_functional_ops._if( 858 cond, inputs, tlist, then_branch, else_branch, name=name) 859 860 # Re-pack the outputs to restore any CompositeTensors 861 return nest.pack_sequence_as(then_out, ret, expand_composites=True) 862 863 864def Gradient(inputs, f, name=None): 865 r"""Computes the gradient function for function f via backpropagation. 866 867 Args: 868 inputs: A list of tensors of size N + M. 869 f: The function we want to compute the gradient for. The function 'f' must 870 be a numerical function which takes N inputs and produces M outputs. Its 871 gradient function 'g', which is a function taking N + M inputs and 872 produces N outputs. I.e. if we have (y1, y2, ..., yM) = f(x1, x2, ..., 873 xN), then, g is (dL/dx1, dL/dx2, ..., dL/dxN) = g(x1, x2, ..., xN, dL/dy1, 874 dL/dy2, ..., dL/dyM), where L is a scalar-value function of (x1, x2, ..., 875 xN) (e.g., the loss function). dL/dxi is the partial derivative of L with 876 respect to xi. 877 name: A name for the operation (optional). 878 879 Returns: 880 A list of tensors of size N. 881 """ 882 # TODO(zhifengc): Pretty-print the above spec in latex. 883 # TODO(zhfiengc): Needs some math expert to say the comment above better. 884 tlist = [_.type for _ in f.definition.signature.input_arg] 885 return symbolic_gradient(input=inputs, Tout=tlist, f=f, name=name) 886 887 888def _GetInputDtypes(func): 889 """Returns the input dtypes of func, excluding dtypes for captured inputs.""" 890 if isinstance(func, function._DefinedFunction): # pylint: disable=protected-access 891 return func.declared_input_types 892 893 # We assume that `func` is a ConcreteFunction here, but we are not able to 894 # verify since importing eager function library will cause cyclic dependence. 895 # 896 # ConcreteFunction.inputs includes captured inputs. 897 num_non_captured_inputs = len(func.inputs) - len(func.captured_inputs) 898 inputs_without_captured = func.inputs[:num_non_captured_inputs] 899 return [t.dtype for t in inputs_without_captured] 900 901 902def _LoopBodyCaptureWrapper(func): 903 """Returns a wrapper for `func` that handles loop-carried captured inputs.""" 904 905 @function.Defun(*_GetInputDtypes(func), func_name="%s_Wrapper" % func.name) 906 def Wrapper(*args): 907 """A wrapper that handles loop-carried captured inputs.""" 908 result = func(*args) 909 extra_args = tuple(function.get_extra_args()) 910 # Nullary functions return an Operation. Normal functions can't do this 911 # because their return values are converted to Tensors. 912 if isinstance(result, ops.Operation): 913 return extra_args 914 # Unary functions return a single Tensor value. 915 elif not isinstance(result, (list, tuple)): 916 return (result,) + extra_args 917 # N-ary functions return a tuple of Tensors. 918 else: 919 return result + type(result)(extra_args) 920 921 return Wrapper 922 923 924# pylint: disable=invalid-name,protected-access 925def While(input_, cond, body, name=None, hostmem=None): 926 r"""output = input; While (Cond(output)) { output = Body(output) }. 927 928 Args: 929 input_: A list of `Tensor` objects. A list of input tensors whose types are 930 T. 931 cond: . A function takes 'input' and returns a tensor. If the tensor is a 932 scalar of non-boolean, the scalar is converted to a boolean 933 according to the following rule: if the scalar is a numerical value, 934 non-zero means True and zero means False; if the scalar is a string, 935 non-empty means True and empty means False. If the tensor is not a 936 scalar, non-emptiness means True and False otherwise. 937 body: . A function takes a list of tensors and returns another list tensors. 938 Both lists have the same types as specified by T. 939 name: A name for the operation (optional). 940 hostmem: A list of integer. If i is in the list, input[i] is a host memory 941 tensor. 942 943 Raises: 944 ValueError: if `cond` has implicitly captured inputs or if `cond` and `body` 945 have different signatures. 946 947 Returns: 948 A list of `Tensor` objects. Has the same type as `input`. 949 A list of output tensors whose types are T. 950 """ 951 if cond.captured_inputs: 952 raise ValueError("While op 'cond' argument must be a function " 953 "without implicitly captured inputs.") 954 955 cond_input_types = _GetInputDtypes(cond) 956 body_input_types = _GetInputDtypes(body) 957 958 if cond_input_types != body_input_types: 959 raise ValueError( 960 "While op 'cond' and 'body' signatures do not match. %r vs %r" % 961 (cond_input_types, body_input_types)) 962 963 if body.captured_inputs: 964 cond_dtypes = list(body_input_types) + [ 965 t.dtype for t in body.captured_inputs 966 ] 967 968 @function.Defun(*cond_dtypes, func_name="%s_Wrapper" % cond.name) 969 def CondWrapper(*args): 970 """A wrapper that handles loop-carried captured inputs.""" 971 return cond(*args[:len(body_input_types)]) 972 973 ret = gen_functional_ops._while( 974 input_ + body.captured_inputs, 975 CondWrapper, 976 _LoopBodyCaptureWrapper(body), 977 name=name) 978 # Slice off the loop-carried captured inputs. 979 ret = ret[:-len(body.captured_inputs)] 980 else: 981 ret = gen_functional_ops._while(input_, cond, body, name=name) 982 if hostmem: 983 input_attr = attr_value_pb2.AttrValue() 984 input_attr.list.i.extend(hostmem) 985 ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access 986 987 output_attr = attr_value_pb2.AttrValue() 988 output_attr.list.i.extend(hostmem) 989 ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access 990 return ret 991 992 993# b/36459430 994# 995# Ideally, we do not need this rewrite For loop into a While loop. 996# However, today, if a While runs on GPU and the condition returns a 997# boolean, the While kernel crashes. Even if we fix the crash, the 998# bool needs to be copied between GPU and CPU. So, a for loop is much 999# preferred when running on GPU. 1000# 1001# On the other hand, For op has no directly XLA kernel. So, when we run 1002# a for loop, we need to rewrite it using a While op. 1003# 1004# It should be possible and probably better to write a XLA C++ kernel 1005# implementing the logic in _ForUsingWhile. 1006def _ForUsingWhile(start, 1007 limit, 1008 delta, 1009 inputs, 1010 forbody, 1011 name=None, 1012 hostmem=None): 1013 """Helper to implement a For loop using a While.""" 1014 # To support negative delta (e.g., range(100, 0, -3)), we iterate 1015 # over the range(n) and use iter * delta + start as the real 1016 # iteration index. (e.g., for i in range(34): iter = i * (-3) + 1017 # 100). 1018 d = math_ops.abs(delta) 1019 # XLA on TPUs doesn't support integer division 1020 n = math_ops.cast( 1021 math_ops.cast((math_ops.abs(limit - start) + d - 1), dtypes.float32) / 1022 math_ops.cast(d, dtypes.float32), dtypes.int32) 1023 1024 # Carried loop variables ("extra_args") are implicitly added to the input list 1025 # of the WhileBody function. WhileCond does not call forbody, and so does not 1026 # depend on any of forbody's extra_args. Since WhileCond and WhileBody 1027 # must have identical inputs, we have to augment the cond signature to take 1028 # the same types as the carried loop variables. 1029 body_sig = [dtypes.int32] * 4 + list(forbody.declared_input_types)[1:] 1030 1031 cond_name = "%s_Cond" % forbody.name 1032 1033 @function.Defun(*body_sig, func_name=cond_name) 1034 def WhileCond(i, n, *args): 1035 del args 1036 return i < n 1037 1038 body_name = "%s_Body" % forbody.name 1039 1040 @function.Defun(*body_sig, func_name=body_name) 1041 def WhileBody(i, n, start, delta, *args): 1042 """A While wrapper for forbody that handles loop-carried captured inputs.""" 1043 for_result = forbody(start + i * delta, *args) 1044 # Nullary functions return an Operation. Normal functions can't do this 1045 # because their return values are converted to Tensors. 1046 if isinstance(for_result, ops.Operation): 1047 for_result = () 1048 # Unary functions return a single Tensor value. 1049 elif isinstance(for_result, ops.Tensor): 1050 for_result = (for_result,) 1051 return (i + 1, n, start, delta) + tuple(for_result) 1052 1053 if hostmem is not None: 1054 hostmem = [0, 1, 2, 3] + [(4 + _) for _ in hostmem] 1055 else: 1056 hostmem = [0, 1, 2, 3] 1057 1058 results = While( 1059 input_=[0, n, start, delta] + inputs, 1060 cond=WhileCond, 1061 body=WhileBody, 1062 name=name, 1063 hostmem=hostmem) 1064 # Slice off the loop-carried captured inputs. 1065 return list(results[4:len(results)]) 1066 1067 1068def For(start, 1069 limit, 1070 delta, 1071 inputs, 1072 body, 1073 name=None, 1074 hostmem=None, 1075 rewrite_with_while=None): 1076 r"""out = input; for i in range(start, limit, delta) out = body(i, out). 1077 1078 Args: 1079 start: A `Tensor` of type `int32`. 1080 limit: A `Tensor` of type `int32`. 1081 delta: A `Tensor` of type `int32`. 1082 inputs: A list of `Tensor` objects. A list of input tensors whose types are 1083 T. 1084 body: A function takes a list of tensors and returns another list of 1085 tensors. Both lists have the same types as (int32, T...). 1086 name: A name for the operation (optional). 1087 hostmem: A list of integer. If i is in the list, inputs[i] is a host memory 1088 tensor. In other words, (i+1)-th argument of the body function is 1089 expecting a host memory. 1090 rewrite_with_while: If True, using While op to implement the For. 1091 1092 Returns: 1093 A list of `Tensor` objects. Has the same type as `input`. 1094 A list of output tensors whose types are T. 1095 """ 1096 if rewrite_with_while: 1097 return _ForUsingWhile(start, limit, delta, inputs, body, name, hostmem) 1098 if body.captured_inputs: 1099 ret = gen_functional_ops._for( 1100 start, 1101 limit, 1102 delta, 1103 inputs + body.captured_inputs, 1104 _LoopBodyCaptureWrapper(body), 1105 name=name) 1106 # Slice off the loop-carried captured inputs. 1107 ret = ret[:-len(body.captured_inputs)] 1108 else: 1109 ret = gen_functional_ops._for(start, limit, delta, inputs, body, name=name) 1110 if hostmem: 1111 num_for_params = 3 # start/limit/delta 1112 1113 input_attr = attr_value_pb2.AttrValue() 1114 input_attr.list.i.extend([num_for_params + i for i in hostmem]) 1115 ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access 1116 1117 output_attr = attr_value_pb2.AttrValue() 1118 output_attr.list.i.extend(hostmem) 1119 ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access 1120 return ret 1121 1122 1123# pylint: enable=invalid-name,protected-access 1124 1125 1126def partitioned_call(args, 1127 f, 1128 tout=None, 1129 executing_eagerly=None, 1130 config=None, 1131 executor_type=None): 1132 """Executes a function while respecting device annotations. 1133 1134 Currently, only those functions that execute within the same address space 1135 can be executed. 1136 1137 Args: 1138 args: The arguments of the function, including captured inputs. 1139 f: The function to execute; an instance of `_DefinedFunction` or 1140 `_EagerDefinedFunction`. 1141 tout: a list containing the output dtypes enums; if `None`, inferred from 1142 the signature of `f`. 1143 executing_eagerly: (Optional) A boolean indicating whether the context is 1144 executing eagerly. If `None`, fetched from the global context. 1145 config: (Optional) A `tensorflow::ConfigProto` proto, serialized. If `None`, 1146 all optimizations are disabled. Currently only handled for eager defined 1147 functions. 1148 executor_type: (Optional) A string for the name of the executor to be used 1149 in the function call. If not set, or set to an empty string, the default 1150 tensorflow executor will be used. 1151 1152 Returns: 1153 The list of `Tensor`s returned by invoking `f(args)`. If the function does 1154 not return anything, then returns `None` if eager execution is enabled, or 1155 the `Operation` if not. 1156 """ 1157 1158 if tout is None: 1159 tout = tuple(x.type for x in f.definition.signature.output_arg) 1160 1161 if executing_eagerly is None: 1162 executing_eagerly = context.executing_eagerly() 1163 1164 if config is None: 1165 config = function_utils.get_disabled_rewriter_config() 1166 1167 if executor_type is None: 1168 executor_type = "" 1169 1170 if executing_eagerly: 1171 if f.stateful_ops: 1172 outputs = gen_functional_ops.stateful_partitioned_call( 1173 args=args, 1174 Tout=tout, 1175 f=f, 1176 config_proto=config, 1177 executor_type=executor_type) 1178 else: 1179 outputs = gen_functional_ops.partitioned_call( 1180 args=args, 1181 Tout=tout, 1182 f=f, 1183 config_proto=config, 1184 executor_type=executor_type) 1185 return outputs if outputs else None 1186 1187 # The generated binding returns an empty list for functions that don't 1188 # return any Tensors, hence the need to use `create_op` directly. 1189 args = [ops.convert_to_tensor(x) for x in args] 1190 tin_attr = attr_value_pb2.AttrValue( 1191 list=attr_value_pb2.AttrValue.ListValue( 1192 type=[x.dtype.as_datatype_enum for x in args])) 1193 tout_attr = attr_value_pb2.AttrValue( 1194 list=attr_value_pb2.AttrValue.ListValue(type=tout)) 1195 func_attr = attr_value_pb2.AttrValue( 1196 func=attr_value_pb2.NameAttrList(name=f.name)) 1197 executor_type_attr = attr_value_pb2.AttrValue( 1198 s=compat.as_bytes(executor_type)) 1199 1200 # When running in graph mode, the graph and function graphs are optimized 1201 # (i.e. run through grappler) per the session options, so we can disable any 1202 # eager-specific rewriting. 1203 config_proto = attr_value_pb2.AttrValue(s=config) 1204 1205 graph = ops.get_default_graph() 1206 f.add_to_graph(graph) 1207 op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall" 1208 1209 # Propagate the attribute indicating the need to compile from function to the 1210 # call itself. 1211 xla_compile_attr = "_XlaMustCompile" 1212 op_attrs = { 1213 "Tin": tin_attr, 1214 "Tout": tout_attr, 1215 "f": func_attr, 1216 "config_proto": config_proto, 1217 "executor_type": executor_type_attr, 1218 } 1219 if xla_compile_attr in f.definition.attr: 1220 op_attrs[xla_compile_attr] = f.definition.attr[xla_compile_attr] 1221 op = graph.create_op(op_name, args, tout, name=op_name, attrs=op_attrs) 1222 outputs = op.outputs 1223 if hasattr(f, "graph"): 1224 _set_read_only_resource_inputs_attr(op, f.graph) 1225 if hasattr(f.graph, "collective_manager_ids_used"): 1226 ops.set_int_list_attr(op, acd.COLLECTIVE_MANAGER_IDS, 1227 f.graph.collective_manager_ids_used) 1228 return outputs if outputs else op 1229 1230 1231def _set_read_only_resource_inputs_attr(op, func_graph): 1232 """Sets the list of resource inputs which are read-only. 1233 1234 This is used by AutomaticControlDependencies. 1235 1236 Args: 1237 op: PartitionedCall Operation. 1238 func_graph: FuncGraph. 1239 """ 1240 read_only_indices = acd.get_read_only_resource_input_indices_graph(func_graph) 1241 ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR, 1242 read_only_indices) 1243