1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Functional operations.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.core.framework import attr_value_pb2 23from tensorflow.python.eager import context 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import function 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import gen_functional_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import tensor_array_ops 34from tensorflow.python.ops import variable_scope as vs 35# pylint: disable=unused-import 36from tensorflow.python.ops.gen_functional_ops import remote_call 37# pylint: enable=unused-import 38from tensorflow.python.ops.gen_functional_ops import symbolic_gradient 39from tensorflow.python.util import compat 40from tensorflow.python.util import function_utils 41from tensorflow.python.util import nest 42from tensorflow.python.util.tf_export import tf_export 43 44 45# TODO(yuanbyu, mrry): Handle stride to support sliding windows. 46@tf_export("foldl") 47def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, 48 swap_memory=False, name=None): 49 """foldl on the list of tensors unpacked from `elems` on dimension 0. 50 51 This foldl operator repeatedly applies the callable `fn` to a sequence 52 of elements from first to last. The elements are made of the tensors 53 unpacked from `elems` on dimension 0. The callable fn takes two tensors as 54 arguments. The first argument is the accumulated value computed from the 55 preceding invocation of fn. If `initializer` is None, `elems` must contain 56 at least one element, and its first element is used as the initializer. 57 58 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 59 of the result tensor is fn(initializer, values[0]).shape`. 60 61 This method also allows multi-arity `elems` and output of `fn`. If `elems` 62 is a (possibly nested) list or tuple of tensors, then each of these tensors 63 must have a matching first (unpack) dimension. The signature of `fn` may 64 match the structure of `elems`. That is, if `elems` is 65 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: 66 `fn = lambda (t1, [t2, t3, [t4, t5]]):`. 67 68 Args: 69 fn: The callable to be performed. 70 elems: A tensor or (possibly nested) sequence of tensors, each of which 71 will be unpacked along their first dimension. The nested sequence 72 of the resulting slices will be the first argument to `fn`. 73 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 74 as the initial value for the accumulator. 75 parallel_iterations: (optional) The number of iterations allowed to run 76 in parallel. 77 back_prop: (optional) True enables support for back propagation. 78 swap_memory: (optional) True enables GPU-CPU memory swapping. 79 name: (optional) Name prefix for the returned tensors. 80 81 Returns: 82 A tensor or (possibly nested) sequence of tensors, resulting from applying 83 `fn` consecutively to the list of tensors unpacked from `elems`, from first 84 to last. 85 86 Raises: 87 TypeError: if `fn` is not callable. 88 89 Example: 90 ```python 91 elems = tf.constant([1, 2, 3, 4, 5, 6]) 92 sum = foldl(lambda a, x: a + x, elems) 93 # sum == 21 94 ``` 95 """ 96 if not callable(fn): 97 raise TypeError("fn must be callable.") 98 99 def create_ta(elem): 100 return tensor_array_ops.TensorArray( 101 dtype=elem.dtype, size=n, dynamic_size=False, 102 infer_shape=True).unstack(elem) 103 104 in_graph_mode = not context.executing_eagerly() 105 with ops.name_scope(name, "foldl", [elems]): 106 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 107 # supported in Eager 108 if in_graph_mode: 109 # Any get_variable calls in fn will cache the first call locally 110 # and not issue repeated network I/O requests for each iteration. 111 varscope = vs.get_variable_scope() 112 varscope_caching_device_was_none = False 113 if varscope.caching_device is None: 114 # TODO(ebrevdo): Change to using colocate_with here and in other 115 # methods. 116 varscope.set_caching_device(lambda op: op.device) 117 varscope_caching_device_was_none = True 118 119 # Convert elems to tensor array. n may be known statically. 120 elems_flat = [ 121 ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems) 122 ] 123 n = (tensor_shape.dimension_value(elems_flat[0].shape[0]) 124 or array_ops.shape(elems_flat[0])[0]) 125 126 elems_ta = nest.map_structure(create_ta, elems) 127 128 if initializer is None: 129 a = nest.map_structure(lambda elem: elem.read(0), elems_ta) 130 i = constant_op.constant(1) 131 else: 132 a = initializer 133 i = constant_op.constant(0) 134 135 def compute(i, a): 136 elem_i = nest.map_structure(lambda elem: elem.read(i), elems_ta) 137 a = fn(a, elem_i) 138 return [i + 1, a] 139 140 _, r_a = control_flow_ops.while_loop( 141 lambda i, a: i < n, compute, [i, a], 142 parallel_iterations=parallel_iterations, 143 back_prop=back_prop, 144 swap_memory=swap_memory, 145 maximum_iterations=n) 146 147 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 148 # supported in Eager 149 if in_graph_mode and varscope_caching_device_was_none: 150 varscope.set_caching_device(None) 151 152 return r_a 153 154 155@tf_export("foldr") 156def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, 157 swap_memory=False, name=None): 158 """foldr on the list of tensors unpacked from `elems` on dimension 0. 159 160 This foldr operator repeatedly applies the callable `fn` to a sequence 161 of elements from last to first. The elements are made of the tensors 162 unpacked from `elems`. The callable fn takes two tensors as arguments. 163 The first argument is the accumulated value computed from the preceding 164 invocation of fn. If `initializer` is None, `elems` must contain at least 165 one element, and its first element is used as the initializer. 166 167 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 168 of the result tensor is `fn(initializer, values[0]).shape`. 169 170 This method also allows multi-arity `elems` and output of `fn`. If `elems` 171 is a (possibly nested) list or tuple of tensors, then each of these tensors 172 must have a matching first (unpack) dimension. The signature of `fn` may 173 match the structure of `elems`. That is, if `elems` is 174 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: 175 `fn = lambda (t1, [t2, t3, [t4, t5]]):`. 176 177 Args: 178 fn: The callable to be performed. 179 elems: A tensor or (possibly nested) sequence of tensors, each of which 180 will be unpacked along their first dimension. The nested sequence 181 of the resulting slices will be the first argument to `fn`. 182 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 183 as the initial value for the accumulator. 184 parallel_iterations: (optional) The number of iterations allowed to run 185 in parallel. 186 back_prop: (optional) True enables support for back propagation. 187 swap_memory: (optional) True enables GPU-CPU memory swapping. 188 name: (optional) Name prefix for the returned tensors. 189 190 Returns: 191 A tensor or (possibly nested) sequence of tensors, resulting from applying 192 `fn` consecutively to the list of tensors unpacked from `elems`, from last 193 to first. 194 195 Raises: 196 TypeError: if `fn` is not callable. 197 198 Example: 199 ```python 200 elems = [1, 2, 3, 4, 5, 6] 201 sum = foldr(lambda a, x: a + x, elems) 202 # sum == 21 203 ``` 204 """ 205 if not callable(fn): 206 raise TypeError("fn must be callable.") 207 208 def create_ta(elem): 209 return tensor_array_ops.TensorArray( 210 dtype=elem.dtype, size=n, dynamic_size=False, 211 infer_shape=True).unstack(elem) 212 213 in_graph_mode = not context.executing_eagerly() 214 with ops.name_scope(name, "foldr", [elems]): 215 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 216 # supported in Eager 217 if in_graph_mode: 218 # Any get_variable calls in fn will cache the first call locally and not 219 # issue repeated network I/O requests for each iteration. 220 varscope = vs.get_variable_scope() 221 varscope_caching_device_was_none = False 222 if varscope.caching_device is None: 223 # TODO(ebrevdo): Change to using colocate_with here and in other 224 # methods. 225 varscope.set_caching_device(lambda op: op.device) 226 varscope_caching_device_was_none = True 227 228 # Convert elems to tensor array. n may be known statically. 229 elems_flat = [ 230 ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems) 231 ] 232 n = (tensor_shape.dimension_value(elems_flat[0].shape[0]) 233 or array_ops.shape(elems_flat[0])[0]) 234 235 elems_ta = nest.map_structure(create_ta, elems) 236 237 if initializer is None: 238 i = n - 1 239 a = nest.map_structure(lambda elem: elem.read(i), elems_ta) 240 else: 241 i = n 242 a = initializer 243 244 def compute(i, a): 245 i -= 1 246 elem = nest.map_structure(lambda elem: elem.read(i), elems_ta) 247 a_out = fn(a, elem) 248 return [i, a_out] 249 250 _, r_a = control_flow_ops.while_loop( 251 lambda i, a: i > 0, 252 compute, [i, a], 253 parallel_iterations=parallel_iterations, 254 back_prop=back_prop, 255 swap_memory=swap_memory, 256 maximum_iterations=n) 257 258 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 259 # supported in Eager 260 if in_graph_mode and varscope_caching_device_was_none: 261 varscope.set_caching_device(None) 262 263 return r_a 264 265 266@tf_export("scan") 267def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, 268 swap_memory=False, infer_shape=True, reverse=False, name=None): 269 """scan on the list of tensors unpacked from `elems` on dimension 0. 270 271 The simplest version of `scan` repeatedly applies the callable `fn` to a 272 sequence of elements from first to last. The elements are made of the tensors 273 unpacked from `elems` on dimension 0. The callable fn takes two tensors as 274 arguments. The first argument is the accumulated value computed from the 275 preceding invocation of fn. If `initializer` is None, `elems` must contain 276 at least one element, and its first element is used as the initializer. 277 278 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 279 of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`. 280 If reverse=True, it's fn(initializer, values[-1]).shape. 281 282 This method also allows multi-arity `elems` and accumulator. If `elems` 283 is a (possibly nested) list or tuple of tensors, then each of these tensors 284 must have a matching first (unpack) dimension. The second argument of 285 `fn` must match the structure of `elems`. 286 287 If no `initializer` is provided, the output structure and dtypes of `fn` 288 are assumed to be the same as its input; and in this case, the first 289 argument of `fn` must match the structure of `elems`. 290 291 If an `initializer` is provided, then the output of `fn` must have the same 292 structure as `initializer`; and the first argument of `fn` must match 293 this structure. 294 295 For example, if `elems` is `(t1, [t2, t3])` and `initializer` is 296 `[i1, i2]` then an appropriate signature for `fn` in `python2` is: 297 `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list, 298 `[acc_n1, acc_n2]`. An alternative correct signature for `fn`, and the 299 one that works in `python3`, is: 300 `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples. 301 302 Args: 303 fn: The callable to be performed. It accepts two arguments. The first 304 will have the same structure as `initializer` if one is provided, 305 otherwise it will have the same structure as `elems`. The second 306 will have the same (possibly nested) structure as `elems`. Its output 307 must have the same structure as `initializer` if one is provided, 308 otherwise it must have the same structure as `elems`. 309 elems: A tensor or (possibly nested) sequence of tensors, each of which 310 will be unpacked along their first dimension. The nested sequence 311 of the resulting slices will be the first argument to `fn`. 312 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 313 initial value for the accumulator, and the expected output type of `fn`. 314 parallel_iterations: (optional) The number of iterations allowed to run 315 in parallel. 316 back_prop: (optional) True enables support for back propagation. 317 swap_memory: (optional) True enables GPU-CPU memory swapping. 318 infer_shape: (optional) False disables tests for consistent output shapes. 319 reverse: (optional) True scans the tensor last to first (instead of first 320 to last). 321 name: (optional) Name prefix for the returned tensors. 322 323 Returns: 324 A tensor or (possibly nested) sequence of tensors. Each tensor packs the 325 results of applying `fn` to tensors unpacked from `elems` along the first 326 dimension, and the previous accumulator value(s), from first to last (or 327 last to first, if `reverse=True`). 328 329 Raises: 330 TypeError: if `fn` is not callable or the structure of the output of 331 `fn` and `initializer` do not match. 332 ValueError: if the lengths of the output of `fn` and `initializer` 333 do not match. 334 335 Examples: 336 ```python 337 elems = np.array([1, 2, 3, 4, 5, 6]) 338 sum = scan(lambda a, x: a + x, elems) 339 # sum == [1, 3, 6, 10, 15, 21] 340 sum = scan(lambda a, x: a + x, elems, reverse=True) 341 # sum == [22, 21, 18, 15, 11, 6] 342 ``` 343 344 ```python 345 elems = np.array([1, 2, 3, 4, 5, 6]) 346 initializer = np.array(0) 347 sum_one = scan( 348 lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer) 349 # sum_one == [1, 2, 3, 4, 5, 6] 350 ``` 351 352 ```python 353 elems = np.array([1, 0, 0, 0, 0, 0]) 354 initializer = (np.array(0), np.array(1)) 355 fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer) 356 # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13]) 357 ``` 358 """ 359 if not callable(fn): 360 raise TypeError("fn must be callable.") 361 362 input_is_sequence = nest.is_sequence(elems) 363 input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x] 364 def input_pack(x): 365 return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0] 366 367 if initializer is None: 368 output_is_sequence = input_is_sequence 369 output_flatten = input_flatten 370 output_pack = input_pack 371 else: 372 output_is_sequence = nest.is_sequence(initializer) 373 output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x] 374 def output_pack(x): 375 return (nest.pack_sequence_as(initializer, x) 376 if output_is_sequence else x[0]) 377 378 elems_flat = input_flatten(elems) 379 380 in_graph_mode = not context.executing_eagerly() 381 with ops.name_scope(name, "scan", elems_flat): 382 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 383 # supported in Eager 384 if in_graph_mode: 385 # Any get_variable calls in fn will cache the first call locally 386 # and not issue repeated network I/O requests for each iteration. 387 varscope = vs.get_variable_scope() 388 varscope_caching_device_was_none = False 389 if varscope.caching_device is None: 390 # TODO(ebrevdo): Change to using colocate_with here and in other 391 # methods. 392 varscope.set_caching_device(lambda op: op.device) 393 varscope_caching_device_was_none = True 394 395 # Convert elems to tensor array. 396 elems_flat = [ 397 ops.convert_to_tensor(elem, name="elem") for elem in elems_flat] 398 399 # Convert elems to tensor array. n may be known statically. 400 n = tensor_shape.dimension_value(elems_flat[0].shape[0]) 401 if n is None: 402 n = array_ops.shape(elems_flat[0])[0] 403 404 # TensorArrays are always flat 405 elems_ta = [ 406 tensor_array_ops.TensorArray(dtype=elem.dtype, size=n, 407 dynamic_size=False, 408 element_shape=elem.shape[1:], 409 infer_shape=True) 410 for elem in elems_flat] 411 # Unpack elements 412 elems_ta = [ 413 elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat)] 414 415 if initializer is None: 416 a_flat = [elem.read(n - 1 if reverse else 0) for elem in elems_ta] 417 i = constant_op.constant(1) 418 else: 419 initializer_flat = output_flatten(initializer) 420 a_flat = [ops.convert_to_tensor(init) for init in initializer_flat] 421 i = constant_op.constant(0) 422 423 # Create a tensor array to store the intermediate values. 424 accs_ta = [ 425 tensor_array_ops.TensorArray( 426 dtype=init.dtype, size=n, 427 element_shape=init.shape if infer_shape else None, 428 dynamic_size=False, 429 infer_shape=infer_shape) 430 for init in a_flat] 431 432 if initializer is None: 433 accs_ta = [acc_ta.write(n - 1 if reverse else 0, a) 434 for (acc_ta, a) in zip(accs_ta, a_flat)] 435 436 def compute(i, a_flat, tas): 437 """The loop body of scan. 438 439 Args: 440 i: the loop counter. 441 a_flat: the accumulator value(s), flattened. 442 tas: the output accumulator TensorArray(s), flattened. 443 444 Returns: 445 [i + 1, a_flat, tas]: the updated counter + new accumulator values + 446 updated TensorArrays 447 448 Raises: 449 TypeError: if initializer and fn() output structure do not match 450 ValueType: if initializer and fn() output lengths do not match 451 """ 452 packed_elems = input_pack([elem_ta.read(i) for elem_ta in elems_ta]) 453 packed_a = output_pack(a_flat) 454 a_out = fn(packed_a, packed_elems) 455 nest.assert_same_structure( 456 elems if initializer is None else initializer, a_out) 457 flat_a_out = output_flatten(a_out) 458 tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)] 459 if reverse: 460 next_i = i - 1 461 else: 462 next_i = i + 1 463 return (next_i, flat_a_out, tas) 464 465 if reverse: 466 initial_i = n - 1 - i 467 condition = lambda i, _1, _2: i >= 0 468 else: 469 initial_i = i 470 condition = lambda i, _1, _2: i < n 471 _, _, r_a = control_flow_ops.while_loop( 472 condition, compute, (initial_i, a_flat, accs_ta), 473 parallel_iterations=parallel_iterations, 474 back_prop=back_prop, swap_memory=swap_memory, 475 maximum_iterations=n) 476 477 results_flat = [r.stack() for r in r_a] 478 479 n_static = tensor_shape.Dimension(tensor_shape.dimension_value( 480 elems_flat[0].get_shape().with_rank_at_least(1)[0])) 481 for elem in elems_flat[1:]: 482 n_static.merge_with(tensor_shape.Dimension(tensor_shape.dimension_value( 483 elem.get_shape().with_rank_at_least(1)[0]))) 484 for r in results_flat: 485 r.set_shape(tensor_shape.TensorShape(n_static).concatenate( 486 r.get_shape()[1:])) 487 488 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 489 # supported in Eager 490 if in_graph_mode and varscope_caching_device_was_none: 491 varscope.set_caching_device(None) 492 493 return output_pack(results_flat) 494 495 496# pylint: disable=invalid-name 497def If(cond, inputs, then_branch, else_branch, name=None): 498 r"""output = Cond(inputs) ? then_branch(inputs) : else_branch(inputs). 499 500 Args: 501 cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is 502 converted to a boolean according to the following rule: if the 503 scalar is a numerical value, non-zero means True and zero means 504 False; if the scalar is a string, non-empty means True and empty 505 means False. 506 inputs: A list of input tensors. 507 then_branch: A function takes 'inputs' and returns a list of tensors, 508 whose types are the same as what else_branch returns. 509 else_branch: A function takes 'inputs' and returns a list of tensors. 510 whose types are the same as what then_branch returns. 511 name: A name for the operation (optional). 512 513 Returns: 514 A list of tensors returned by either then_branch(inputs) 515 or else_branch(inputs). 516 """ 517 # pylint: disable=protected-access 518 return gen_functional_ops._if( 519 cond, 520 inputs, [_.type for _ in then_branch.definition.signature.output_arg], 521 then_branch, 522 else_branch, 523 name=name) 524 525 526def Gradient(inputs, f, name=None): 527 r"""Computes the gradient function for function f via backpropagation. 528 529 Args: 530 inputs: A list of tensors of size N + M. 531 f: The function we want to compute the gradient for. 532 533 The function 'f' must be a numerical function which takes N inputs and 534 produces M outputs. Its gradient function 'g', which is a function 535 taking N + M inputs and produces N outputs. 536 537 I.e. if we have 538 (y1, y2, ..., yM) = f(x1, x2, ..., xN), 539 then, g is 540 (dL/dx1, dL/dx2, ..., dL/dxN) = g(x1, x2, ..., xN, 541 dL/dy1, dL/dy2, ..., dL/dyM), 542 543 where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the 544 loss function). dL/dxi is the partial derivative of L with respect 545 to xi. 546 547 name: A name for the operation (optional). 548 549 Returns: 550 A list of tensors of size N. 551 """ 552 # TODO(zhifengc): Pretty-print the above spec in latex. 553 # TODO(zhfiengc): Needs some math expert to say the comment above better. 554 tlist = [_.type for _ in f.definition.signature.input_arg] 555 return symbolic_gradient(input=inputs, Tout=tlist, f=f, name=name) 556 557 558def _LoopBodyCaptureWrapper(func): 559 """Returns a wrapper for `func` that handles loop-carried captured inputs.""" 560 561 @function.Defun( 562 *func.declared_input_types, func_name="%s_Wrapper" % func.name) 563 def Wrapper(*args): 564 """A wrapper that handles loop-carried captured inputs.""" 565 result = func(*args) 566 extra_args = tuple(function.get_extra_args()) 567 # Nullary functions return an Operation. Normal functions can't do this 568 # because their return values are converted to Tensors. 569 if isinstance(result, ops.Operation): 570 return extra_args 571 # Unary functions return a single Tensor value. 572 elif not isinstance(result, tuple): 573 return (result,) + extra_args 574 # N-ary functions return a tuple of Tensors. 575 else: 576 return result + extra_args 577 578 return Wrapper 579 580 581# pylint: disable=invalid-name,protected-access 582def While(input_, cond, body, name=None, hostmem=None): 583 r"""output = input; While (Cond(output)) { output = Body(output) }. 584 585 Args: 586 input_: A list of `Tensor` objects. 587 A list of input tensors whose types are T. 588 cond: . A function takes 'input' and returns a tensor. If the tensor is 589 a scalar of non-boolean, the scalar is converted to a boolean 590 according to the following rule: if the scalar is a numerical 591 value, non-zero means True and zero means False; if the scalar is 592 a string, non-empty means True and empty means False. If the 593 tensor is not a scalar, non-emptiness means True and False 594 otherwise. 595 body: . A function takes a list of tensors and returns another 596 list tensors. Both lists have the same types as specified 597 by T. 598 name: A name for the operation (optional). 599 hostmem: A list of integer. If i is in the list, input[i] is a 600 host memory tensor. 601 602 Raises: 603 ValueError: if `cond` has implicitly captured inputs or if `cond` and `body` 604 have different signatures. 605 606 Returns: 607 A list of `Tensor` objects. Has the same type as `input`. 608 A list of output tensors whose types are T. 609 """ 610 if cond.captured_inputs: 611 raise ValueError("While op 'cond' argument must be a function " 612 "without implicitly captured inputs.") 613 614 if cond.declared_input_types != body.declared_input_types: 615 raise ValueError( 616 "While op 'cond' and 'body' signatures do not match. %r vs %r" % 617 (cond.declared_input_types, body.declared_input_types)) 618 619 if body.captured_inputs: 620 cond_dtypes = list( 621 body.declared_input_types) + [t.dtype for t in body.captured_inputs] 622 623 @function.Defun(*cond_dtypes, func_name="%s_Wrapper" % cond.name) 624 def CondWrapper(*args): 625 """A wrapper that handles loop-carried captured inputs.""" 626 return cond(*args[:len(body.declared_input_types)]) 627 628 ret = gen_functional_ops._while( 629 input_ + body.captured_inputs, 630 CondWrapper, 631 _LoopBodyCaptureWrapper(body), 632 name=name) 633 # Slice off the loop-carried captured inputs. 634 ret = ret[:-len(body.captured_inputs)] 635 else: 636 ret = gen_functional_ops._while(input_, cond, body, name=name) 637 if hostmem: 638 input_attr = attr_value_pb2.AttrValue() 639 input_attr.list.i.extend(hostmem) 640 ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access 641 642 output_attr = attr_value_pb2.AttrValue() 643 output_attr.list.i.extend(hostmem) 644 ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access 645 return ret 646 647 648# b/36459430 649# 650# Ideally, we do not need this rewrite For loop into a While loop. 651# However, today, if a While runs on GPU and the condition returns a 652# boolean, the While kernel crashes. Even if we fix the crash, the 653# bool needs to be copied between GPU and CPU. So, a for loop is much 654# preferred when running on GPU. 655# 656# On the other hand, For op has no directly XLA kernel. So, when we run 657# a for loop, we need to rewrite it using a While op. 658# 659# It should be possible and probably better to write a XLA C++ kernel 660# implementing the logic in _ForUsingWhile. 661def _ForUsingWhile(start, 662 limit, 663 delta, 664 inputs, 665 forbody, 666 name=None, 667 hostmem=None): 668 """Helper to implement a For loop using a While.""" 669 # To support negative delta (e.g., range(100, 0, -3)), we iterate 670 # over the range(n) and use iter * delta + start as the real 671 # iteration index. (e.g., for i in range(34): iter = i * (-3) + 672 # 100). 673 d = math_ops.abs(delta) 674 # XLA on TPUs doesn't support integer division 675 n = math_ops.cast( 676 math_ops.cast((math_ops.abs(limit - start) + d - 1), dtypes.float32) / 677 math_ops.cast(d, dtypes.float32), dtypes.int32) 678 679 # Carried loop variables ("extra_args") are implicitly added to the input list 680 # of the WhileBody function. WhileCond does not call forbody, and so does not 681 # depend on any of forbody's extra_args. Since WhileCond and WhileBody 682 # must have identical inputs, we have to augment the cond signature to take 683 # the same types as the carried loop variables. 684 body_sig = [dtypes.int32] * 4 + list(forbody.declared_input_types)[1:] 685 686 cond_name = "%s_Cond" % forbody.name 687 688 @function.Defun(*body_sig, func_name=cond_name) 689 def WhileCond(i, n, *args): 690 del args 691 return i < n 692 693 body_name = "%s_Body" % forbody.name 694 695 @function.Defun(*body_sig, func_name=body_name) 696 def WhileBody(i, n, start, delta, *args): 697 """A While wrapper for forbody that handles loop-carried captured inputs.""" 698 for_result = forbody(start + i * delta, *args) 699 # Nullary functions return an Operation. Normal functions can't do this 700 # because their return values are converted to Tensors. 701 if isinstance(for_result, ops.Operation): 702 for_result = () 703 # Unary functions return a single Tensor value. 704 elif isinstance(for_result, ops.Tensor): 705 for_result = (for_result,) 706 return (i + 1, n, start, delta) + tuple(for_result) 707 708 if hostmem is not None: 709 hostmem = [0, 1, 2, 3] + [(4 + _) for _ in hostmem] 710 else: 711 hostmem = [0, 1, 2, 3] 712 713 results = While( 714 input_=[0, n, start, delta] + inputs, 715 cond=WhileCond, 716 body=WhileBody, 717 name=name, 718 hostmem=hostmem) 719 # Slice off the loop-carried captured inputs. 720 return list(results[4:len(results)]) 721 722 723def For(start, 724 limit, 725 delta, 726 inputs, 727 body, 728 name=None, 729 hostmem=None, 730 rewrite_with_while=None): 731 r"""out = input; for i in range(start, limit, delta) out = body(i, out). 732 733 Args: 734 start: A `Tensor` of type `int32`. 735 limit: A `Tensor` of type `int32`. 736 delta: A `Tensor` of type `int32`. 737 inputs: A list of `Tensor` objects. 738 A list of input tensors whose types are T. 739 body: A function takes a list of tensors and returns another 740 list of tensors. Both lists have the same types as (int32, T...). 741 name: A name for the operation (optional). 742 hostmem: A list of integer. If i is in the list, inputs[i] is a 743 host memory tensor. In other words, (i+1)-th argument of the body 744 function is expecting a host memory. 745 rewrite_with_while: If True, using While op to implement the For. 746 747 Returns: 748 A list of `Tensor` objects. Has the same type as `input`. 749 A list of output tensors whose types are T. 750 """ 751 if rewrite_with_while: 752 return _ForUsingWhile(start, limit, delta, inputs, body, name, hostmem) 753 if body.captured_inputs: 754 ret = gen_functional_ops._for( 755 start, 756 limit, 757 delta, 758 inputs + body.captured_inputs, 759 _LoopBodyCaptureWrapper(body), 760 name=name) 761 # Slice off the loop-carried captured inputs. 762 ret = ret[:-len(body.captured_inputs)] 763 else: 764 ret = gen_functional_ops._for(start, limit, delta, inputs, body, name=name) 765 if hostmem: 766 num_for_params = 3 # start/limit/delta 767 768 input_attr = attr_value_pb2.AttrValue() 769 input_attr.list.i.extend([num_for_params + i for i in hostmem]) 770 ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access 771 772 output_attr = attr_value_pb2.AttrValue() 773 output_attr.list.i.extend(hostmem) 774 ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access 775 return ret 776# pylint: enable=invalid-name,protected-access 777 778 779def partitioned_call(args, f, tout=None, executing_eagerly=None, config=None, 780 executor_type=None): 781 """Executes a function while respecting device annotations. 782 783 Currently, only those functions that execute within the same address space 784 can be executed. 785 786 Args: 787 args: The arguments of the function, including captured inputs. 788 f: The function to execute; an instance of `_DefinedFunction` or 789 `_EagerDefinedFunction`. 790 tout: a list containing the output dtypes enums; if `None`, inferred from 791 the signature of `f`. 792 executing_eagerly: (Optional) A boolean indicating whether the context is 793 executing eagerly. If `None`, fetched from the global context. 794 config: (Optional) A `tensorflow::ConfigProto` proto, serialized. If 795 `None`, all optimizations are disabled. Currently only handled for eager 796 defined functions. 797 executor_type: (Optional) A string for the name of the executor to be used 798 in the function call. If not set, or set to an empty string, the default 799 tensorflow executor will be used. 800 801 Returns: 802 The list of `Tensor`s returned by invoking `f(args)`. If the function does 803 not return anything, then returns `None` if eager execution is enabled, or 804 the `Operation` if not. 805 """ 806 807 if tout is None: 808 tout = tuple(x.type for x in f.definition.signature.output_arg) 809 810 if executing_eagerly is None: 811 executing_eagerly = context.executing_eagerly() 812 813 if config is None: 814 config = function_utils.get_disabled_rewriter_config() 815 816 if executor_type is None: 817 executor_type = "" 818 819 if executing_eagerly or len(tout): 820 if f.stateful_ops: 821 outputs = gen_functional_ops.stateful_partitioned_call( 822 args=args, Tout=tout, f=f, config_proto=config, 823 executor_type=executor_type) 824 else: 825 outputs = gen_functional_ops.partitioned_call( 826 args=args, Tout=tout, f=f, config_proto=config, 827 executor_type=executor_type) 828 return outputs if outputs else None 829 830 # The generated binding returns an empty list for functions that don't 831 # return any Tensors, hence the need to use `create_op` directly. 832 args = [ops.internal_convert_to_tensor(x) for x in args] 833 tin_attr = attr_value_pb2.AttrValue( 834 list=attr_value_pb2.AttrValue.ListValue( 835 type=[x.dtype.as_datatype_enum for x in args])) 836 tout_attr = attr_value_pb2.AttrValue( 837 list=attr_value_pb2.AttrValue.ListValue(type=tout)) 838 func_attr = attr_value_pb2.AttrValue( 839 func=attr_value_pb2.NameAttrList(name=f.name)) 840 executor_type_attr = attr_value_pb2.AttrValue( 841 s=compat.as_bytes(executor_type)) 842 843 # When running in graph mode, the graph and function graphs are optimized 844 # (i.e. run through grappler) per the session options, so we can disable any 845 # eager-specific rewriting. 846 config_proto = attr_value_pb2.AttrValue( 847 s=function_utils.get_disabled_rewriter_config()) 848 849 graph = ops.get_default_graph() 850 f.add_to_graph(graph) 851 op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall" 852 op = graph.create_op( 853 op_name, 854 args, 855 tout, 856 compute_shapes=False, 857 name="PartitionedFunctionCall", 858 attrs={ 859 "Tin": tin_attr, 860 "Tout": tout_attr, 861 "f": func_attr, 862 "config_proto": config_proto, 863 "executor_type": executor_type_attr, 864 }) 865 outputs = op.outputs 866 return outputs if outputs else op 867