1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Variables. 17 18See the [Variables](https://www.tensorflow.org/guide/variables) guide. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import gen_math_ops 29from tensorflow.python.ops import gen_resource_variable_ops 30from tensorflow.python.ops import gen_state_ops 31# go/tf-wildcard-import 32# pylint: disable=wildcard-import 33from tensorflow.python.ops.gen_state_ops import * 34# pylint: enable=wildcard-import 35from tensorflow.python.util import deprecation 36from tensorflow.python.util.deprecation import deprecated 37from tensorflow.python.util.tf_export import tf_export 38 39 40# pylint: disable=protected-access,g-doc-return-or-yield,g-doc-args 41def variable_op(shape, dtype, name="Variable", set_shape=True, container="", 42 shared_name=""): 43 """Deprecated. Used variable_op_v2 instead.""" 44 if not set_shape: 45 shape = tensor_shape.unknown_shape() 46 ret = gen_state_ops.variable(shape=shape, dtype=dtype, name=name, 47 container=container, shared_name=shared_name) 48 # TODO(mrry): Move this to where it is used, so we can get rid of this op 49 # wrapper? 50 if set_shape: 51 ret.set_shape(shape) 52 return ret 53 54 55def variable_op_v2(shape, dtype, name="Variable", container="", shared_name=""): 56 """Create a variable Operation. 57 58 See also variables.Variable. 59 60 Args: 61 shape: The shape of the tensor managed by this variable 62 dtype: The underlying type of the tensor values. 63 name: optional name to use for the variable op. 64 container: An optional string. Defaults to "". 65 If non-empty, this variable is placed in the given container. 66 Otherwise, a default container is used. 67 shared_name: An optional string. Defaults to "". 68 If non-empty, this variable is named in the given bucket 69 with this shared_name. Otherwise, the node name is used instead. 70 71 Returns: 72 A variable tensor. 73 """ 74 return gen_state_ops.variable_v2( 75 shape=shape, 76 dtype=dtype, 77 name=name, 78 container=container, 79 shared_name=shared_name) 80 81 82def init_variable(v, init, name="init"): 83 """Initializes variable with "init". 84 85 This op does the following: 86 if init is a Tensor, v = init 87 if callable(init): v = init(VariableShape(v), v.dtype) 88 89 Args: 90 v: Variable to initialize 91 init: Tensor to assign to v, 92 Or an object convertible to Tensor e.g. nparray, 93 Or an Initializer that generates a tensor given the shape and type of v. 94 An "Initializer" is a callable that returns a tensor that "v" should be 95 set to. It will be called as init(shape, dtype). 96 name: Optional name for the op. 97 98 Returns: 99 The operation that initializes v. 100 """ 101 with ops.name_scope(None, v.op.name + "/", [v, init]): 102 with ops.name_scope(name) as scope: 103 with ops.colocate_with(v): 104 if callable(init): 105 assert v.get_shape().is_fully_defined(), "Variable shape unknown." 106 # TODO(mrry): Convert to v.shape when the property and 107 # accessor are reconciled (and all initializers support 108 # tf.TensorShape objects). 109 value = init(v.get_shape().as_list(), v.dtype.base_dtype) 110 value = ops.convert_to_tensor(value, name="value") 111 return gen_state_ops.assign(v, value, name=scope) 112 else: 113 init = ops.convert_to_tensor(init, name="init") 114 return gen_state_ops.assign(v, init, name=scope) 115 116 117def is_variable_initialized(ref, name=None): 118 """Checks whether a tensor has been initialized. 119 120 Outputs boolean scalar indicating whether the tensor has been initialized. 121 122 Args: 123 ref: A mutable `Tensor`. 124 Should be from a `Variable` node. May be uninitialized. 125 name: A name for the operation (optional). 126 127 Returns: 128 A `Tensor` of type `bool`. 129 """ 130 if ref.dtype._is_ref_dtype: 131 return gen_state_ops.is_variable_initialized(ref=ref, name=name) 132 # Handle resource variables. 133 return ref.is_initialized(name=name) 134 135 136@tf_export(v1=["assign_sub"]) 137def assign_sub(ref, value, use_locking=None, name=None): 138 """Update `ref` by subtracting `value` from it. 139 140 This operation outputs `ref` after the update is done. 141 This makes it easier to chain operations that need to use the reset value. 142 Unlike `tf.math.subtract`, this op does not broadcast. `ref` and `value` 143 must have the same shape. 144 145 Args: 146 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 147 `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, 148 `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. Should be 149 from a `Variable` node. 150 value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to 151 be subtracted to the variable. 152 use_locking: An optional `bool`. Defaults to `False`. If True, the 153 subtraction will be protected by a lock; otherwise the behavior is 154 undefined, but may exhibit less contention. 155 name: A name for the operation (optional). 156 157 Returns: 158 Same as `ref`. Returned as a convenience for operations that want 159 to use the new value after the variable has been updated. 160 161 @compatibility(TF2) 162 `tf.compat.v1.assign_sub` is mostly compatible with eager 163 execution and `tf.function`. 164 165 To switch to the native TF2 style, one could use method 'assign_sub' of 166 `tf.Variable`: 167 168 #### How to Map Arguments 169 170 | TF1 Arg Name | TF2 Arg Name | Note | 171 | :-------------------- | :-------------- | :------------------------- | 172 | `ref` | `self` | In `assign_sub()` method | 173 | `value` | `value` | In `assign_sub()` method | 174 | `use_locking` | `use_locking` | In `assign_sub()` method | 175 | `name` | `name` | In `assign_sub()` method | 176 | - | `read_value` | Set to True to replicate | 177 : : : behavior (True is default) : 178 179 180 #### Before & After Usage Example 181 182 Before: 183 184 >>> with tf.Graph().as_default(): 185 ... with tf.compat.v1.Session() as sess: 186 ... a = tf.compat.v1.Variable(1, dtype=tf.int64) 187 ... sess.run(a.initializer) 188 ... update_op = tf.compat.v1.assign_sub(a, 1) 189 ... res_a = sess.run(update_op) 190 ... res_a 191 0 192 193 After: 194 195 >>> b = tf.Variable(1, dtype=tf.int64) 196 >>> res_b = b.assign_sub(1) 197 >>> res_b.numpy() 198 0 199 200 @end_compatibility 201 """ 202 if ref.dtype._is_ref_dtype: 203 return gen_state_ops.assign_sub( 204 ref, value, use_locking=use_locking, name=name) 205 return ref.assign_sub(value) 206 207 208@tf_export(v1=["assign_add"]) 209def assign_add(ref, value, use_locking=None, name=None): 210 """Update `ref` by adding `value` to it. 211 212 This operation outputs `ref` after the update is done. 213 This makes it easier to chain operations that need to use the reset value. 214 Unlike `tf.math.add`, this op does not broadcast. `ref` and `value` must have 215 the same shape. 216 217 Args: 218 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 219 `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, 220 `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. Should be 221 from a `Variable` node. 222 value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to 223 be added to the variable. 224 use_locking: An optional `bool`. Defaults to `False`. If True, the addition 225 will be protected by a lock; otherwise the behavior is undefined, but may 226 exhibit less contention. 227 name: A name for the operation (optional). 228 229 Returns: 230 Same as `ref`. Returned as a convenience for operations that want 231 to use the new value after the variable has been updated. 232 233 @compatibility(TF2) 234 `tf.compat.v1.assign_add` is mostly compatible with eager 235 execution and `tf.function`. 236 237 To switch to the native TF2 style, one could use method 'assign_add' of 238 `tf.Variable`: 239 240 #### How to Map Arguments 241 242 | TF1 Arg Name | TF2 Arg Name | Note | 243 | :-------------------- | :-------------- | :------------------------- | 244 | `ref` | `self` | In `assign_add()` method | 245 | `value` | `value` | In `assign_add()` method | 246 | `use_locking` | `use_locking` | In `assign_add()` method | 247 | `name` | `name` | In `assign_add()` method | 248 | - | `read_value` | Set to True to replicate | 249 : : : behavior (True is default) : 250 251 252 #### Before & After Usage Example 253 254 Before: 255 256 >>> with tf.Graph().as_default(): 257 ... with tf.compat.v1.Session() as sess: 258 ... a = tf.compat.v1.Variable(0, dtype=tf.int64) 259 ... sess.run(a.initializer) 260 ... update_op = tf.compat.v1.assign_add(a, 1) 261 ... res_a = sess.run(update_op) 262 ... res_a 263 1 264 265 After: 266 267 >>> b = tf.Variable(0, dtype=tf.int64) 268 >>> res_b = b.assign_add(1) 269 >>> res_b.numpy() 270 1 271 272 @end_compatibility 273 """ 274 if ref.dtype._is_ref_dtype: 275 return gen_state_ops.assign_add( 276 ref, value, use_locking=use_locking, name=name) 277 return ref.assign_add(value) 278 279 280@tf_export(v1=["assign"]) 281def assign(ref, value, validate_shape=None, use_locking=None, name=None): 282 """Update `ref` by assigning `value` to it. 283 284 This operation outputs a Tensor that holds the new value of `ref` after 285 the value has been assigned. This makes it easier to chain operations that 286 need to use the reset value. 287 288 Args: 289 ref: A mutable `Tensor`. Should be from a `Variable` node. May be 290 uninitialized. 291 value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to 292 be assigned to the variable. 293 validate_shape: An optional `bool`. Defaults to `True`. If true, the 294 operation will validate that the shape of 'value' matches the shape of the 295 Tensor being assigned to. If false, 'ref' will take on the shape of 296 'value'. 297 use_locking: An optional `bool`. Defaults to `True`. If True, the assignment 298 will be protected by a lock; otherwise the behavior is undefined, but may 299 exhibit less contention. 300 name: A name for the operation (optional). 301 302 Returns: 303 A `Tensor` that will hold the new value of `ref` after 304 the assignment has completed. 305 306 @compatibility(TF2) 307 `tf.compat.v1.assign` is mostly compatible with eager 308 execution and `tf.function`. However, argument 'validate_shape' will be 309 ignored. To avoid shape validation, set 'shape' to tf.TensorShape(None) when 310 constructing the variable: 311 312 >>> import tensorflow as tf 313 >>> a = tf.Variable([1], shape=tf.TensorShape(None)) 314 >>> tf.compat.v1.assign(a, [2,3]) 315 316 To switch to the native TF2 style, one could use method 'assign' of 317 `tf.Variable`: 318 319 #### How to Map Arguments 320 321 | TF1 Arg Name | TF2 Arg Name | Note | 322 | :-------------------- | :-------------- | :------------------------- | 323 | `ref` | `self` | In `assign()` method | 324 | `value` | `value` | In `assign()` method | 325 | `validate_shape` | Not supported | Specify `shape` in the | 326 : : : constructor to replicate : 327 : : : behavior : 328 | `use_locking` | `use_locking` | In `assign()` method | 329 | `name` | `name` | In `assign()` method | 330 | - | `read_value` | Set to True to replicate | 331 : : : behavior (True is default) : 332 @end_compatibility 333 334 335 #### Before & After Usage Example 336 337 Before: 338 339 >>> with tf.Graph().as_default(): 340 ... with tf.compat.v1.Session() as sess: 341 ... a = tf.compat.v1.Variable(0, dtype=tf.int64) 342 ... sess.run(a.initializer) 343 ... update_op = tf.compat.v1.assign(a, 2) 344 ... res_a = sess.run(update_op) 345 ... res_a 346 2 347 348 After: 349 350 >>> b = tf.Variable(0, dtype=tf.int64) 351 >>> res_b = b.assign(2) 352 >>> res_b.numpy() 353 2 354 """ 355 if ref.dtype._is_ref_dtype: 356 return gen_state_ops.assign( 357 ref, value, use_locking=use_locking, name=name, 358 validate_shape=validate_shape) 359 return ref.assign(value, name=name) 360 361 362@tf_export(v1=["count_up_to"]) 363@deprecated(None, "Prefer Dataset.range instead.") 364def count_up_to(ref, limit, name=None): 365 r"""Increments 'ref' until it reaches 'limit'. 366 367 Args: 368 ref: A Variable. Must be one of the following types: `int32`, `int64`. 369 Should be from a scalar `Variable` node. 370 limit: An `int`. 371 If incrementing ref would bring it above limit, instead generates an 372 'OutOfRange' error. 373 name: A name for the operation (optional). 374 375 Returns: 376 A `Tensor`. Has the same type as `ref`. 377 A copy of the input before increment. If nothing else modifies the 378 input, the values produced will all be distinct. 379 """ 380 if ref.dtype._is_ref_dtype: 381 return gen_state_ops.count_up_to(ref, limit=limit, name=name) 382 return gen_state_ops.resource_count_up_to( 383 ref.handle, limit, T=ref.dtype, name=name) 384 385 386@tf_export(v1=["scatter_update"]) 387def scatter_update(ref, indices, updates, use_locking=True, name=None): 388 # pylint: disable=line-too-long 389 r"""Applies sparse updates to a variable reference. 390 391 This operation computes 392 393 ```python 394 # Scalar indices 395 ref[indices, ...] = updates[...] 396 397 # Vector indices (for each i) 398 ref[indices[i], ...] = updates[i, ...] 399 400 # High rank indices (for each i, ..., j) 401 ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] 402 ``` 403 404 This operation outputs `ref` after the update is done. 405 This makes it easier to chain operations that need to use the reset value. 406 407 If values in `ref` is to be updated more than once, because there are 408 duplicate entries in `indices`, the order at which the updates happen 409 for each value is undefined. 410 411 Requires `updates.shape = indices.shape + ref.shape[1:]`. 412 413 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 414 <img style="width:100%" src="https://www.tensorflow.org/images/ScatterUpdate.png" alt> 415 </div> 416 417 Args: 418 ref: A `Variable`. 419 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 420 A tensor of indices into the first dimension of `ref`. 421 updates: A `Tensor`. Must have the same type as `ref`. 422 A tensor of updated values to store in `ref`. 423 use_locking: An optional `bool`. Defaults to `True`. 424 If True, the assignment will be protected by a lock; 425 otherwise the behavior is undefined, but may exhibit less contention. 426 name: A name for the operation (optional). 427 428 Returns: 429 Same as `ref`. Returned as a convenience for operations that want 430 to use the updated values after the update is done. 431 """ 432 if ref.dtype._is_ref_dtype: 433 return gen_state_ops.scatter_update(ref, indices, updates, 434 use_locking=use_locking, name=name) 435 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_update( # pylint: disable=protected-access 436 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 437 name=name)) 438 439 440@tf_export(v1=["scatter_nd_update"]) 441def scatter_nd_update(ref, indices, updates, use_locking=True, name=None): 442 r"""Applies sparse `updates` to individual values or slices in a Variable. 443 444 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 445 446 `indices` must be integer tensor, containing indices into `ref`. 447 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 448 449 The innermost dimension of `indices` (with length `K`) corresponds to 450 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 451 dimension of `ref`. 452 453 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 454 455 ``` 456 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 457 ``` 458 459 For example, say we want to update 4 scattered elements to a rank-1 tensor to 460 8 elements. In Python, that update would look like this: 461 462 ```python 463 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 464 indices = tf.constant([[4], [3], [1] ,[7]]) 465 updates = tf.constant([9, 10, 11, 12]) 466 update = tf.compat.v1.scatter_nd_update(ref, indices, updates) 467 with tf.compat.v1.Session() as sess: 468 print sess.run(update) 469 ``` 470 471 The resulting update to ref would look like this: 472 473 [1, 11, 3, 10, 9, 6, 7, 12] 474 475 See `tf.scatter_nd` for more details about how to make updates to 476 slices. 477 478 Args: 479 ref: A Variable. 480 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 481 A tensor of indices into ref. 482 updates: A `Tensor`. Must have the same type as `ref`. 483 A Tensor. Must have the same type as ref. A tensor of updated 484 values to add to ref. 485 use_locking: An optional `bool`. Defaults to `True`. 486 An optional bool. Defaults to True. If True, the assignment will 487 be protected by a lock; otherwise the behavior is undefined, 488 but may exhibit less contention. 489 name: A name for the operation (optional). 490 491 Returns: 492 The value of the variable after the update. 493 """ 494 if ref.dtype._is_ref_dtype: 495 return gen_state_ops.scatter_nd_update( 496 ref, indices, updates, use_locking, name) 497 return ref._lazy_read(gen_state_ops.resource_scatter_nd_update( # pylint: disable=protected-access 498 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 499 name=name)) 500 501 502@tf_export(v1=["scatter_add"]) 503def scatter_add(ref, indices, updates, use_locking=False, name=None): 504 # pylint: disable=line-too-long 505 r"""Adds sparse updates to the variable referenced by `resource`. 506 507 This operation computes 508 509 ```python 510 # Scalar indices 511 ref[indices, ...] += updates[...] 512 513 # Vector indices (for each i) 514 ref[indices[i], ...] += updates[i, ...] 515 516 # High rank indices (for each i, ..., j) 517 ref[indices[i, ..., j], ...] += updates[i, ..., j, ...] 518 ``` 519 520 This operation outputs `ref` after the update is done. 521 This makes it easier to chain operations that need to use the updated value. 522 Duplicate entries are handled correctly: if multiple `indices` reference 523 the same location, their contributions add. 524 525 Requires `updates.shape = indices.shape + ref.shape[1:]`. 526 527 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 528 <img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt> 529 </div> 530 531 Args: 532 ref: A `Variable`. 533 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 534 A tensor of indices into the first dimension of `ref`. 535 updates: A `Tensor`. Must have the same type as `ref`. 536 A tensor of updated values to store in `ref`. 537 use_locking: An optional `bool`. Defaults to `False`. 538 If True, the assignment will be protected by a lock; 539 otherwise the behavior is undefined, but may exhibit less contention. 540 name: A name for the operation (optional). 541 542 Returns: 543 Same as `ref`. Returned as a convenience for operations that want 544 to use the updated values after the update is done. 545 """ 546 if ref.dtype._is_ref_dtype: 547 return gen_state_ops.scatter_add(ref, indices, updates, 548 use_locking=use_locking, name=name) 549 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_add( # pylint: disable=protected-access 550 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 551 name=name)) 552 553 554@tf_export(v1=["scatter_nd_add"]) 555def scatter_nd_add(ref, indices, updates, use_locking=False, name=None): 556 r"""Applies sparse addition to individual values or slices in a Variable. 557 558 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 559 560 `indices` must be integer tensor, containing indices into `ref`. 561 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 562 563 The innermost dimension of `indices` (with length `K`) corresponds to 564 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 565 dimension of `ref`. 566 567 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 568 569 ``` 570 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] 571 ``` 572 573 For example, say we want to add 4 scattered elements to a rank-1 tensor to 574 8 elements. In Python, that addition would look like this: 575 576 ```python 577 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 578 indices = tf.constant([[4], [3], [1], [7]]) 579 updates = tf.constant([9, 10, 11, 12]) 580 add = tf.compat.v1.scatter_nd_add(ref, indices, updates) 581 with tf.compat.v1.Session() as sess: 582 print sess.run(add) 583 ``` 584 585 The resulting update to ref would look like this: 586 587 [1, 13, 3, 14, 14, 6, 7, 20] 588 589 See `tf.scatter_nd` for more details about how to make updates to 590 slices. 591 592 Args: 593 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 594 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 595 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 596 `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node. 597 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 598 A tensor of indices into ref. 599 updates: A `Tensor`. Must have the same type as `ref`. 600 A tensor of updated values to add to ref. 601 use_locking: An optional `bool`. Defaults to `False`. 602 If True, the assignment will be protected by a lock; 603 otherwise the behavior is undefined, but may exhibit less contention. 604 name: A name for the operation (optional). 605 606 Returns: 607 A mutable `Tensor`. Has the same type as `ref`. 608 """ 609 if ref.dtype._is_ref_dtype: 610 return gen_state_ops.scatter_nd_add( 611 ref, indices, updates, use_locking, name) 612 return ref._lazy_read(gen_state_ops.resource_scatter_nd_add( # pylint: disable=protected-access 613 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 614 name=name)) 615 616 617@tf_export(v1=["scatter_sub"]) 618def scatter_sub(ref, indices, updates, use_locking=False, name=None): 619 r"""Subtracts sparse updates to a variable reference. 620 621 ```python 622 # Scalar indices 623 ref[indices, ...] -= updates[...] 624 625 # Vector indices (for each i) 626 ref[indices[i], ...] -= updates[i, ...] 627 628 # High rank indices (for each i, ..., j) 629 ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...] 630 ``` 631 632 This operation outputs `ref` after the update is done. 633 This makes it easier to chain operations that need to use the reset value. 634 635 Duplicate entries are handled correctly: if multiple `indices` reference 636 the same location, their (negated) contributions add. 637 638 Requires `updates.shape = indices.shape + ref.shape[1:]` or 639 `updates.shape = []`. 640 641 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 642 <img style="width:100%" 643 src="https://www.tensorflow.org/images/ScatterSub.png" alt> 644 </div> 645 646 Args: 647 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 648 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 649 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 650 `uint32`, `uint64`. Should be from a `Variable` node. 651 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 652 A tensor of indices into the first dimension of `ref`. 653 updates: A `Tensor`. Must have the same type as `ref`. 654 A tensor of updated values to subtract from `ref`. 655 use_locking: An optional `bool`. Defaults to `False`. 656 If True, the subtraction will be protected by a lock; 657 otherwise the behavior is undefined, but may exhibit less contention. 658 name: A name for the operation (optional). 659 660 Returns: 661 A mutable `Tensor`. Has the same type as `ref`. 662 """ 663 if ref.dtype._is_ref_dtype: 664 return gen_state_ops.scatter_sub(ref, indices, updates, 665 use_locking=use_locking, name=name) 666 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub( # pylint: disable=protected-access 667 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 668 name=name)) 669 670 671@tf_export(v1=["scatter_nd_sub"]) 672def scatter_nd_sub(ref, indices, updates, use_locking=False, name=None): 673 r"""Applies sparse subtraction to individual values or slices in a Variable. 674 675 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 676 677 `indices` must be integer tensor, containing indices into `ref`. 678 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 679 680 The innermost dimension of `indices` (with length `K`) corresponds to 681 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 682 dimension of `ref`. 683 684 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 685 686 ``` 687 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] 688 ``` 689 690 For example, say we want to subtract 4 scattered elements from a rank-1 tensor 691 with 8 elements. In Python, that update would look like this: 692 693 ```python 694 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 695 indices = tf.constant([[4], [3], [1] ,[7]]) 696 updates = tf.constant([9, 10, 11, 12]) 697 op = tf.compat.v1.scatter_nd_sub(ref, indices, updates) 698 with tf.compat.v1.Session() as sess: 699 print sess.run(op) 700 ``` 701 702 The resulting update to ref would look like this: 703 704 [1, -9, 3, -6, -6, 6, 7, -4] 705 706 See `tf.scatter_nd` for more details about how to make updates to 707 slices. 708 709 Args: 710 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 711 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 712 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 713 `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node. 714 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 715 A tensor of indices into ref. 716 updates: A `Tensor`. Must have the same type as `ref`. 717 A tensor of updated values to add to ref. 718 use_locking: An optional `bool`. Defaults to `False`. 719 An optional bool. Defaults to True. If True, the assignment will 720 be protected by a lock; otherwise the behavior is undefined, 721 but may exhibit less contention. 722 name: A name for the operation (optional). 723 724 Returns: 725 A mutable `Tensor`. Has the same type as `ref`. 726 """ 727 if ref.dtype._is_ref_dtype: 728 return gen_state_ops.scatter_nd_sub( 729 ref, indices, updates, use_locking, name) 730 return ref._lazy_read(gen_state_ops.resource_scatter_nd_sub( # pylint: disable=protected-access 731 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 732 name=name)) 733 734 735@tf_export(v1=["scatter_mul"]) 736def scatter_mul(ref, indices, updates, use_locking=False, name=None): 737 # pylint: disable=line-too-long 738 r"""Multiplies sparse updates into a variable reference. 739 740 This operation computes 741 742 ```python 743 # Scalar indices 744 ref[indices, ...] *= updates[...] 745 746 # Vector indices (for each i) 747 ref[indices[i], ...] *= updates[i, ...] 748 749 # High rank indices (for each i, ..., j) 750 ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...] 751 ``` 752 753 This operation outputs `ref` after the update is done. 754 This makes it easier to chain operations that need to use the reset value. 755 756 Duplicate entries are handled correctly: if multiple `indices` reference 757 the same location, their contributions multiply. 758 759 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = 760 []`. 761 762 Args: 763 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 764 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 765 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 766 `uint32`, `uint64`. Should be from a `Variable` node. 767 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A 768 tensor of indices into the first dimension of `ref`. 769 updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated 770 values to multiply to `ref`. 771 use_locking: An optional `bool`. Defaults to `False`. If True, the operation 772 will be protected by a lock; otherwise the behavior is undefined, but may 773 exhibit less contention. 774 name: A name for the operation (optional). 775 776 Returns: 777 A mutable `Tensor`. Has the same type as `ref`. 778 """ 779 if ref.dtype._is_ref_dtype: 780 return gen_state_ops.scatter_mul(ref, indices, updates, 781 use_locking=use_locking, name=name) 782 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_mul( # pylint: disable=protected-access 783 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 784 name=name)) 785 786 787@tf_export(v1=["scatter_div"]) 788def scatter_div(ref, indices, updates, use_locking=False, name=None): 789 # pylint: disable=line-too-long 790 r"""Divides a variable reference by sparse updates. 791 792 This operation computes 793 794 ```python 795 # Scalar indices 796 ref[indices, ...] /= updates[...] 797 798 # Vector indices (for each i) 799 ref[indices[i], ...] /= updates[i, ...] 800 801 # High rank indices (for each i, ..., j) 802 ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] 803 ``` 804 805 This operation outputs `ref` after the update is done. 806 This makes it easier to chain operations that need to use the reset value. 807 808 Duplicate entries are handled correctly: if multiple `indices` reference 809 the same location, their contributions divide. 810 811 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = 812 []`. 813 814 Args: 815 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 816 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 817 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 818 `uint32`, `uint64`. Should be from a `Variable` node. 819 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A 820 tensor of indices into the first dimension of `ref`. 821 updates: A `Tensor`. Must have the same type as `ref`. A tensor of values 822 that `ref` is divided by. 823 use_locking: An optional `bool`. Defaults to `False`. If True, the operation 824 will be protected by a lock; otherwise the behavior is undefined, but may 825 exhibit less contention. 826 name: A name for the operation (optional). 827 828 Returns: 829 A mutable `Tensor`. Has the same type as `ref`. 830 """ 831 if ref.dtype._is_ref_dtype: 832 return gen_state_ops.scatter_div(ref, indices, updates, 833 use_locking=use_locking, name=name) 834 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_div( # pylint: disable=protected-access 835 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 836 name=name)) 837 838 839@tf_export(v1=["scatter_max"]) 840def scatter_max(ref, indices, updates, use_locking=False, name=None): 841 # pylint: disable=line-too-long 842 r"""Reduces sparse updates into a variable reference using the `max` operation. 843 844 This operation computes 845 846 # Scalar indices 847 ref[indices, ...] = max(ref[indices, ...], updates[...]) 848 849 # Vector indices (for each i) 850 ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...]) 851 852 # High rank indices (for each i, ..., j) 853 ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], 854 updates[i, ..., j, ...]) 855 856 This operation outputs `ref` after the update is done. 857 This makes it easier to chain operations that need to use the reset value. 858 859 Duplicate entries are handled correctly: if multiple `indices` reference 860 the same location, their contributions combine. 861 862 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = 863 []`. 864 865 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 866 <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" 867 alt> 868 </div> 869 870 Args: 871 ref: A mutable `Tensor`. Must be one of the following types: `half`, 872 `bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a 873 `Variable` node. 874 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A 875 tensor of indices into the first dimension of `ref`. 876 updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated 877 values to reduce into `ref`. 878 use_locking: An optional `bool`. Defaults to `False`. If True, the update 879 will be protected by a lock; otherwise the behavior is undefined, but may 880 exhibit less contention. 881 name: A name for the operation (optional). 882 883 Returns: 884 A mutable `Tensor`. Has the same type as `ref`. 885 """ 886 if ref.dtype._is_ref_dtype: 887 return gen_state_ops.scatter_max(ref, indices, updates, 888 use_locking=use_locking, name=name) 889 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_max( # pylint: disable=protected-access 890 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 891 name=name)) 892 893 894@tf_export(v1=["scatter_min"]) 895def scatter_min(ref, indices, updates, use_locking=False, name=None): 896 # pylint: disable=line-too-long 897 r"""Reduces sparse updates into a variable reference using the `min` operation. 898 899 This operation computes 900 901 # Scalar indices 902 ref[indices, ...] = min(ref[indices, ...], updates[...]) 903 904 # Vector indices (for each i) 905 ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...]) 906 907 # High rank indices (for each i, ..., j) 908 ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], 909 updates[i, ..., j, ...]) 910 911 This operation outputs `ref` after the update is done. 912 This makes it easier to chain operations that need to use the reset value. 913 914 Duplicate entries are handled correctly: if multiple `indices` reference 915 the same location, their contributions combine. 916 917 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = 918 []`. 919 920 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 921 <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" 922 alt> 923 </div> 924 925 Args: 926 ref: A mutable `Tensor`. Must be one of the following types: `half`, 927 `bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a 928 `Variable` node. 929 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A 930 tensor of indices into the first dimension of `ref`. 931 updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated 932 values to reduce into `ref`. 933 use_locking: An optional `bool`. Defaults to `False`. If True, the update 934 will be protected by a lock; otherwise the behavior is undefined, but may 935 exhibit less contention. 936 name: A name for the operation (optional). 937 938 Returns: 939 A mutable `Tensor`. Has the same type as `ref`. 940 """ 941 if ref.dtype._is_ref_dtype: 942 return gen_state_ops.scatter_min(ref, indices, updates, 943 use_locking=use_locking, name=name) 944 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_min( # pylint: disable=protected-access 945 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 946 name=name)) 947 948 949@tf_export(v1=["batch_scatter_update"]) 950@deprecation.deprecated( 951 "2018-11-29", "Use the batch_scatter_update method of Variable instead.") 952def batch_scatter_update(ref, indices, updates, use_locking=True, name=None): 953 """Generalization of `tf.compat.v1.scatter_update` to axis different than 0. 954 955 Analogous to `batch_gather`. This assumes that `ref`, `indices` and `updates` 956 have a series of leading dimensions that are the same for all of them, and the 957 updates are performed on the last dimension of indices. In other words, the 958 dimensions should be the following: 959 960 `num_prefix_dims = indices.ndims - 1` 961 `batch_dim = num_prefix_dims + 1` 962 `updates.shape = indices.shape + var.shape[batch_dim:]` 963 964 where 965 966 `updates.shape[:num_prefix_dims]` 967 `== indices.shape[:num_prefix_dims]` 968 `== var.shape[:num_prefix_dims]` 969 970 And the operation performed can be expressed as: 971 972 `var[i_1, ..., i_n, indices[i_1, ..., i_n, j]] = updates[i_1, ..., i_n, j]` 973 974 When indices is a 1D tensor, this operation is equivalent to 975 `tf.compat.v1.scatter_update`. 976 977 To avoid this operation there would be 2 alternatives: 978 1) Reshaping the variable by merging the first `ndims` dimensions. However, 979 this is not possible because `tf.reshape` returns a Tensor, which we 980 cannot use `tf.compat.v1.scatter_update` on. 981 2) Looping over the first `ndims` of the variable and using 982 `tf.compat.v1.scatter_update` on the subtensors that result of slicing the 983 first 984 dimension. This is a valid option for `ndims = 1`, but less efficient than 985 this implementation. 986 987 See also `tf.compat.v1.scatter_update` and `tf.compat.v1.scatter_nd_update`. 988 989 Args: 990 ref: `Variable` to scatter onto. 991 indices: Tensor containing indices as described above. 992 updates: Tensor of updates to apply to `ref`. 993 use_locking: Boolean indicating whether to lock the writing operation. 994 name: Optional scope name string. 995 996 Returns: 997 Ref to `variable` after it has been modified. 998 999 Raises: 1000 ValueError: If the initial `ndims` of `ref`, `indices`, and `updates` are 1001 not the same. 1002 """ 1003 with ops.name_scope(name): 1004 indices = ops.convert_to_tensor(indices, name="indices") 1005 indices_shape = array_ops.shape(indices) 1006 indices_dimensions = indices.get_shape().ndims 1007 1008 if indices_dimensions is None: 1009 raise ValueError("batch_gather does not allow indices with unknown " 1010 "shape.") 1011 1012 nd_indices = array_ops.expand_dims(indices, axis=-1) 1013 nd_indices_list = [] 1014 1015 # Scatter ND requires indices to have an additional dimension, in which the 1016 # coordinates of the updated things are specified. For this to be adapted to 1017 # the scatter_update with several leading dimensions, we simply make use of 1018 # a tf.range for all the leading dimensions followed by concat of all the 1019 # coordinates we created with the original indices. 1020 1021 # For example if indices.shape = [2, 3, 4], we should generate the following 1022 # indices for tf.compat.v1.scatter_nd_update: 1023 # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]] 1024 # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]] 1025 # nd_indices[:, :, 2] = indices 1026 for dimension in range(indices_dimensions - 1): 1027 # In this loop we generate the following for the example (one for each 1028 # iteration). 1029 # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]] 1030 # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]] 1031 # This is done at every iteration with a tf.range over the size of the 1032 # i-th dimension and using broadcasting over the desired shape. 1033 dimension_size = indices_shape[dimension] 1034 shape_to_broadcast = [1] * (indices_dimensions + 1) 1035 shape_to_broadcast[dimension] = dimension_size 1036 dimension_range = array_ops.reshape( 1037 gen_math_ops._range(0, dimension_size, 1), shape_to_broadcast) 1038 if dimension_range.dtype.base_dtype != nd_indices.dtype: 1039 dimension_range = gen_math_ops.cast(dimension_range, nd_indices.dtype) 1040 nd_indices_list.append( 1041 dimension_range * array_ops.ones_like(nd_indices)) 1042 # Add the original indices at the end, as described above, and concat. 1043 nd_indices_list.append(nd_indices) 1044 final_indices = array_ops.concat(nd_indices_list, axis=-1) 1045 return scatter_nd_update( 1046 ref, final_indices, updates, use_locking=use_locking) 1047