• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Variables.
17
18See the [Variables](https://www.tensorflow.org/guide/variables) guide.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import gen_math_ops
29from tensorflow.python.ops import gen_resource_variable_ops
30from tensorflow.python.ops import gen_state_ops
31# go/tf-wildcard-import
32# pylint: disable=wildcard-import
33from tensorflow.python.ops.gen_state_ops import *
34# pylint: enable=wildcard-import
35from tensorflow.python.util import deprecation
36from tensorflow.python.util.deprecation import deprecated
37from tensorflow.python.util.tf_export import tf_export
38
39
40# pylint: disable=protected-access,g-doc-return-or-yield,g-doc-args
41def variable_op(shape, dtype, name="Variable", set_shape=True, container="",
42                shared_name=""):
43  """Deprecated. Used variable_op_v2 instead."""
44  if not set_shape:
45    shape = tensor_shape.unknown_shape()
46  ret = gen_state_ops.variable(shape=shape, dtype=dtype, name=name,
47                               container=container, shared_name=shared_name)
48  # TODO(mrry): Move this to where it is used, so we can get rid of this op
49  #   wrapper?
50  if set_shape:
51    ret.set_shape(shape)
52  return ret
53
54
55def variable_op_v2(shape, dtype, name="Variable", container="", shared_name=""):
56  """Create a variable Operation.
57
58  See also variables.Variable.
59
60  Args:
61    shape: The shape of the tensor managed by this variable
62    dtype: The underlying type of the tensor values.
63    name: optional name to use for the variable op.
64    container: An optional string. Defaults to "".
65      If non-empty, this variable is placed in the given container.
66      Otherwise, a default container is used.
67    shared_name: An optional string. Defaults to "".
68      If non-empty, this variable is named in the given bucket
69      with this shared_name. Otherwise, the node name is used instead.
70
71  Returns:
72    A variable tensor.
73  """
74  return gen_state_ops.variable_v2(
75      shape=shape,
76      dtype=dtype,
77      name=name,
78      container=container,
79      shared_name=shared_name)
80
81
82def init_variable(v, init, name="init"):
83  """Initializes variable with "init".
84
85  This op does the following:
86  if init is a Tensor, v = init
87  if callable(init): v = init(VariableShape(v), v.dtype)
88
89  Args:
90    v: Variable to initialize
91    init: Tensor to assign to v,
92      Or an object convertible to Tensor e.g. nparray,
93      Or an Initializer that generates a tensor given the shape and type of v.
94      An "Initializer" is a callable that returns a tensor that "v" should be
95      set to. It will be called as init(shape, dtype).
96    name: Optional name for the op.
97
98  Returns:
99    The operation that initializes v.
100  """
101  with ops.name_scope(None, v.op.name + "/", [v, init]):
102    with ops.name_scope(name) as scope:
103      with ops.colocate_with(v):
104        if callable(init):
105          assert v.get_shape().is_fully_defined(), "Variable shape unknown."
106          # TODO(mrry): Convert to v.shape when the property and
107          # accessor are reconciled (and all initializers support
108          # tf.TensorShape objects).
109          value = init(v.get_shape().as_list(), v.dtype.base_dtype)
110          value = ops.convert_to_tensor(value, name="value")
111          return gen_state_ops.assign(v, value, name=scope)
112        else:
113          init = ops.convert_to_tensor(init, name="init")
114          return gen_state_ops.assign(v, init, name=scope)
115
116
117def is_variable_initialized(ref, name=None):
118  """Checks whether a tensor has been initialized.
119
120  Outputs boolean scalar indicating whether the tensor has been initialized.
121
122  Args:
123    ref: A mutable `Tensor`.
124      Should be from a `Variable` node. May be uninitialized.
125    name: A name for the operation (optional).
126
127  Returns:
128    A `Tensor` of type `bool`.
129  """
130  if ref.dtype._is_ref_dtype:
131    return gen_state_ops.is_variable_initialized(ref=ref, name=name)
132  # Handle resource variables.
133  return ref.is_initialized(name=name)
134
135
136@tf_export(v1=["assign_sub"])
137def assign_sub(ref, value, use_locking=None, name=None):
138  """Update 'ref' by subtracting 'value' from it.
139
140  This operation outputs "ref" after the update is done.
141  This makes it easier to chain operations that need to use the reset value.
142
143  Args:
144    ref: A mutable `Tensor`. Must be one of the following types:
145      `float32`, `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`,
146      `int8`, `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`.
147      Should be from a `Variable` node.
148    value: A `Tensor`. Must have the same type as `ref`.
149      The value to be subtracted to the variable.
150    use_locking: An optional `bool`. Defaults to `False`.
151      If True, the subtraction will be protected by a lock;
152      otherwise the behavior is undefined, but may exhibit less contention.
153    name: A name for the operation (optional).
154
155  Returns:
156    Same as "ref".  Returned as a convenience for operations that want
157    to use the new value after the variable has been updated.
158  """
159  if ref.dtype._is_ref_dtype:
160    return gen_state_ops.assign_sub(
161        ref, value, use_locking=use_locking, name=name)
162  return ref.assign_sub(value)
163
164
165@tf_export(v1=["assign_add"])
166def assign_add(ref, value, use_locking=None, name=None):
167  """Update 'ref' by adding 'value' to it.
168
169  This operation outputs "ref" after the update is done.
170  This makes it easier to chain operations that need to use the reset value.
171
172  Args:
173    ref: A mutable `Tensor`. Must be one of the following types:
174      `float32`, `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`,
175      `int8`, `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`.
176      Should be from a `Variable` node.
177    value: A `Tensor`. Must have the same type as `ref`.
178      The value to be added to the variable.
179    use_locking: An optional `bool`. Defaults to `False`.
180      If True, the addition will be protected by a lock;
181      otherwise the behavior is undefined, but may exhibit less contention.
182    name: A name for the operation (optional).
183
184  Returns:
185    Same as "ref".  Returned as a convenience for operations that want
186    to use the new value after the variable has been updated.
187  """
188  if ref.dtype._is_ref_dtype:
189    return gen_state_ops.assign_add(
190        ref, value, use_locking=use_locking, name=name)
191  return ref.assign_add(value)
192
193
194@tf_export(v1=["assign"])
195def assign(ref, value, validate_shape=None, use_locking=None, name=None):
196  """Update 'ref' by assigning 'value' to it.
197
198  This operation outputs a Tensor that holds the new value of 'ref' after
199    the value has been assigned. This makes it easier to chain operations
200    that need to use the reset value.
201
202  Args:
203    ref: A mutable `Tensor`.
204      Should be from a `Variable` node. May be uninitialized.
205    value: A `Tensor`. Must have the same type as `ref`.
206      The value to be assigned to the variable.
207    validate_shape: An optional `bool`. Defaults to `True`.
208      If true, the operation will validate that the shape
209      of 'value' matches the shape of the Tensor being assigned to.  If false,
210      'ref' will take on the shape of 'value'.
211    use_locking: An optional `bool`. Defaults to `True`.
212      If True, the assignment will be protected by a lock;
213      otherwise the behavior is undefined, but may exhibit less contention.
214    name: A name for the operation (optional).
215
216  Returns:
217    A `Tensor` that will hold the new value of 'ref' after
218      the assignment has completed.
219  """
220  if ref.dtype._is_ref_dtype:
221    return gen_state_ops.assign(
222        ref, value, use_locking=use_locking, name=name,
223        validate_shape=validate_shape)
224  return ref.assign(value, name=name)
225
226
227@tf_export(v1=["count_up_to"])
228@deprecated(None, "Prefer Dataset.range instead.")
229def count_up_to(ref, limit, name=None):
230  r"""Increments 'ref' until it reaches 'limit'.
231
232  Args:
233    ref: A Variable. Must be one of the following types: `int32`, `int64`.
234      Should be from a scalar `Variable` node.
235    limit: An `int`.
236      If incrementing ref would bring it above limit, instead generates an
237      'OutOfRange' error.
238    name: A name for the operation (optional).
239
240  Returns:
241    A `Tensor`. Has the same type as `ref`.
242    A copy of the input before increment. If nothing else modifies the
243    input, the values produced will all be distinct.
244  """
245  if ref.dtype._is_ref_dtype:
246    return gen_state_ops.count_up_to(ref, limit=limit, name=name)
247  return gen_state_ops.resource_count_up_to(
248      ref.handle, limit, T=ref.dtype, name=name)
249
250
251@tf_export(v1=["scatter_update"])
252def scatter_update(ref, indices, updates, use_locking=True, name=None):
253  # pylint: disable=line-too-long
254  r"""Applies sparse updates to a variable reference.
255
256  This operation computes
257
258  ```python
259      # Scalar indices
260      ref[indices, ...] = updates[...]
261
262      # Vector indices (for each i)
263      ref[indices[i], ...] = updates[i, ...]
264
265      # High rank indices (for each i, ..., j)
266      ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]
267  ```
268
269  This operation outputs `ref` after the update is done.
270  This makes it easier to chain operations that need to use the reset value.
271
272  If values in `ref` is to be updated more than once, because there are
273  duplicate entries in `indices`, the order at which the updates happen
274  for each value is undefined.
275
276  Requires `updates.shape = indices.shape + ref.shape[1:]`.
277
278  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
279  <img style="width:100%" src="https://www.tensorflow.org/images/ScatterUpdate.png" alt>
280  </div>
281
282  Args:
283    ref: A `Variable`.
284    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
285      A tensor of indices into the first dimension of `ref`.
286    updates: A `Tensor`. Must have the same type as `ref`.
287      A tensor of updated values to store in `ref`.
288    use_locking: An optional `bool`. Defaults to `True`.
289      If True, the assignment will be protected by a lock;
290      otherwise the behavior is undefined, but may exhibit less contention.
291    name: A name for the operation (optional).
292
293  Returns:
294    Same as `ref`.  Returned as a convenience for operations that want
295    to use the updated values after the update is done.
296  """
297  if ref.dtype._is_ref_dtype:
298    return gen_state_ops.scatter_update(ref, indices, updates,
299                                        use_locking=use_locking, name=name)
300  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_update(  # pylint: disable=protected-access
301      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
302      name=name))
303
304
305@tf_export(v1=["scatter_nd_update"])
306def scatter_nd_update(ref, indices, updates, use_locking=True, name=None):
307  r"""Applies sparse `updates` to individual values or slices in a Variable.
308
309  `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
310
311  `indices` must be integer tensor, containing indices into `ref`.
312  It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
313
314  The innermost dimension of `indices` (with length `K`) corresponds to
315  indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
316  dimension of `ref`.
317
318  `updates` is `Tensor` of rank `Q-1+P-K` with shape:
319
320  ```
321  [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
322  ```
323
324  For example, say we want to update 4 scattered elements to a rank-1 tensor to
325  8 elements. In Python, that update would look like this:
326
327  ```python
328      ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
329      indices = tf.constant([[4], [3], [1] ,[7]])
330      updates = tf.constant([9, 10, 11, 12])
331      update = tf.scatter_nd_update(ref, indices, updates)
332      with tf.Session() as sess:
333        print sess.run(update)
334  ```
335
336  The resulting update to ref would look like this:
337
338      [1, 11, 3, 10, 9, 6, 7, 12]
339
340  See `tf.scatter_nd` for more details about how to make updates to
341  slices.
342
343  Args:
344    ref: A Variable.
345    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
346      A tensor of indices into ref.
347    updates: A `Tensor`. Must have the same type as `ref`.
348      A Tensor. Must have the same type as ref. A tensor of updated
349      values to add to ref.
350    use_locking: An optional `bool`. Defaults to `True`.
351      An optional bool. Defaults to True. If True, the assignment will
352      be protected by a lock; otherwise the behavior is undefined,
353      but may exhibit less contention.
354    name: A name for the operation (optional).
355
356  Returns:
357    The value of the variable after the update.
358  """
359  if ref.dtype._is_ref_dtype:
360    return gen_state_ops.scatter_nd_update(
361        ref, indices, updates, use_locking, name)
362  return ref._lazy_read(gen_state_ops.resource_scatter_nd_update(  # pylint: disable=protected-access
363      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
364      name=name))
365
366
367@tf_export(v1=["scatter_add"])
368def scatter_add(ref, indices, updates, use_locking=False, name=None):
369  # pylint: disable=line-too-long
370  r"""Adds sparse updates to the variable referenced by `resource`.
371
372  This operation computes
373
374  ```python
375      # Scalar indices
376      ref[indices, ...] += updates[...]
377
378      # Vector indices (for each i)
379      ref[indices[i], ...] += updates[i, ...]
380
381      # High rank indices (for each i, ..., j)
382      ref[indices[i, ..., j], ...] += updates[i, ..., j, ...]
383  ```
384
385  This operation outputs `ref` after the update is done.
386  This makes it easier to chain operations that need to use the updated value.
387  Duplicate entries are handled correctly: if multiple `indices` reference
388  the same location, their contributions add.
389
390  Requires `updates.shape = indices.shape + ref.shape[1:]`.
391
392  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
393  <img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
394  </div>
395
396  Args:
397    ref: A `Variable`.
398    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
399      A tensor of indices into the first dimension of `ref`.
400    updates: A `Tensor`. Must have the same type as `ref`.
401      A tensor of updated values to store in `ref`.
402    use_locking: An optional `bool`. Defaults to `False`.
403      If True, the assignment will be protected by a lock;
404      otherwise the behavior is undefined, but may exhibit less contention.
405    name: A name for the operation (optional).
406
407  Returns:
408    Same as `ref`.  Returned as a convenience for operations that want
409    to use the updated values after the update is done.
410  """
411  if ref.dtype._is_ref_dtype:
412    return gen_state_ops.scatter_add(ref, indices, updates,
413                                     use_locking=use_locking, name=name)
414  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_add(  # pylint: disable=protected-access
415      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
416      name=name))
417
418
419@tf_export(v1=["scatter_nd_add"])
420def scatter_nd_add(ref, indices, updates, use_locking=False, name=None):
421  r"""Applies sparse addition to individual values or slices in a Variable.
422
423  `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
424
425  `indices` must be integer tensor, containing indices into `ref`.
426  It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
427
428  The innermost dimension of `indices` (with length `K`) corresponds to
429  indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
430  dimension of `ref`.
431
432  `updates` is `Tensor` of rank `Q-1+P-K` with shape:
433
434  ```
435  [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]
436  ```
437
438  For example, say we want to add 4 scattered elements to a rank-1 tensor to
439  8 elements. In Python, that addition would look like this:
440
441  ```python
442  ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
443  indices = tf.constant([[4], [3], [1], [7]])
444  updates = tf.constant([9, 10, 11, 12])
445  add = tf.scatter_nd_add(ref, indices, updates)
446  with tf.Session() as sess:
447    print sess.run(add)
448  ```
449
450  The resulting update to ref would look like this:
451
452      [1, 13, 3, 14, 14, 6, 7, 20]
453
454  See `tf.scatter_nd` for more details about how to make updates to
455  slices.
456
457  Args:
458    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
459      `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
460      `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
461      `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node.
462    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
463      A tensor of indices into ref.
464    updates: A `Tensor`. Must have the same type as `ref`.
465      A tensor of updated values to add to ref.
466    use_locking: An optional `bool`. Defaults to `False`.
467      If True, the assignment will be protected by a lock;
468      otherwise the behavior is undefined, but may exhibit less contention.
469    name: A name for the operation (optional).
470
471  Returns:
472    A mutable `Tensor`. Has the same type as `ref`.
473  """
474  if ref.dtype._is_ref_dtype:
475    return gen_state_ops.scatter_nd_add(
476        ref, indices, updates, use_locking, name)
477  return ref._lazy_read(gen_state_ops.resource_scatter_nd_add(  # pylint: disable=protected-access
478      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
479      name=name))
480
481
482@tf_export(v1=["scatter_sub"])
483def scatter_sub(ref, indices, updates, use_locking=False, name=None):
484  r"""Subtracts sparse updates to a variable reference.
485
486  ```python
487      # Scalar indices
488      ref[indices, ...] -= updates[...]
489
490      # Vector indices (for each i)
491      ref[indices[i], ...] -= updates[i, ...]
492
493      # High rank indices (for each i, ..., j)
494      ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...]
495  ```
496
497  This operation outputs `ref` after the update is done.
498  This makes it easier to chain operations that need to use the reset value.
499
500  Duplicate entries are handled correctly: if multiple `indices` reference
501  the same location, their (negated) contributions add.
502
503  Requires `updates.shape = indices.shape + ref.shape[1:]` or
504  `updates.shape = []`.
505
506  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
507  <img style="width:100%"
508       src="https://www.tensorflow.org/images/ScatterSub.png" alt>
509  </div>
510
511  Args:
512    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
513      `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
514      `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
515      `uint32`, `uint64`. Should be from a `Variable` node.
516    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
517      A tensor of indices into the first dimension of `ref`.
518    updates: A `Tensor`. Must have the same type as `ref`.
519      A tensor of updated values to subtract from `ref`.
520    use_locking: An optional `bool`. Defaults to `False`.
521      If True, the subtraction will be protected by a lock;
522      otherwise the behavior is undefined, but may exhibit less contention.
523    name: A name for the operation (optional).
524
525  Returns:
526    A mutable `Tensor`. Has the same type as `ref`.
527  """
528  if ref.dtype._is_ref_dtype:
529    return gen_state_ops.scatter_sub(ref, indices, updates,
530                                     use_locking=use_locking, name=name)
531  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub(  # pylint: disable=protected-access
532      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
533      name=name))
534
535
536@tf_export(v1=["scatter_nd_sub"])
537def scatter_nd_sub(ref, indices, updates, use_locking=False, name=None):
538  r"""Applies sparse subtraction to individual values or slices in a Variable.
539
540  `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
541
542  `indices` must be integer tensor, containing indices into `ref`.
543  It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
544
545  The innermost dimension of `indices` (with length `K`) corresponds to
546  indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
547  dimension of `ref`.
548
549  `updates` is `Tensor` of rank `Q-1+P-K` with shape:
550
551  ```
552  [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]
553  ```
554
555  For example, say we want to subtract 4 scattered elements from a rank-1 tensor
556  with 8 elements. In Python, that update would look like this:
557
558  ```python
559  ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
560  indices = tf.constant([[4], [3], [1] ,[7]])
561  updates = tf.constant([9, 10, 11, 12])
562  op = tf.scatter_nd_sub(ref, indices, updates)
563  with tf.Session() as sess:
564    print sess.run(op)
565  ```
566
567  The resulting update to ref would look like this:
568
569      [1, -9, 3, -6, -6, 6, 7, -4]
570
571  See `tf.scatter_nd` for more details about how to make updates to
572  slices.
573
574  Args:
575    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
576      `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
577      `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
578      `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node.
579    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
580      A tensor of indices into ref.
581    updates: A `Tensor`. Must have the same type as `ref`.
582      A tensor of updated values to add to ref.
583    use_locking: An optional `bool`. Defaults to `False`.
584      An optional bool. Defaults to True. If True, the assignment will
585      be protected by a lock; otherwise the behavior is undefined,
586      but may exhibit less contention.
587    name: A name for the operation (optional).
588
589  Returns:
590    A mutable `Tensor`. Has the same type as `ref`.
591  """
592  if ref.dtype._is_ref_dtype:
593    return gen_state_ops.scatter_nd_sub(
594        ref, indices, updates, use_locking, name)
595  return ref._lazy_read(gen_state_ops.resource_scatter_nd_sub(  # pylint: disable=protected-access
596      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
597      name=name))
598
599
600@tf_export(v1=["scatter_mul"])
601def scatter_mul(ref, indices, updates, use_locking=False, name=None):
602  # pylint: disable=line-too-long
603  r"""Multiplies sparse updates into a variable reference.
604
605  This operation computes
606
607  ```python
608      # Scalar indices
609      ref[indices, ...] *= updates[...]
610
611      # Vector indices (for each i)
612      ref[indices[i], ...] *= updates[i, ...]
613
614      # High rank indices (for each i, ..., j)
615      ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...]
616  ```
617
618  This operation outputs `ref` after the update is done.
619  This makes it easier to chain operations that need to use the reset value.
620
621  Duplicate entries are handled correctly: if multiple `indices` reference
622  the same location, their contributions multiply.
623
624  Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
625  []`.
626
627  Args:
628    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
629      `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
630      `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
631      `uint32`, `uint64`. Should be from a `Variable` node.
632    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
633      tensor of indices into the first dimension of `ref`.
634    updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated
635      values to multiply to `ref`.
636    use_locking: An optional `bool`. Defaults to `False`. If True, the operation
637      will be protected by a lock; otherwise the behavior is undefined, but may
638      exhibit less contention.
639    name: A name for the operation (optional).
640
641  Returns:
642    A mutable `Tensor`. Has the same type as `ref`.
643  """
644  return gen_state_ops.scatter_mul(
645      ref=ref,
646      indices=indices,
647      updates=updates,
648      use_locking=use_locking,
649      name=name)
650
651
652@tf_export(v1=["scatter_div"])
653def scatter_div(ref, indices, updates, use_locking=False, name=None):
654  # pylint: disable=line-too-long
655  r"""Divides a variable reference by sparse updates.
656
657  This operation computes
658
659  ```python
660      # Scalar indices
661      ref[indices, ...] /= updates[...]
662
663      # Vector indices (for each i)
664      ref[indices[i], ...] /= updates[i, ...]
665
666      # High rank indices (for each i, ..., j)
667      ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...]
668  ```
669
670  This operation outputs `ref` after the update is done.
671  This makes it easier to chain operations that need to use the reset value.
672
673  Duplicate entries are handled correctly: if multiple `indices` reference
674  the same location, their contributions divide.
675
676  Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
677  []`.
678
679  Args:
680    ref: A mutable `Tensor`. Must be one of the following types: `float32`,
681      `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
682      `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
683      `uint32`, `uint64`. Should be from a `Variable` node.
684    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
685      tensor of indices into the first dimension of `ref`.
686    updates: A `Tensor`. Must have the same type as `ref`. A tensor of values
687      that `ref` is divided by.
688    use_locking: An optional `bool`. Defaults to `False`. If True, the operation
689      will be protected by a lock; otherwise the behavior is undefined, but may
690      exhibit less contention.
691    name: A name for the operation (optional).
692
693  Returns:
694    A mutable `Tensor`. Has the same type as `ref`.
695  """
696  return gen_state_ops.scatter_div(
697      ref=ref,
698      indices=indices,
699      updates=updates,
700      use_locking=use_locking,
701      name=name)
702
703
704@tf_export(v1=["scatter_max"])
705def scatter_max(ref, indices, updates, use_locking=False, name=None):
706  # pylint: disable=line-too-long
707  r"""Reduces sparse updates into a variable reference using the `max` operation.
708
709  This operation computes
710
711      # Scalar indices
712      ref[indices, ...] = max(ref[indices, ...], updates[...])
713
714      # Vector indices (for each i)
715      ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...])
716
717      # High rank indices (for each i, ..., j)
718      ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...],
719      updates[i, ..., j, ...])
720
721  This operation outputs `ref` after the update is done.
722  This makes it easier to chain operations that need to use the reset value.
723
724  Duplicate entries are handled correctly: if multiple `indices` reference
725  the same location, their contributions combine.
726
727  Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
728  []`.
729
730  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
731  <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png"
732  alt>
733  </div>
734
735  Args:
736    ref: A mutable `Tensor`. Must be one of the following types: `half`,
737      `bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a
738      `Variable` node.
739    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
740      tensor of indices into the first dimension of `ref`.
741    updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated
742      values to reduce into `ref`.
743    use_locking: An optional `bool`. Defaults to `False`. If True, the update
744      will be protected by a lock; otherwise the behavior is undefined, but may
745      exhibit less contention.
746    name: A name for the operation (optional).
747
748  Returns:
749    A mutable `Tensor`. Has the same type as `ref`.
750  """
751  return gen_state_ops.scatter_max(
752      ref=ref,
753      indices=indices,
754      updates=updates,
755      use_locking=use_locking,
756      name=name)
757
758
759@tf_export(v1=["scatter_min"])
760def scatter_min(ref, indices, updates, use_locking=False, name=None):
761  # pylint: disable=line-too-long
762  r"""Reduces sparse updates into a variable reference using the `min` operation.
763
764  This operation computes
765
766      # Scalar indices
767      ref[indices, ...] = min(ref[indices, ...], updates[...])
768
769      # Vector indices (for each i)
770      ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...])
771
772      # High rank indices (for each i, ..., j)
773      ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...],
774      updates[i, ..., j, ...])
775
776  This operation outputs `ref` after the update is done.
777  This makes it easier to chain operations that need to use the reset value.
778
779  Duplicate entries are handled correctly: if multiple `indices` reference
780  the same location, their contributions combine.
781
782  Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
783  []`.
784
785  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
786  <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png"
787  alt>
788  </div>
789
790  Args:
791    ref: A mutable `Tensor`. Must be one of the following types: `half`,
792      `bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a
793      `Variable` node.
794    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
795      tensor of indices into the first dimension of `ref`.
796    updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated
797      values to reduce into `ref`.
798    use_locking: An optional `bool`. Defaults to `False`. If True, the update
799      will be protected by a lock; otherwise the behavior is undefined, but may
800      exhibit less contention.
801    name: A name for the operation (optional).
802
803  Returns:
804    A mutable `Tensor`. Has the same type as `ref`.
805  """
806  return gen_state_ops.scatter_min(
807      ref=ref,
808      indices=indices,
809      updates=updates,
810      use_locking=use_locking,
811      name=name)
812
813
814@tf_export(v1=["batch_scatter_update"])
815@deprecation.deprecated(
816    "2018-11-29", "Use the batch_scatter_update method of Variable instead.")
817def batch_scatter_update(ref, indices, updates, use_locking=True, name=None):
818  """Generalization of `tf.scatter_update` to axis different than 0.
819
820  Analogous to `batch_gather`. This assumes that `ref`, `indices` and `updates`
821  have a series of leading dimensions that are the same for all of them, and the
822  updates are performed on the last dimension of indices. In other words, the
823  dimensions should be the following:
824
825  `num_prefix_dims = indices.ndims - 1`
826  `batch_dim = num_prefix_dims + 1`
827  `updates.shape = indices.shape + var.shape[batch_dim:]`
828
829  where
830
831  `updates.shape[:num_prefix_dims]`
832  `== indices.shape[:num_prefix_dims]`
833  `== var.shape[:num_prefix_dims]`
834
835  And the operation performed can be expressed as:
836
837  `var[i_1, ..., i_n, indices[i_1, ..., i_n, j]] = updates[i_1, ..., i_n, j]`
838
839  When indices is a 1D tensor, this operation is equivalent to
840  `tf.scatter_update`.
841
842  To avoid this operation there would be 2 alternatives:
843  1) Reshaping the variable by merging the first `ndims` dimensions. However,
844     this is not possible because `tf.reshape` returns a Tensor, which we
845     cannot use `tf.scatter_update` on.
846  2) Looping over the first `ndims` of the variable and using
847     `tf.scatter_update` on the subtensors that result of slicing the first
848     dimension. This is a valid option for `ndims = 1`, but less efficient than
849     this implementation.
850
851  See also `tf.scatter_update` and `tf.scatter_nd_update`.
852
853  Args:
854    ref: `Variable` to scatter onto.
855    indices: Tensor containing indices as described above.
856    updates: Tensor of updates to apply to `ref`.
857    use_locking: Boolean indicating whether to lock the writing operation.
858    name: Optional scope name string.
859
860  Returns:
861    Ref to `variable` after it has been modified.
862
863  Raises:
864    ValueError: If the initial `ndims` of `ref`, `indices`, and `updates` are
865        not the same.
866  """
867  with ops.name_scope(name):
868    indices = ops.convert_to_tensor(indices, name="indices")
869    indices_shape = array_ops.shape(indices)
870    indices_dimensions = indices.get_shape().ndims
871
872    if indices_dimensions is None:
873      raise ValueError("batch_gather does not allow indices with unknown "
874                       "shape.")
875
876    nd_indices = array_ops.expand_dims(indices, axis=-1)
877    nd_indices_list = []
878
879    # Scatter ND requires indices to have an additional dimension, in which the
880    # coordinates of the updated things are specified. For this to be adapted to
881    # the scatter_update with several leading dimensions, we simply make use of
882    # a tf.range for all the leading dimensions followed by concat of all the
883    # coordinates we created with the original indices.
884
885    # For example if indices.shape = [2, 3, 4], we should generate the following
886    # indices for tf.scatter_nd_update:
887    # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]]
888    # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]]
889    # nd_indices[:, :, 2] = indices
890    for dimension in range(indices_dimensions - 1):
891      # In this loop we generate the following for the example (one for each
892      # iteration).
893      # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]]
894      # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]]
895      # This is done at every iteration with a tf.range over the size of the
896      # i-th dimension and using broadcasting over the desired shape.
897      dimension_size = indices_shape[dimension]
898      shape_to_broadcast = [1] * (indices_dimensions + 1)
899      shape_to_broadcast[dimension] = dimension_size
900      dimension_range = array_ops.reshape(
901          gen_math_ops._range(0, dimension_size, 1), shape_to_broadcast)
902      if dimension_range.dtype.base_dtype != nd_indices.dtype:
903        dimension_range = gen_math_ops.cast(dimension_range, nd_indices.dtype)
904      nd_indices_list.append(
905          dimension_range * array_ops.ones_like(nd_indices))
906    # Add the original indices at the end, as described above, and concat.
907    nd_indices_list.append(nd_indices)
908    final_indices = array_ops.concat(nd_indices_list, axis=-1)
909    return scatter_nd_update(
910        ref, final_indices, updates, use_locking=use_locking)
911