1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Variables. 17 18See the [Variables](https://www.tensorflow.org/guide/variables) guide. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import gen_math_ops 29from tensorflow.python.ops import gen_resource_variable_ops 30from tensorflow.python.ops import gen_state_ops 31# go/tf-wildcard-import 32# pylint: disable=wildcard-import 33from tensorflow.python.ops.gen_state_ops import * 34# pylint: enable=wildcard-import 35from tensorflow.python.util import deprecation 36from tensorflow.python.util.deprecation import deprecated 37from tensorflow.python.util.tf_export import tf_export 38 39 40# pylint: disable=protected-access,g-doc-return-or-yield,g-doc-args 41def variable_op(shape, dtype, name="Variable", set_shape=True, container="", 42 shared_name=""): 43 """Deprecated. Used variable_op_v2 instead.""" 44 if not set_shape: 45 shape = tensor_shape.unknown_shape() 46 ret = gen_state_ops.variable(shape=shape, dtype=dtype, name=name, 47 container=container, shared_name=shared_name) 48 # TODO(mrry): Move this to where it is used, so we can get rid of this op 49 # wrapper? 50 if set_shape: 51 ret.set_shape(shape) 52 return ret 53 54 55def variable_op_v2(shape, dtype, name="Variable", container="", shared_name=""): 56 """Create a variable Operation. 57 58 See also variables.Variable. 59 60 Args: 61 shape: The shape of the tensor managed by this variable 62 dtype: The underlying type of the tensor values. 63 name: optional name to use for the variable op. 64 container: An optional string. Defaults to "". 65 If non-empty, this variable is placed in the given container. 66 Otherwise, a default container is used. 67 shared_name: An optional string. Defaults to "". 68 If non-empty, this variable is named in the given bucket 69 with this shared_name. Otherwise, the node name is used instead. 70 71 Returns: 72 A variable tensor. 73 """ 74 return gen_state_ops.variable_v2( 75 shape=shape, 76 dtype=dtype, 77 name=name, 78 container=container, 79 shared_name=shared_name) 80 81 82def init_variable(v, init, name="init"): 83 """Initializes variable with "init". 84 85 This op does the following: 86 if init is a Tensor, v = init 87 if callable(init): v = init(VariableShape(v), v.dtype) 88 89 Args: 90 v: Variable to initialize 91 init: Tensor to assign to v, 92 Or an object convertible to Tensor e.g. nparray, 93 Or an Initializer that generates a tensor given the shape and type of v. 94 An "Initializer" is a callable that returns a tensor that "v" should be 95 set to. It will be called as init(shape, dtype). 96 name: Optional name for the op. 97 98 Returns: 99 The operation that initializes v. 100 """ 101 with ops.name_scope(None, v.op.name + "/", [v, init]): 102 with ops.name_scope(name) as scope: 103 with ops.colocate_with(v): 104 if callable(init): 105 assert v.get_shape().is_fully_defined(), "Variable shape unknown." 106 # TODO(mrry): Convert to v.shape when the property and 107 # accessor are reconciled (and all initializers support 108 # tf.TensorShape objects). 109 value = init(v.get_shape().as_list(), v.dtype.base_dtype) 110 value = ops.convert_to_tensor(value, name="value") 111 return gen_state_ops.assign(v, value, name=scope) 112 else: 113 init = ops.convert_to_tensor(init, name="init") 114 return gen_state_ops.assign(v, init, name=scope) 115 116 117def is_variable_initialized(ref, name=None): 118 """Checks whether a tensor has been initialized. 119 120 Outputs boolean scalar indicating whether the tensor has been initialized. 121 122 Args: 123 ref: A mutable `Tensor`. 124 Should be from a `Variable` node. May be uninitialized. 125 name: A name for the operation (optional). 126 127 Returns: 128 A `Tensor` of type `bool`. 129 """ 130 if ref.dtype._is_ref_dtype: 131 return gen_state_ops.is_variable_initialized(ref=ref, name=name) 132 # Handle resource variables. 133 return ref.is_initialized(name=name) 134 135 136@tf_export(v1=["assign_sub"]) 137def assign_sub(ref, value, use_locking=None, name=None): 138 """Update 'ref' by subtracting 'value' from it. 139 140 This operation outputs "ref" after the update is done. 141 This makes it easier to chain operations that need to use the reset value. 142 143 Args: 144 ref: A mutable `Tensor`. Must be one of the following types: 145 `float32`, `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, 146 `int8`, `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. 147 Should be from a `Variable` node. 148 value: A `Tensor`. Must have the same type as `ref`. 149 The value to be subtracted to the variable. 150 use_locking: An optional `bool`. Defaults to `False`. 151 If True, the subtraction will be protected by a lock; 152 otherwise the behavior is undefined, but may exhibit less contention. 153 name: A name for the operation (optional). 154 155 Returns: 156 Same as "ref". Returned as a convenience for operations that want 157 to use the new value after the variable has been updated. 158 """ 159 if ref.dtype._is_ref_dtype: 160 return gen_state_ops.assign_sub( 161 ref, value, use_locking=use_locking, name=name) 162 return ref.assign_sub(value) 163 164 165@tf_export(v1=["assign_add"]) 166def assign_add(ref, value, use_locking=None, name=None): 167 """Update 'ref' by adding 'value' to it. 168 169 This operation outputs "ref" after the update is done. 170 This makes it easier to chain operations that need to use the reset value. 171 172 Args: 173 ref: A mutable `Tensor`. Must be one of the following types: 174 `float32`, `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, 175 `int8`, `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. 176 Should be from a `Variable` node. 177 value: A `Tensor`. Must have the same type as `ref`. 178 The value to be added to the variable. 179 use_locking: An optional `bool`. Defaults to `False`. 180 If True, the addition will be protected by a lock; 181 otherwise the behavior is undefined, but may exhibit less contention. 182 name: A name for the operation (optional). 183 184 Returns: 185 Same as "ref". Returned as a convenience for operations that want 186 to use the new value after the variable has been updated. 187 """ 188 if ref.dtype._is_ref_dtype: 189 return gen_state_ops.assign_add( 190 ref, value, use_locking=use_locking, name=name) 191 return ref.assign_add(value) 192 193 194@tf_export(v1=["assign"]) 195def assign(ref, value, validate_shape=None, use_locking=None, name=None): 196 """Update 'ref' by assigning 'value' to it. 197 198 This operation outputs a Tensor that holds the new value of 'ref' after 199 the value has been assigned. This makes it easier to chain operations 200 that need to use the reset value. 201 202 Args: 203 ref: A mutable `Tensor`. 204 Should be from a `Variable` node. May be uninitialized. 205 value: A `Tensor`. Must have the same type as `ref`. 206 The value to be assigned to the variable. 207 validate_shape: An optional `bool`. Defaults to `True`. 208 If true, the operation will validate that the shape 209 of 'value' matches the shape of the Tensor being assigned to. If false, 210 'ref' will take on the shape of 'value'. 211 use_locking: An optional `bool`. Defaults to `True`. 212 If True, the assignment will be protected by a lock; 213 otherwise the behavior is undefined, but may exhibit less contention. 214 name: A name for the operation (optional). 215 216 Returns: 217 A `Tensor` that will hold the new value of 'ref' after 218 the assignment has completed. 219 """ 220 if ref.dtype._is_ref_dtype: 221 return gen_state_ops.assign( 222 ref, value, use_locking=use_locking, name=name, 223 validate_shape=validate_shape) 224 return ref.assign(value, name=name) 225 226 227@tf_export(v1=["count_up_to"]) 228@deprecated(None, "Prefer Dataset.range instead.") 229def count_up_to(ref, limit, name=None): 230 r"""Increments 'ref' until it reaches 'limit'. 231 232 Args: 233 ref: A Variable. Must be one of the following types: `int32`, `int64`. 234 Should be from a scalar `Variable` node. 235 limit: An `int`. 236 If incrementing ref would bring it above limit, instead generates an 237 'OutOfRange' error. 238 name: A name for the operation (optional). 239 240 Returns: 241 A `Tensor`. Has the same type as `ref`. 242 A copy of the input before increment. If nothing else modifies the 243 input, the values produced will all be distinct. 244 """ 245 if ref.dtype._is_ref_dtype: 246 return gen_state_ops.count_up_to(ref, limit=limit, name=name) 247 return gen_state_ops.resource_count_up_to( 248 ref.handle, limit, T=ref.dtype, name=name) 249 250 251@tf_export(v1=["scatter_update"]) 252def scatter_update(ref, indices, updates, use_locking=True, name=None): 253 # pylint: disable=line-too-long 254 r"""Applies sparse updates to a variable reference. 255 256 This operation computes 257 258 ```python 259 # Scalar indices 260 ref[indices, ...] = updates[...] 261 262 # Vector indices (for each i) 263 ref[indices[i], ...] = updates[i, ...] 264 265 # High rank indices (for each i, ..., j) 266 ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] 267 ``` 268 269 This operation outputs `ref` after the update is done. 270 This makes it easier to chain operations that need to use the reset value. 271 272 If values in `ref` is to be updated more than once, because there are 273 duplicate entries in `indices`, the order at which the updates happen 274 for each value is undefined. 275 276 Requires `updates.shape = indices.shape + ref.shape[1:]`. 277 278 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 279 <img style="width:100%" src="https://www.tensorflow.org/images/ScatterUpdate.png" alt> 280 </div> 281 282 Args: 283 ref: A `Variable`. 284 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 285 A tensor of indices into the first dimension of `ref`. 286 updates: A `Tensor`. Must have the same type as `ref`. 287 A tensor of updated values to store in `ref`. 288 use_locking: An optional `bool`. Defaults to `True`. 289 If True, the assignment will be protected by a lock; 290 otherwise the behavior is undefined, but may exhibit less contention. 291 name: A name for the operation (optional). 292 293 Returns: 294 Same as `ref`. Returned as a convenience for operations that want 295 to use the updated values after the update is done. 296 """ 297 if ref.dtype._is_ref_dtype: 298 return gen_state_ops.scatter_update(ref, indices, updates, 299 use_locking=use_locking, name=name) 300 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_update( # pylint: disable=protected-access 301 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 302 name=name)) 303 304 305@tf_export(v1=["scatter_nd_update"]) 306def scatter_nd_update(ref, indices, updates, use_locking=True, name=None): 307 r"""Applies sparse `updates` to individual values or slices in a Variable. 308 309 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 310 311 `indices` must be integer tensor, containing indices into `ref`. 312 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 313 314 The innermost dimension of `indices` (with length `K`) corresponds to 315 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 316 dimension of `ref`. 317 318 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 319 320 ``` 321 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 322 ``` 323 324 For example, say we want to update 4 scattered elements to a rank-1 tensor to 325 8 elements. In Python, that update would look like this: 326 327 ```python 328 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 329 indices = tf.constant([[4], [3], [1] ,[7]]) 330 updates = tf.constant([9, 10, 11, 12]) 331 update = tf.scatter_nd_update(ref, indices, updates) 332 with tf.Session() as sess: 333 print sess.run(update) 334 ``` 335 336 The resulting update to ref would look like this: 337 338 [1, 11, 3, 10, 9, 6, 7, 12] 339 340 See `tf.scatter_nd` for more details about how to make updates to 341 slices. 342 343 Args: 344 ref: A Variable. 345 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 346 A tensor of indices into ref. 347 updates: A `Tensor`. Must have the same type as `ref`. 348 A Tensor. Must have the same type as ref. A tensor of updated 349 values to add to ref. 350 use_locking: An optional `bool`. Defaults to `True`. 351 An optional bool. Defaults to True. If True, the assignment will 352 be protected by a lock; otherwise the behavior is undefined, 353 but may exhibit less contention. 354 name: A name for the operation (optional). 355 356 Returns: 357 The value of the variable after the update. 358 """ 359 if ref.dtype._is_ref_dtype: 360 return gen_state_ops.scatter_nd_update( 361 ref, indices, updates, use_locking, name) 362 return ref._lazy_read(gen_state_ops.resource_scatter_nd_update( # pylint: disable=protected-access 363 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 364 name=name)) 365 366 367@tf_export(v1=["scatter_add"]) 368def scatter_add(ref, indices, updates, use_locking=False, name=None): 369 # pylint: disable=line-too-long 370 r"""Adds sparse updates to the variable referenced by `resource`. 371 372 This operation computes 373 374 ```python 375 # Scalar indices 376 ref[indices, ...] += updates[...] 377 378 # Vector indices (for each i) 379 ref[indices[i], ...] += updates[i, ...] 380 381 # High rank indices (for each i, ..., j) 382 ref[indices[i, ..., j], ...] += updates[i, ..., j, ...] 383 ``` 384 385 This operation outputs `ref` after the update is done. 386 This makes it easier to chain operations that need to use the updated value. 387 Duplicate entries are handled correctly: if multiple `indices` reference 388 the same location, their contributions add. 389 390 Requires `updates.shape = indices.shape + ref.shape[1:]`. 391 392 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 393 <img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt> 394 </div> 395 396 Args: 397 ref: A `Variable`. 398 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 399 A tensor of indices into the first dimension of `ref`. 400 updates: A `Tensor`. Must have the same type as `ref`. 401 A tensor of updated values to store in `ref`. 402 use_locking: An optional `bool`. Defaults to `False`. 403 If True, the assignment will be protected by a lock; 404 otherwise the behavior is undefined, but may exhibit less contention. 405 name: A name for the operation (optional). 406 407 Returns: 408 Same as `ref`. Returned as a convenience for operations that want 409 to use the updated values after the update is done. 410 """ 411 if ref.dtype._is_ref_dtype: 412 return gen_state_ops.scatter_add(ref, indices, updates, 413 use_locking=use_locking, name=name) 414 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_add( # pylint: disable=protected-access 415 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 416 name=name)) 417 418 419@tf_export(v1=["scatter_nd_add"]) 420def scatter_nd_add(ref, indices, updates, use_locking=False, name=None): 421 r"""Applies sparse addition to individual values or slices in a Variable. 422 423 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 424 425 `indices` must be integer tensor, containing indices into `ref`. 426 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 427 428 The innermost dimension of `indices` (with length `K`) corresponds to 429 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 430 dimension of `ref`. 431 432 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 433 434 ``` 435 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] 436 ``` 437 438 For example, say we want to add 4 scattered elements to a rank-1 tensor to 439 8 elements. In Python, that addition would look like this: 440 441 ```python 442 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 443 indices = tf.constant([[4], [3], [1], [7]]) 444 updates = tf.constant([9, 10, 11, 12]) 445 add = tf.scatter_nd_add(ref, indices, updates) 446 with tf.Session() as sess: 447 print sess.run(add) 448 ``` 449 450 The resulting update to ref would look like this: 451 452 [1, 13, 3, 14, 14, 6, 7, 20] 453 454 See `tf.scatter_nd` for more details about how to make updates to 455 slices. 456 457 Args: 458 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 459 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 460 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 461 `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node. 462 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 463 A tensor of indices into ref. 464 updates: A `Tensor`. Must have the same type as `ref`. 465 A tensor of updated values to add to ref. 466 use_locking: An optional `bool`. Defaults to `False`. 467 If True, the assignment will be protected by a lock; 468 otherwise the behavior is undefined, but may exhibit less contention. 469 name: A name for the operation (optional). 470 471 Returns: 472 A mutable `Tensor`. Has the same type as `ref`. 473 """ 474 if ref.dtype._is_ref_dtype: 475 return gen_state_ops.scatter_nd_add( 476 ref, indices, updates, use_locking, name) 477 return ref._lazy_read(gen_state_ops.resource_scatter_nd_add( # pylint: disable=protected-access 478 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 479 name=name)) 480 481 482@tf_export(v1=["scatter_sub"]) 483def scatter_sub(ref, indices, updates, use_locking=False, name=None): 484 r"""Subtracts sparse updates to a variable reference. 485 486 ```python 487 # Scalar indices 488 ref[indices, ...] -= updates[...] 489 490 # Vector indices (for each i) 491 ref[indices[i], ...] -= updates[i, ...] 492 493 # High rank indices (for each i, ..., j) 494 ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...] 495 ``` 496 497 This operation outputs `ref` after the update is done. 498 This makes it easier to chain operations that need to use the reset value. 499 500 Duplicate entries are handled correctly: if multiple `indices` reference 501 the same location, their (negated) contributions add. 502 503 Requires `updates.shape = indices.shape + ref.shape[1:]` or 504 `updates.shape = []`. 505 506 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 507 <img style="width:100%" 508 src="https://www.tensorflow.org/images/ScatterSub.png" alt> 509 </div> 510 511 Args: 512 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 513 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 514 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 515 `uint32`, `uint64`. Should be from a `Variable` node. 516 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 517 A tensor of indices into the first dimension of `ref`. 518 updates: A `Tensor`. Must have the same type as `ref`. 519 A tensor of updated values to subtract from `ref`. 520 use_locking: An optional `bool`. Defaults to `False`. 521 If True, the subtraction will be protected by a lock; 522 otherwise the behavior is undefined, but may exhibit less contention. 523 name: A name for the operation (optional). 524 525 Returns: 526 A mutable `Tensor`. Has the same type as `ref`. 527 """ 528 if ref.dtype._is_ref_dtype: 529 return gen_state_ops.scatter_sub(ref, indices, updates, 530 use_locking=use_locking, name=name) 531 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub( # pylint: disable=protected-access 532 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 533 name=name)) 534 535 536@tf_export(v1=["scatter_nd_sub"]) 537def scatter_nd_sub(ref, indices, updates, use_locking=False, name=None): 538 r"""Applies sparse subtraction to individual values or slices in a Variable. 539 540 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 541 542 `indices` must be integer tensor, containing indices into `ref`. 543 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 544 545 The innermost dimension of `indices` (with length `K`) corresponds to 546 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 547 dimension of `ref`. 548 549 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 550 551 ``` 552 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] 553 ``` 554 555 For example, say we want to subtract 4 scattered elements from a rank-1 tensor 556 with 8 elements. In Python, that update would look like this: 557 558 ```python 559 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 560 indices = tf.constant([[4], [3], [1] ,[7]]) 561 updates = tf.constant([9, 10, 11, 12]) 562 op = tf.scatter_nd_sub(ref, indices, updates) 563 with tf.Session() as sess: 564 print sess.run(op) 565 ``` 566 567 The resulting update to ref would look like this: 568 569 [1, -9, 3, -6, -6, 6, 7, -4] 570 571 See `tf.scatter_nd` for more details about how to make updates to 572 slices. 573 574 Args: 575 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 576 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 577 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 578 `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node. 579 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 580 A tensor of indices into ref. 581 updates: A `Tensor`. Must have the same type as `ref`. 582 A tensor of updated values to add to ref. 583 use_locking: An optional `bool`. Defaults to `False`. 584 An optional bool. Defaults to True. If True, the assignment will 585 be protected by a lock; otherwise the behavior is undefined, 586 but may exhibit less contention. 587 name: A name for the operation (optional). 588 589 Returns: 590 A mutable `Tensor`. Has the same type as `ref`. 591 """ 592 if ref.dtype._is_ref_dtype: 593 return gen_state_ops.scatter_nd_sub( 594 ref, indices, updates, use_locking, name) 595 return ref._lazy_read(gen_state_ops.resource_scatter_nd_sub( # pylint: disable=protected-access 596 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 597 name=name)) 598 599 600@tf_export(v1=["scatter_mul"]) 601def scatter_mul(ref, indices, updates, use_locking=False, name=None): 602 # pylint: disable=line-too-long 603 r"""Multiplies sparse updates into a variable reference. 604 605 This operation computes 606 607 ```python 608 # Scalar indices 609 ref[indices, ...] *= updates[...] 610 611 # Vector indices (for each i) 612 ref[indices[i], ...] *= updates[i, ...] 613 614 # High rank indices (for each i, ..., j) 615 ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...] 616 ``` 617 618 This operation outputs `ref` after the update is done. 619 This makes it easier to chain operations that need to use the reset value. 620 621 Duplicate entries are handled correctly: if multiple `indices` reference 622 the same location, their contributions multiply. 623 624 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = 625 []`. 626 627 Args: 628 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 629 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 630 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 631 `uint32`, `uint64`. Should be from a `Variable` node. 632 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A 633 tensor of indices into the first dimension of `ref`. 634 updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated 635 values to multiply to `ref`. 636 use_locking: An optional `bool`. Defaults to `False`. If True, the operation 637 will be protected by a lock; otherwise the behavior is undefined, but may 638 exhibit less contention. 639 name: A name for the operation (optional). 640 641 Returns: 642 A mutable `Tensor`. Has the same type as `ref`. 643 """ 644 return gen_state_ops.scatter_mul( 645 ref=ref, 646 indices=indices, 647 updates=updates, 648 use_locking=use_locking, 649 name=name) 650 651 652@tf_export(v1=["scatter_div"]) 653def scatter_div(ref, indices, updates, use_locking=False, name=None): 654 # pylint: disable=line-too-long 655 r"""Divides a variable reference by sparse updates. 656 657 This operation computes 658 659 ```python 660 # Scalar indices 661 ref[indices, ...] /= updates[...] 662 663 # Vector indices (for each i) 664 ref[indices[i], ...] /= updates[i, ...] 665 666 # High rank indices (for each i, ..., j) 667 ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] 668 ``` 669 670 This operation outputs `ref` after the update is done. 671 This makes it easier to chain operations that need to use the reset value. 672 673 Duplicate entries are handled correctly: if multiple `indices` reference 674 the same location, their contributions divide. 675 676 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = 677 []`. 678 679 Args: 680 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 681 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 682 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 683 `uint32`, `uint64`. Should be from a `Variable` node. 684 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A 685 tensor of indices into the first dimension of `ref`. 686 updates: A `Tensor`. Must have the same type as `ref`. A tensor of values 687 that `ref` is divided by. 688 use_locking: An optional `bool`. Defaults to `False`. If True, the operation 689 will be protected by a lock; otherwise the behavior is undefined, but may 690 exhibit less contention. 691 name: A name for the operation (optional). 692 693 Returns: 694 A mutable `Tensor`. Has the same type as `ref`. 695 """ 696 return gen_state_ops.scatter_div( 697 ref=ref, 698 indices=indices, 699 updates=updates, 700 use_locking=use_locking, 701 name=name) 702 703 704@tf_export(v1=["scatter_max"]) 705def scatter_max(ref, indices, updates, use_locking=False, name=None): 706 # pylint: disable=line-too-long 707 r"""Reduces sparse updates into a variable reference using the `max` operation. 708 709 This operation computes 710 711 # Scalar indices 712 ref[indices, ...] = max(ref[indices, ...], updates[...]) 713 714 # Vector indices (for each i) 715 ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...]) 716 717 # High rank indices (for each i, ..., j) 718 ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], 719 updates[i, ..., j, ...]) 720 721 This operation outputs `ref` after the update is done. 722 This makes it easier to chain operations that need to use the reset value. 723 724 Duplicate entries are handled correctly: if multiple `indices` reference 725 the same location, their contributions combine. 726 727 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = 728 []`. 729 730 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 731 <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" 732 alt> 733 </div> 734 735 Args: 736 ref: A mutable `Tensor`. Must be one of the following types: `half`, 737 `bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a 738 `Variable` node. 739 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A 740 tensor of indices into the first dimension of `ref`. 741 updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated 742 values to reduce into `ref`. 743 use_locking: An optional `bool`. Defaults to `False`. If True, the update 744 will be protected by a lock; otherwise the behavior is undefined, but may 745 exhibit less contention. 746 name: A name for the operation (optional). 747 748 Returns: 749 A mutable `Tensor`. Has the same type as `ref`. 750 """ 751 return gen_state_ops.scatter_max( 752 ref=ref, 753 indices=indices, 754 updates=updates, 755 use_locking=use_locking, 756 name=name) 757 758 759@tf_export(v1=["scatter_min"]) 760def scatter_min(ref, indices, updates, use_locking=False, name=None): 761 # pylint: disable=line-too-long 762 r"""Reduces sparse updates into a variable reference using the `min` operation. 763 764 This operation computes 765 766 # Scalar indices 767 ref[indices, ...] = min(ref[indices, ...], updates[...]) 768 769 # Vector indices (for each i) 770 ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...]) 771 772 # High rank indices (for each i, ..., j) 773 ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], 774 updates[i, ..., j, ...]) 775 776 This operation outputs `ref` after the update is done. 777 This makes it easier to chain operations that need to use the reset value. 778 779 Duplicate entries are handled correctly: if multiple `indices` reference 780 the same location, their contributions combine. 781 782 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = 783 []`. 784 785 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 786 <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" 787 alt> 788 </div> 789 790 Args: 791 ref: A mutable `Tensor`. Must be one of the following types: `half`, 792 `bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a 793 `Variable` node. 794 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A 795 tensor of indices into the first dimension of `ref`. 796 updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated 797 values to reduce into `ref`. 798 use_locking: An optional `bool`. Defaults to `False`. If True, the update 799 will be protected by a lock; otherwise the behavior is undefined, but may 800 exhibit less contention. 801 name: A name for the operation (optional). 802 803 Returns: 804 A mutable `Tensor`. Has the same type as `ref`. 805 """ 806 return gen_state_ops.scatter_min( 807 ref=ref, 808 indices=indices, 809 updates=updates, 810 use_locking=use_locking, 811 name=name) 812 813 814@tf_export(v1=["batch_scatter_update"]) 815@deprecation.deprecated( 816 "2018-11-29", "Use the batch_scatter_update method of Variable instead.") 817def batch_scatter_update(ref, indices, updates, use_locking=True, name=None): 818 """Generalization of `tf.scatter_update` to axis different than 0. 819 820 Analogous to `batch_gather`. This assumes that `ref`, `indices` and `updates` 821 have a series of leading dimensions that are the same for all of them, and the 822 updates are performed on the last dimension of indices. In other words, the 823 dimensions should be the following: 824 825 `num_prefix_dims = indices.ndims - 1` 826 `batch_dim = num_prefix_dims + 1` 827 `updates.shape = indices.shape + var.shape[batch_dim:]` 828 829 where 830 831 `updates.shape[:num_prefix_dims]` 832 `== indices.shape[:num_prefix_dims]` 833 `== var.shape[:num_prefix_dims]` 834 835 And the operation performed can be expressed as: 836 837 `var[i_1, ..., i_n, indices[i_1, ..., i_n, j]] = updates[i_1, ..., i_n, j]` 838 839 When indices is a 1D tensor, this operation is equivalent to 840 `tf.scatter_update`. 841 842 To avoid this operation there would be 2 alternatives: 843 1) Reshaping the variable by merging the first `ndims` dimensions. However, 844 this is not possible because `tf.reshape` returns a Tensor, which we 845 cannot use `tf.scatter_update` on. 846 2) Looping over the first `ndims` of the variable and using 847 `tf.scatter_update` on the subtensors that result of slicing the first 848 dimension. This is a valid option for `ndims = 1`, but less efficient than 849 this implementation. 850 851 See also `tf.scatter_update` and `tf.scatter_nd_update`. 852 853 Args: 854 ref: `Variable` to scatter onto. 855 indices: Tensor containing indices as described above. 856 updates: Tensor of updates to apply to `ref`. 857 use_locking: Boolean indicating whether to lock the writing operation. 858 name: Optional scope name string. 859 860 Returns: 861 Ref to `variable` after it has been modified. 862 863 Raises: 864 ValueError: If the initial `ndims` of `ref`, `indices`, and `updates` are 865 not the same. 866 """ 867 with ops.name_scope(name): 868 indices = ops.convert_to_tensor(indices, name="indices") 869 indices_shape = array_ops.shape(indices) 870 indices_dimensions = indices.get_shape().ndims 871 872 if indices_dimensions is None: 873 raise ValueError("batch_gather does not allow indices with unknown " 874 "shape.") 875 876 nd_indices = array_ops.expand_dims(indices, axis=-1) 877 nd_indices_list = [] 878 879 # Scatter ND requires indices to have an additional dimension, in which the 880 # coordinates of the updated things are specified. For this to be adapted to 881 # the scatter_update with several leading dimensions, we simply make use of 882 # a tf.range for all the leading dimensions followed by concat of all the 883 # coordinates we created with the original indices. 884 885 # For example if indices.shape = [2, 3, 4], we should generate the following 886 # indices for tf.scatter_nd_update: 887 # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]] 888 # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]] 889 # nd_indices[:, :, 2] = indices 890 for dimension in range(indices_dimensions - 1): 891 # In this loop we generate the following for the example (one for each 892 # iteration). 893 # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]] 894 # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]] 895 # This is done at every iteration with a tf.range over the size of the 896 # i-th dimension and using broadcasting over the desired shape. 897 dimension_size = indices_shape[dimension] 898 shape_to_broadcast = [1] * (indices_dimensions + 1) 899 shape_to_broadcast[dimension] = dimension_size 900 dimension_range = array_ops.reshape( 901 gen_math_ops._range(0, dimension_size, 1), shape_to_broadcast) 902 if dimension_range.dtype.base_dtype != nd_indices.dtype: 903 dimension_range = gen_math_ops.cast(dimension_range, nd_indices.dtype) 904 nd_indices_list.append( 905 dimension_range * array_ops.ones_like(nd_indices)) 906 # Add the original indices at the end, as described above, and concat. 907 nd_indices_list.append(nd_indices) 908 final_indices = array_ops.concat(nd_indices_list, axis=-1) 909 return scatter_nd_update( 910 ref, final_indices, updates, use_locking=use_locking) 911