• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Variables.
17
18See the [Variables](https://www.tensorflow.org/guide/variables) guide.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import gen_math_ops
29from tensorflow.python.ops import gen_resource_variable_ops
30from tensorflow.python.ops import gen_state_ops
31# go/tf-wildcard-import
32# pylint: disable=wildcard-import
33from tensorflow.python.ops.gen_state_ops import *
34# pylint: enable=wildcard-import
35from tensorflow.python.util import deprecation
36from tensorflow.python.util.deprecation import deprecated
37from tensorflow.python.util.tf_export import tf_export
38
39
40# pylint: disable=protected-access,g-doc-return-or-yield,g-doc-args
41def variable_op(shape, dtype, name="Variable", set_shape=True, container="",
42                shared_name=""):
43  """Deprecated. Used variable_op_v2 instead."""
44  if not set_shape:
45    shape = tensor_shape.unknown_shape()
46  ret = gen_state_ops.variable(shape=shape, dtype=dtype, name=name,
47                               container=container, shared_name=shared_name)
48  # TODO(mrry): Move this to where it is used, so we can get rid of this op
49  #   wrapper?
50  if set_shape:
51    ret.set_shape(shape)
52  return ret
53
54
55def variable_op_v2(shape, dtype, name="Variable", container="", shared_name=""):
56  """Create a variable Operation.
57
58  See also variables.Variable.
59
60  Args:
61    shape: The shape of the tensor managed by this variable
62    dtype: The underlying type of the tensor values.
63    name: optional name to use for the variable op.
64    container: An optional string. Defaults to "".
65      If non-empty, this variable is placed in the given container.
66      Otherwise, a default container is used.
67    shared_name: An optional string. Defaults to "".
68      If non-empty, this variable is named in the given bucket
69      with this shared_name. Otherwise, the node name is used instead.
70
71  Returns:
72    A variable tensor.
73  """
74  return gen_state_ops.variable_v2(
75      shape=shape,
76      dtype=dtype,
77      name=name,
78      container=container,
79      shared_name=shared_name)
80
81
82def init_variable(v, init, name="init"):
83  """Initializes variable with "init".
84
85  This op does the following:
86  if init is a Tensor, v = init
87  if callable(init): v = init(VariableShape(v), v.dtype)
88
89  Args:
90    v: Variable to initialize
91    init: Tensor to assign to v,
92      Or an object convertible to Tensor e.g. nparray,
93      Or an Initializer that generates a tensor given the shape and type of v.
94      An "Initializer" is a callable that returns a tensor that "v" should be
95      set to. It will be called as init(shape, dtype).
96    name: Optional name for the op.
97
98  Returns:
99    The operation that initializes v.
100  """
101  with ops.name_scope(None, v.op.name + "/", [v, init]):
102    with ops.name_scope(name) as scope:
103      with ops.colocate_with(v):
104        if callable(init):
105          assert v.get_shape().is_fully_defined(), "Variable shape unknown."
106          # TODO(mrry): Convert to v.shape when the property and
107          # accessor are reconciled (and all initializers support
108          # tf.TensorShape objects).
109          value = init(v.get_shape().as_list(), v.dtype.base_dtype)
110          value = ops.convert_to_tensor(value, name="value")
111          return gen_state_ops.assign(v, value, name=scope)
112        else:
113          init = ops.convert_to_tensor(init, name="init")
114          return gen_state_ops.assign(v, init, name=scope)
115
116
117def is_variable_initialized(ref, name=None):
118  """Checks whether a tensor has been initialized.
119
120  Outputs boolean scalar indicating whether the tensor has been initialized.
121
122  Args:
123    ref: A mutable `Tensor`.
124      Should be from a `Variable` node. May be uninitialized.
125    name: A name for the operation (optional).
126
127  Returns:
128    A `Tensor` of type `bool`.
129  """
130  if ref.dtype._is_ref_dtype:
131    return gen_state_ops.is_variable_initialized(ref=ref, name=name)
132  # Handle resource variables.
133  return ref.is_initialized(name=name)
134
135
136@tf_export(v1=["assign_sub"])
137def assign_sub(ref, value, use_locking=None, name=None):
138  """Update `ref` by subtracting `value` from it.
139
140  This operation outputs `ref` after the update is done.
141  This makes it easier to chain operations that need to use the reset value.
142  Unlike `tf.math.subtract`, this op does not broadcast. `ref` and `value`
143  must have the same shape.
144
145  Args:
146    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
147      `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`,
148      `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. Should be
149      from a `Variable` node.
150    value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to
151      be subtracted to the variable.
152    use_locking: An optional `bool`. Defaults to `False`. If True, the
153      subtraction will be protected by a lock; otherwise the behavior is
154      undefined, but may exhibit less contention.
155    name: A name for the operation (optional).
156
157  Returns:
158    Same as `ref`.  Returned as a convenience for operations that want
159    to use the new value after the variable has been updated.
160
161  @compatibility(TF2)
162  `tf.compat.v1.assign_sub` is mostly compatible with eager
163  execution and `tf.function`.
164
165  To switch to the native TF2 style, one could use method 'assign_sub' of
166  `tf.Variable`:
167
168  #### How to Map Arguments
169
170  | TF1 Arg Name          | TF2 Arg Name    | Note                       |
171  | :-------------------- | :-------------- | :------------------------- |
172  | `ref`                 | `self`          | In `assign_sub()` method   |
173  | `value`               | `value`         | In `assign_sub()` method   |
174  | `use_locking`         | `use_locking`   | In `assign_sub()` method   |
175  | `name`                | `name`          | In `assign_sub()` method   |
176  | -                     | `read_value`    | Set to True to replicate   |
177  :                       :                 : behavior (True is default) :
178
179
180  #### Before & After Usage Example
181
182  Before:
183
184  >>> with tf.Graph().as_default():
185  ...   with tf.compat.v1.Session() as sess:
186  ...     a = tf.compat.v1.Variable(1, dtype=tf.int64)
187  ...     sess.run(a.initializer)
188  ...     update_op = tf.compat.v1.assign_sub(a, 1)
189  ...     res_a = sess.run(update_op)
190  ...     res_a
191  0
192
193  After:
194
195  >>> b = tf.Variable(1, dtype=tf.int64)
196  >>> res_b = b.assign_sub(1)
197  >>> res_b.numpy()
198  0
199
200  @end_compatibility
201  """
202  if ref.dtype._is_ref_dtype:
203    return gen_state_ops.assign_sub(
204        ref, value, use_locking=use_locking, name=name)
205  return ref.assign_sub(value)
206
207
208@tf_export(v1=["assign_add"])
209def assign_add(ref, value, use_locking=None, name=None):
210  """Update `ref` by adding `value` to it.
211
212  This operation outputs `ref` after the update is done.
213  This makes it easier to chain operations that need to use the reset value.
214  Unlike `tf.math.add`, this op does not broadcast. `ref` and `value` must have
215  the same shape.
216
217  Args:
218    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
219      `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`,
220      `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. Should be
221      from a `Variable` node.
222    value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to
223      be added to the variable.
224    use_locking: An optional `bool`. Defaults to `False`. If True, the addition
225      will be protected by a lock; otherwise the behavior is undefined, but may
226      exhibit less contention.
227    name: A name for the operation (optional).
228
229  Returns:
230    Same as `ref`.  Returned as a convenience for operations that want
231    to use the new value after the variable has been updated.
232
233  @compatibility(TF2)
234  `tf.compat.v1.assign_add` is mostly compatible with eager
235  execution and `tf.function`.
236
237  To switch to the native TF2 style, one could use method 'assign_add' of
238  `tf.Variable`:
239
240  #### How to Map Arguments
241
242  | TF1 Arg Name          | TF2 Arg Name    | Note                       |
243  | :-------------------- | :-------------- | :------------------------- |
244  | `ref`                 | `self`          | In `assign_add()` method   |
245  | `value`               | `value`         | In `assign_add()` method   |
246  | `use_locking`         | `use_locking`   | In `assign_add()` method   |
247  | `name`                | `name`          | In `assign_add()` method   |
248  | -                     | `read_value`    | Set to True to replicate   |
249  :                       :                 : behavior (True is default) :
250
251
252  #### Before & After Usage Example
253
254  Before:
255
256  >>> with tf.Graph().as_default():
257  ...   with tf.compat.v1.Session() as sess:
258  ...     a = tf.compat.v1.Variable(0, dtype=tf.int64)
259  ...     sess.run(a.initializer)
260  ...     update_op = tf.compat.v1.assign_add(a, 1)
261  ...     res_a = sess.run(update_op)
262  ...     res_a
263  1
264
265  After:
266
267  >>> b = tf.Variable(0, dtype=tf.int64)
268  >>> res_b = b.assign_add(1)
269  >>> res_b.numpy()
270  1
271
272  @end_compatibility
273  """
274  if ref.dtype._is_ref_dtype:
275    return gen_state_ops.assign_add(
276        ref, value, use_locking=use_locking, name=name)
277  return ref.assign_add(value)
278
279
280@tf_export(v1=["assign"])
281def assign(ref, value, validate_shape=None, use_locking=None, name=None):
282  """Update `ref` by assigning `value` to it.
283
284  This operation outputs a Tensor that holds the new value of `ref` after
285  the value has been assigned. This makes it easier to chain operations that
286  need to use the reset value.
287
288  Args:
289    ref: A mutable `Tensor`. Should be from a `Variable` node. May be
290      uninitialized.
291    value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to
292      be assigned to the variable.
293    validate_shape: An optional `bool`. Defaults to `True`. If true, the
294      operation will validate that the shape of 'value' matches the shape of the
295      Tensor being assigned to.  If false, 'ref' will take on the shape of
296      'value'.
297    use_locking: An optional `bool`. Defaults to `True`. If True, the assignment
298      will be protected by a lock; otherwise the behavior is undefined, but may
299      exhibit less contention.
300    name: A name for the operation (optional).
301
302  Returns:
303    A `Tensor` that will hold the new value of `ref` after
304      the assignment has completed.
305
306  @compatibility(TF2)
307  `tf.compat.v1.assign` is mostly compatible with eager
308  execution and `tf.function`. However, argument 'validate_shape' will be
309  ignored. To avoid shape validation, set 'shape' to tf.TensorShape(None) when
310  constructing the variable:
311
312  >>> import tensorflow as tf
313  >>> a = tf.Variable([1], shape=tf.TensorShape(None))
314  >>> tf.compat.v1.assign(a, [2,3])
315
316  To switch to the native TF2 style, one could use method 'assign' of
317  `tf.Variable`:
318
319  #### How to Map Arguments
320
321  | TF1 Arg Name          | TF2 Arg Name    | Note                       |
322  | :-------------------- | :-------------- | :------------------------- |
323  | `ref`                 | `self`          | In `assign()` method       |
324  | `value`               | `value`         | In `assign()` method       |
325  | `validate_shape`      | Not supported   | Specify `shape` in the     |
326  :                       :                 : constructor to replicate   :
327  :                       :                 : behavior                   :
328  | `use_locking`         | `use_locking`   | In `assign()` method       |
329  | `name`                | `name`          | In `assign()` method       |
330  | -                     | `read_value`    | Set to True to replicate   |
331  :                       :                 : behavior (True is default) :
332  @end_compatibility
333
334
335  #### Before & After Usage Example
336
337  Before:
338
339  >>> with tf.Graph().as_default():
340  ...   with tf.compat.v1.Session() as sess:
341  ...     a = tf.compat.v1.Variable(0, dtype=tf.int64)
342  ...     sess.run(a.initializer)
343  ...     update_op = tf.compat.v1.assign(a, 2)
344  ...     res_a = sess.run(update_op)
345  ...     res_a
346  2
347
348  After:
349
350  >>> b = tf.Variable(0, dtype=tf.int64)
351  >>> res_b = b.assign(2)
352  >>> res_b.numpy()
353  2
354  """
355  if ref.dtype._is_ref_dtype:
356    return gen_state_ops.assign(
357        ref, value, use_locking=use_locking, name=name,
358        validate_shape=validate_shape)
359  return ref.assign(value, name=name)
360
361
362@tf_export(v1=["count_up_to"])
363@deprecated(None, "Prefer Dataset.range instead.")
364def count_up_to(ref, limit, name=None):
365  r"""Increments 'ref' until it reaches 'limit'.
366
367  Args:
368    ref: A Variable. Must be one of the following types: `int32`, `int64`.
369      Should be from a scalar `Variable` node.
370    limit: An `int`.
371      If incrementing ref would bring it above limit, instead generates an
372      'OutOfRange' error.
373    name: A name for the operation (optional).
374
375  Returns:
376    A `Tensor`. Has the same type as `ref`.
377    A copy of the input before increment. If nothing else modifies the
378    input, the values produced will all be distinct.
379  """
380  if ref.dtype._is_ref_dtype:
381    return gen_state_ops.count_up_to(ref, limit=limit, name=name)
382  return gen_state_ops.resource_count_up_to(
383      ref.handle, limit, T=ref.dtype, name=name)
384
385
386@tf_export(v1=["scatter_update"])
387def scatter_update(ref, indices, updates, use_locking=True, name=None):
388  # pylint: disable=line-too-long
389  r"""Applies sparse updates to a variable reference.
390
391  This operation computes
392
393  ```python
394      # Scalar indices
395      ref[indices, ...] = updates[...]
396
397      # Vector indices (for each i)
398      ref[indices[i], ...] = updates[i, ...]
399
400      # High rank indices (for each i, ..., j)
401      ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]
402  ```
403
404  This operation outputs `ref` after the update is done.
405  This makes it easier to chain operations that need to use the reset value.
406
407  If values in `ref` is to be updated more than once, because there are
408  duplicate entries in `indices`, the order at which the updates happen
409  for each value is undefined.
410
411  Requires `updates.shape = indices.shape + ref.shape[1:]`.
412
413  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
414  <img style="width:100%" src="https://www.tensorflow.org/images/ScatterUpdate.png" alt>
415  </div>
416
417  Args:
418    ref: A `Variable`.
419    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
420      A tensor of indices into the first dimension of `ref`.
421    updates: A `Tensor`. Must have the same type as `ref`.
422      A tensor of updated values to store in `ref`.
423    use_locking: An optional `bool`. Defaults to `True`.
424      If True, the assignment will be protected by a lock;
425      otherwise the behavior is undefined, but may exhibit less contention.
426    name: A name for the operation (optional).
427
428  Returns:
429    Same as `ref`.  Returned as a convenience for operations that want
430    to use the updated values after the update is done.
431  """
432  if ref.dtype._is_ref_dtype:
433    return gen_state_ops.scatter_update(ref, indices, updates,
434                                        use_locking=use_locking, name=name)
435  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_update(  # pylint: disable=protected-access
436      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
437      name=name))
438
439
440@tf_export(v1=["scatter_nd_update"])
441def scatter_nd_update(ref, indices, updates, use_locking=True, name=None):
442  r"""Applies sparse `updates` to individual values or slices in a Variable.
443
444  `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
445
446  `indices` must be integer tensor, containing indices into `ref`.
447  It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
448
449  The innermost dimension of `indices` (with length `K`) corresponds to
450  indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
451  dimension of `ref`.
452
453  `updates` is `Tensor` of rank `Q-1+P-K` with shape:
454
455  ```
456  [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
457  ```
458
459  For example, say we want to update 4 scattered elements to a rank-1 tensor to
460  8 elements. In Python, that update would look like this:
461
462  ```python
463      ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
464      indices = tf.constant([[4], [3], [1] ,[7]])
465      updates = tf.constant([9, 10, 11, 12])
466      update = tf.compat.v1.scatter_nd_update(ref, indices, updates)
467      with tf.compat.v1.Session() as sess:
468        print sess.run(update)
469  ```
470
471  The resulting update to ref would look like this:
472
473      [1, 11, 3, 10, 9, 6, 7, 12]
474
475  See `tf.scatter_nd` for more details about how to make updates to
476  slices.
477
478  Args:
479    ref: A Variable.
480    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
481      A tensor of indices into ref.
482    updates: A `Tensor`. Must have the same type as `ref`.
483      A Tensor. Must have the same type as ref. A tensor of updated
484      values to add to ref.
485    use_locking: An optional `bool`. Defaults to `True`.
486      An optional bool. Defaults to True. If True, the assignment will
487      be protected by a lock; otherwise the behavior is undefined,
488      but may exhibit less contention.
489    name: A name for the operation (optional).
490
491  Returns:
492    The value of the variable after the update.
493  """
494  if ref.dtype._is_ref_dtype:
495    return gen_state_ops.scatter_nd_update(
496        ref, indices, updates, use_locking, name)
497  return ref._lazy_read(gen_state_ops.resource_scatter_nd_update(  # pylint: disable=protected-access
498      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
499      name=name))
500
501
502@tf_export(v1=["scatter_add"])
503def scatter_add(ref, indices, updates, use_locking=False, name=None):
504  # pylint: disable=line-too-long
505  r"""Adds sparse updates to the variable referenced by `resource`.
506
507  This operation computes
508
509  ```python
510      # Scalar indices
511      ref[indices, ...] += updates[...]
512
513      # Vector indices (for each i)
514      ref[indices[i], ...] += updates[i, ...]
515
516      # High rank indices (for each i, ..., j)
517      ref[indices[i, ..., j], ...] += updates[i, ..., j, ...]
518  ```
519
520  This operation outputs `ref` after the update is done.
521  This makes it easier to chain operations that need to use the updated value.
522  Duplicate entries are handled correctly: if multiple `indices` reference
523  the same location, their contributions add.
524
525  Requires `updates.shape = indices.shape + ref.shape[1:]`.
526
527  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
528  <img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
529  </div>
530
531  Args:
532    ref: A `Variable`.
533    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
534      A tensor of indices into the first dimension of `ref`.
535    updates: A `Tensor`. Must have the same type as `ref`.
536      A tensor of updated values to store in `ref`.
537    use_locking: An optional `bool`. Defaults to `False`.
538      If True, the assignment will be protected by a lock;
539      otherwise the behavior is undefined, but may exhibit less contention.
540    name: A name for the operation (optional).
541
542  Returns:
543    Same as `ref`.  Returned as a convenience for operations that want
544    to use the updated values after the update is done.
545  """
546  if ref.dtype._is_ref_dtype:
547    return gen_state_ops.scatter_add(ref, indices, updates,
548                                     use_locking=use_locking, name=name)
549  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_add(  # pylint: disable=protected-access
550      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
551      name=name))
552
553
554@tf_export(v1=["scatter_nd_add"])
555def scatter_nd_add(ref, indices, updates, use_locking=False, name=None):
556  r"""Applies sparse addition to individual values or slices in a Variable.
557
558  `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
559
560  `indices` must be integer tensor, containing indices into `ref`.
561  It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
562
563  The innermost dimension of `indices` (with length `K`) corresponds to
564  indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
565  dimension of `ref`.
566
567  `updates` is `Tensor` of rank `Q-1+P-K` with shape:
568
569  ```
570  [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]
571  ```
572
573  For example, say we want to add 4 scattered elements to a rank-1 tensor to
574  8 elements. In Python, that addition would look like this:
575
576  ```python
577  ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
578  indices = tf.constant([[4], [3], [1], [7]])
579  updates = tf.constant([9, 10, 11, 12])
580  add = tf.compat.v1.scatter_nd_add(ref, indices, updates)
581  with tf.compat.v1.Session() as sess:
582    print sess.run(add)
583  ```
584
585  The resulting update to ref would look like this:
586
587      [1, 13, 3, 14, 14, 6, 7, 20]
588
589  See `tf.scatter_nd` for more details about how to make updates to
590  slices.
591
592  Args:
593    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
594      `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
595      `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
596      `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node.
597    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
598      A tensor of indices into ref.
599    updates: A `Tensor`. Must have the same type as `ref`.
600      A tensor of updated values to add to ref.
601    use_locking: An optional `bool`. Defaults to `False`.
602      If True, the assignment will be protected by a lock;
603      otherwise the behavior is undefined, but may exhibit less contention.
604    name: A name for the operation (optional).
605
606  Returns:
607    A mutable `Tensor`. Has the same type as `ref`.
608  """
609  if ref.dtype._is_ref_dtype:
610    return gen_state_ops.scatter_nd_add(
611        ref, indices, updates, use_locking, name)
612  return ref._lazy_read(gen_state_ops.resource_scatter_nd_add(  # pylint: disable=protected-access
613      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
614      name=name))
615
616
617@tf_export(v1=["scatter_sub"])
618def scatter_sub(ref, indices, updates, use_locking=False, name=None):
619  r"""Subtracts sparse updates to a variable reference.
620
621  ```python
622      # Scalar indices
623      ref[indices, ...] -= updates[...]
624
625      # Vector indices (for each i)
626      ref[indices[i], ...] -= updates[i, ...]
627
628      # High rank indices (for each i, ..., j)
629      ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...]
630  ```
631
632  This operation outputs `ref` after the update is done.
633  This makes it easier to chain operations that need to use the reset value.
634
635  Duplicate entries are handled correctly: if multiple `indices` reference
636  the same location, their (negated) contributions add.
637
638  Requires `updates.shape = indices.shape + ref.shape[1:]` or
639  `updates.shape = []`.
640
641  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
642  <img style="width:100%"
643       src="https://www.tensorflow.org/images/ScatterSub.png" alt>
644  </div>
645
646  Args:
647    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
648      `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
649      `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
650      `uint32`, `uint64`. Should be from a `Variable` node.
651    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
652      A tensor of indices into the first dimension of `ref`.
653    updates: A `Tensor`. Must have the same type as `ref`.
654      A tensor of updated values to subtract from `ref`.
655    use_locking: An optional `bool`. Defaults to `False`.
656      If True, the subtraction will be protected by a lock;
657      otherwise the behavior is undefined, but may exhibit less contention.
658    name: A name for the operation (optional).
659
660  Returns:
661    A mutable `Tensor`. Has the same type as `ref`.
662  """
663  if ref.dtype._is_ref_dtype:
664    return gen_state_ops.scatter_sub(ref, indices, updates,
665                                     use_locking=use_locking, name=name)
666  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub(  # pylint: disable=protected-access
667      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
668      name=name))
669
670
671@tf_export(v1=["scatter_nd_sub"])
672def scatter_nd_sub(ref, indices, updates, use_locking=False, name=None):
673  r"""Applies sparse subtraction to individual values or slices in a Variable.
674
675  `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
676
677  `indices` must be integer tensor, containing indices into `ref`.
678  It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
679
680  The innermost dimension of `indices` (with length `K`) corresponds to
681  indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
682  dimension of `ref`.
683
684  `updates` is `Tensor` of rank `Q-1+P-K` with shape:
685
686  ```
687  [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]
688  ```
689
690  For example, say we want to subtract 4 scattered elements from a rank-1 tensor
691  with 8 elements. In Python, that update would look like this:
692
693  ```python
694  ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
695  indices = tf.constant([[4], [3], [1] ,[7]])
696  updates = tf.constant([9, 10, 11, 12])
697  op = tf.compat.v1.scatter_nd_sub(ref, indices, updates)
698  with tf.compat.v1.Session() as sess:
699    print sess.run(op)
700  ```
701
702  The resulting update to ref would look like this:
703
704      [1, -9, 3, -6, -6, 6, 7, -4]
705
706  See `tf.scatter_nd` for more details about how to make updates to
707  slices.
708
709  Args:
710    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
711      `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
712      `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
713      `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node.
714    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
715      A tensor of indices into ref.
716    updates: A `Tensor`. Must have the same type as `ref`.
717      A tensor of updated values to add to ref.
718    use_locking: An optional `bool`. Defaults to `False`.
719      An optional bool. Defaults to True. If True, the assignment will
720      be protected by a lock; otherwise the behavior is undefined,
721      but may exhibit less contention.
722    name: A name for the operation (optional).
723
724  Returns:
725    A mutable `Tensor`. Has the same type as `ref`.
726  """
727  if ref.dtype._is_ref_dtype:
728    return gen_state_ops.scatter_nd_sub(
729        ref, indices, updates, use_locking, name)
730  return ref._lazy_read(gen_state_ops.resource_scatter_nd_sub(  # pylint: disable=protected-access
731      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
732      name=name))
733
734
735@tf_export(v1=["scatter_mul"])
736def scatter_mul(ref, indices, updates, use_locking=False, name=None):
737  # pylint: disable=line-too-long
738  r"""Multiplies sparse updates into a variable reference.
739
740  This operation computes
741
742  ```python
743      # Scalar indices
744      ref[indices, ...] *= updates[...]
745
746      # Vector indices (for each i)
747      ref[indices[i], ...] *= updates[i, ...]
748
749      # High rank indices (for each i, ..., j)
750      ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...]
751  ```
752
753  This operation outputs `ref` after the update is done.
754  This makes it easier to chain operations that need to use the reset value.
755
756  Duplicate entries are handled correctly: if multiple `indices` reference
757  the same location, their contributions multiply.
758
759  Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
760  []`.
761
762  Args:
763    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
764      `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
765      `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
766      `uint32`, `uint64`. Should be from a `Variable` node.
767    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
768      tensor of indices into the first dimension of `ref`.
769    updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated
770      values to multiply to `ref`.
771    use_locking: An optional `bool`. Defaults to `False`. If True, the operation
772      will be protected by a lock; otherwise the behavior is undefined, but may
773      exhibit less contention.
774    name: A name for the operation (optional).
775
776  Returns:
777    A mutable `Tensor`. Has the same type as `ref`.
778  """
779  if ref.dtype._is_ref_dtype:
780    return gen_state_ops.scatter_mul(ref, indices, updates,
781                                     use_locking=use_locking, name=name)
782  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_mul(  # pylint: disable=protected-access
783      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
784      name=name))
785
786
787@tf_export(v1=["scatter_div"])
788def scatter_div(ref, indices, updates, use_locking=False, name=None):
789  # pylint: disable=line-too-long
790  r"""Divides a variable reference by sparse updates.
791
792  This operation computes
793
794  ```python
795      # Scalar indices
796      ref[indices, ...] /= updates[...]
797
798      # Vector indices (for each i)
799      ref[indices[i], ...] /= updates[i, ...]
800
801      # High rank indices (for each i, ..., j)
802      ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...]
803  ```
804
805  This operation outputs `ref` after the update is done.
806  This makes it easier to chain operations that need to use the reset value.
807
808  Duplicate entries are handled correctly: if multiple `indices` reference
809  the same location, their contributions divide.
810
811  Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
812  []`.
813
814  Args:
815    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
816      `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
817      `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
818      `uint32`, `uint64`. Should be from a `Variable` node.
819    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
820      tensor of indices into the first dimension of `ref`.
821    updates: A `Tensor`. Must have the same type as `ref`. A tensor of values
822      that `ref` is divided by.
823    use_locking: An optional `bool`. Defaults to `False`. If True, the operation
824      will be protected by a lock; otherwise the behavior is undefined, but may
825      exhibit less contention.
826    name: A name for the operation (optional).
827
828  Returns:
829    A mutable `Tensor`. Has the same type as `ref`.
830  """
831  if ref.dtype._is_ref_dtype:
832    return gen_state_ops.scatter_div(ref, indices, updates,
833                                     use_locking=use_locking, name=name)
834  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_div(  # pylint: disable=protected-access
835      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
836      name=name))
837
838
839@tf_export(v1=["scatter_max"])
840def scatter_max(ref, indices, updates, use_locking=False, name=None):
841  # pylint: disable=line-too-long
842  r"""Reduces sparse updates into a variable reference using the `max` operation.
843
844  This operation computes
845
846      # Scalar indices
847      ref[indices, ...] = max(ref[indices, ...], updates[...])
848
849      # Vector indices (for each i)
850      ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...])
851
852      # High rank indices (for each i, ..., j)
853      ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...],
854      updates[i, ..., j, ...])
855
856  This operation outputs `ref` after the update is done.
857  This makes it easier to chain operations that need to use the reset value.
858
859  Duplicate entries are handled correctly: if multiple `indices` reference
860  the same location, their contributions combine.
861
862  Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
863  []`.
864
865  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
866  <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png"
867  alt>
868  </div>
869
870  Args:
871    ref: A mutable `Tensor`. Must be one of the following types: `half`,
872      `bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a
873      `Variable` node.
874    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
875      tensor of indices into the first dimension of `ref`.
876    updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated
877      values to reduce into `ref`.
878    use_locking: An optional `bool`. Defaults to `False`. If True, the update
879      will be protected by a lock; otherwise the behavior is undefined, but may
880      exhibit less contention.
881    name: A name for the operation (optional).
882
883  Returns:
884    A mutable `Tensor`. Has the same type as `ref`.
885  """
886  if ref.dtype._is_ref_dtype:
887    return gen_state_ops.scatter_max(ref, indices, updates,
888                                     use_locking=use_locking, name=name)
889  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_max(  # pylint: disable=protected-access
890      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
891      name=name))
892
893
894@tf_export(v1=["scatter_min"])
895def scatter_min(ref, indices, updates, use_locking=False, name=None):
896  # pylint: disable=line-too-long
897  r"""Reduces sparse updates into a variable reference using the `min` operation.
898
899  This operation computes
900
901      # Scalar indices
902      ref[indices, ...] = min(ref[indices, ...], updates[...])
903
904      # Vector indices (for each i)
905      ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...])
906
907      # High rank indices (for each i, ..., j)
908      ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...],
909      updates[i, ..., j, ...])
910
911  This operation outputs `ref` after the update is done.
912  This makes it easier to chain operations that need to use the reset value.
913
914  Duplicate entries are handled correctly: if multiple `indices` reference
915  the same location, their contributions combine.
916
917  Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
918  []`.
919
920  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
921  <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png"
922  alt>
923  </div>
924
925  Args:
926    ref: A mutable `Tensor`. Must be one of the following types: `half`,
927      `bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a
928      `Variable` node.
929    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
930      tensor of indices into the first dimension of `ref`.
931    updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated
932      values to reduce into `ref`.
933    use_locking: An optional `bool`. Defaults to `False`. If True, the update
934      will be protected by a lock; otherwise the behavior is undefined, but may
935      exhibit less contention.
936    name: A name for the operation (optional).
937
938  Returns:
939    A mutable `Tensor`. Has the same type as `ref`.
940  """
941  if ref.dtype._is_ref_dtype:
942    return gen_state_ops.scatter_min(ref, indices, updates,
943                                     use_locking=use_locking, name=name)
944  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_min(  # pylint: disable=protected-access
945      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
946      name=name))
947
948
949@tf_export(v1=["batch_scatter_update"])
950@deprecation.deprecated(
951    "2018-11-29", "Use the batch_scatter_update method of Variable instead.")
952def batch_scatter_update(ref, indices, updates, use_locking=True, name=None):
953  """Generalization of `tf.compat.v1.scatter_update` to axis different than 0.
954
955  Analogous to `batch_gather`. This assumes that `ref`, `indices` and `updates`
956  have a series of leading dimensions that are the same for all of them, and the
957  updates are performed on the last dimension of indices. In other words, the
958  dimensions should be the following:
959
960  `num_prefix_dims = indices.ndims - 1`
961  `batch_dim = num_prefix_dims + 1`
962  `updates.shape = indices.shape + var.shape[batch_dim:]`
963
964  where
965
966  `updates.shape[:num_prefix_dims]`
967  `== indices.shape[:num_prefix_dims]`
968  `== var.shape[:num_prefix_dims]`
969
970  And the operation performed can be expressed as:
971
972  `var[i_1, ..., i_n, indices[i_1, ..., i_n, j]] = updates[i_1, ..., i_n, j]`
973
974  When indices is a 1D tensor, this operation is equivalent to
975  `tf.compat.v1.scatter_update`.
976
977  To avoid this operation there would be 2 alternatives:
978  1) Reshaping the variable by merging the first `ndims` dimensions. However,
979     this is not possible because `tf.reshape` returns a Tensor, which we
980     cannot use `tf.compat.v1.scatter_update` on.
981  2) Looping over the first `ndims` of the variable and using
982     `tf.compat.v1.scatter_update` on the subtensors that result of slicing the
983     first
984     dimension. This is a valid option for `ndims = 1`, but less efficient than
985     this implementation.
986
987  See also `tf.compat.v1.scatter_update` and `tf.compat.v1.scatter_nd_update`.
988
989  Args:
990    ref: `Variable` to scatter onto.
991    indices: Tensor containing indices as described above.
992    updates: Tensor of updates to apply to `ref`.
993    use_locking: Boolean indicating whether to lock the writing operation.
994    name: Optional scope name string.
995
996  Returns:
997    Ref to `variable` after it has been modified.
998
999  Raises:
1000    ValueError: If the initial `ndims` of `ref`, `indices`, and `updates` are
1001        not the same.
1002  """
1003  with ops.name_scope(name):
1004    indices = ops.convert_to_tensor(indices, name="indices")
1005    indices_shape = array_ops.shape(indices)
1006    indices_dimensions = indices.get_shape().ndims
1007
1008    if indices_dimensions is None:
1009      raise ValueError("batch_gather does not allow indices with unknown "
1010                       "shape.")
1011
1012    nd_indices = array_ops.expand_dims(indices, axis=-1)
1013    nd_indices_list = []
1014
1015    # Scatter ND requires indices to have an additional dimension, in which the
1016    # coordinates of the updated things are specified. For this to be adapted to
1017    # the scatter_update with several leading dimensions, we simply make use of
1018    # a tf.range for all the leading dimensions followed by concat of all the
1019    # coordinates we created with the original indices.
1020
1021    # For example if indices.shape = [2, 3, 4], we should generate the following
1022    # indices for tf.compat.v1.scatter_nd_update:
1023    # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]]
1024    # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]]
1025    # nd_indices[:, :, 2] = indices
1026    for dimension in range(indices_dimensions - 1):
1027      # In this loop we generate the following for the example (one for each
1028      # iteration).
1029      # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]]
1030      # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]]
1031      # This is done at every iteration with a tf.range over the size of the
1032      # i-th dimension and using broadcasting over the desired shape.
1033      dimension_size = indices_shape[dimension]
1034      shape_to_broadcast = [1] * (indices_dimensions + 1)
1035      shape_to_broadcast[dimension] = dimension_size
1036      dimension_range = array_ops.reshape(
1037          gen_math_ops._range(0, dimension_size, 1), shape_to_broadcast)
1038      if dimension_range.dtype.base_dtype != nd_indices.dtype:
1039        dimension_range = gen_math_ops.cast(dimension_range, nd_indices.dtype)
1040      nd_indices_list.append(
1041          dimension_range * array_ops.ones_like(nd_indices))
1042    # Add the original indices at the end, as described above, and concat.
1043    nd_indices_list.append(nd_indices)
1044    final_indices = array_ops.concat(nd_indices_list, axis=-1)
1045    return scatter_nd_update(
1046        ref, final_indices, updates, use_locking=use_locking)
1047