• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Variable class."""
16
17import enum
18import functools
19import itertools
20import os
21
22from tensorflow.core.framework import attr_value_pb2
23from tensorflow.core.framework import variable_pb2
24from tensorflow.python import pywrap_tensorflow  # pylint: disable=unused-import
25from tensorflow.python.eager import context
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import indexed_slices
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import control_flow_ops
32from tensorflow.python.ops import gen_array_ops
33from tensorflow.python.ops import gen_math_ops
34from tensorflow.python.ops import gen_state_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import state_ops
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.trackable import base as trackable
39from tensorflow.python.types import core
40from tensorflow.python.util import _pywrap_utils
41from tensorflow.python.util import compat
42from tensorflow.python.util import object_identity
43from tensorflow.python.util import tf_should_use
44from tensorflow.python.util import traceback_utils
45from tensorflow.python.util.deprecation import deprecated
46from tensorflow.python.util.deprecation import deprecated_args
47from tensorflow.python.util import traceback_utils
48from tensorflow.python.util.tf_export import tf_export
49
50
51def default_variable_creator(_, **kwds):
52  del kwds
53  raise NotImplementedError("variable_scope needs to be imported")
54
55
56def default_variable_creator_v2(_, **kwds):
57  del kwds
58  raise NotImplementedError("variable_scope needs to be imported")
59
60
61def _make_getter(captured_getter, captured_previous):
62  """To avoid capturing loop variables."""
63
64  def getter(**kwargs):
65    return captured_getter(captured_previous, **kwargs)
66
67  return getter
68
69
70@tf_export("VariableSynchronization")
71class VariableSynchronization(enum.Enum):
72  """Indicates when a distributed variable will be synced.
73
74  * `AUTO`: Indicates that the synchronization will be determined by the current
75    `DistributionStrategy` (eg. With `MirroredStrategy` this would be
76    `ON_WRITE`).
77  * `NONE`: Indicates that there will only be one copy of the variable, so
78    there is no need to sync.
79  * `ON_WRITE`: Indicates that the variable will be updated across devices
80    every time it is written.
81  * `ON_READ`: Indicates that the variable will be aggregated across devices
82    when it is read (eg. when checkpointing or when evaluating an op that uses
83    the variable).
84
85    Example:
86  >>> temp_grad=[tf.Variable([0.], trainable=False,
87  ...                      synchronization=tf.VariableSynchronization.ON_READ,
88  ...                      aggregation=tf.VariableAggregation.MEAN
89  ...                      )]
90  """
91  AUTO = 0
92  NONE = 1
93  ON_WRITE = 2
94  ON_READ = 3
95
96
97# LINT.IfChange
98@tf_export("VariableAggregation", v1=[])
99class VariableAggregationV2(enum.Enum):
100  """Indicates how a distributed variable will be aggregated.
101
102  `tf.distribute.Strategy` distributes a model by making multiple copies
103  (called "replicas") acting data-parallel on different elements of the input
104  batch. When performing some variable-update operation, say
105  `var.assign_add(x)`, in a model, we need to resolve how to combine the
106  different values for `x` computed in the different replicas.
107
108  * `NONE`: This is the default, giving an error if you use a
109    variable-update operation with multiple replicas.
110  * `SUM`: Add the updates across replicas.
111  * `MEAN`: Take the arithmetic mean ("average") of the updates across replicas.
112  * `ONLY_FIRST_REPLICA`: This is for when every replica is performing the same
113    update, but we only want to perform the update once. Used, e.g., for the
114    global step counter.
115  """
116  NONE = 0
117  SUM = 1
118  MEAN = 2
119  ONLY_FIRST_REPLICA = 3
120
121  def __hash__(self):
122    return hash(self.value)
123
124  def __eq__(self, other):
125    if self is other:
126      return True
127    elif isinstance(other, VariableAggregation):
128      return int(self.value) == int(other.value)
129    else:
130      return False
131
132
133@tf_export(v1=["VariableAggregation"])
134class VariableAggregation(enum.Enum):
135  NONE = 0
136  SUM = 1
137  MEAN = 2
138  ONLY_FIRST_REPLICA = 3
139  ONLY_FIRST_TOWER = 3  # DEPRECATED
140
141  def __hash__(self):
142    return hash(self.value)
143
144
145# LINT.ThenChange(//tensorflow/core/framework/variable.proto)
146#
147# Note that we are currently relying on the integer values of the Python enums
148# matching the integer values of the proto enums.
149
150VariableAggregation.__doc__ = (
151    VariableAggregationV2.__doc__ +
152    "* `ONLY_FIRST_TOWER`: Deprecated alias for `ONLY_FIRST_REPLICA`.\n  ")
153
154
155def validate_synchronization_aggregation_trainable(synchronization, aggregation,
156                                                   trainable, name):
157  """Given user-provided variable properties, sets defaults and validates."""
158  if aggregation is None:
159    aggregation = VariableAggregation.NONE
160  else:
161    if not isinstance(aggregation,
162                      (VariableAggregation, VariableAggregationV2)):
163      try:
164        aggregation = VariableAggregationV2(aggregation)
165      except ValueError:
166        raise ValueError(
167            "Invalid variable aggregation mode: {} for variable: {}".format(
168                aggregation, name))
169  if synchronization is None:
170    synchronization = VariableSynchronization.AUTO
171  else:
172    try:
173      synchronization = VariableSynchronization(synchronization)
174    except ValueError:
175      raise ValueError(
176          "Invalid variable synchronization mode: {} for variable: {}".format(
177              synchronization, name))
178  if trainable is None:
179    trainable = synchronization != VariableSynchronization.ON_READ
180  return synchronization, aggregation, trainable
181
182
183class VariableMetaclass(type):
184  """Metaclass to allow construction of tf.Variable to be overridden."""
185
186  def _variable_v1_call(cls,
187                        initial_value=None,
188                        trainable=None,
189                        collections=None,
190                        validate_shape=True,
191                        caching_device=None,
192                        name=None,
193                        variable_def=None,
194                        dtype=None,
195                        expected_shape=None,
196                        import_scope=None,
197                        constraint=None,
198                        use_resource=None,
199                        synchronization=VariableSynchronization.AUTO,
200                        aggregation=VariableAggregation.NONE,
201                        shape=None):
202    """Call on Variable class. Useful to force the signature."""
203    previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
204    for _, getter in ops.get_default_graph()._variable_creator_stack:  # pylint: disable=protected-access
205      previous_getter = _make_getter(getter, previous_getter)
206
207    # Reset `aggregation` that is explicitly set as `None` to the enum NONE.
208    if aggregation is None:
209      aggregation = VariableAggregation.NONE
210    return previous_getter(
211        initial_value=initial_value,
212        trainable=trainable,
213        collections=collections,
214        validate_shape=validate_shape,
215        caching_device=caching_device,
216        name=name,
217        variable_def=variable_def,
218        dtype=dtype,
219        expected_shape=expected_shape,
220        import_scope=import_scope,
221        constraint=constraint,
222        use_resource=use_resource,
223        synchronization=synchronization,
224        aggregation=aggregation,
225        shape=shape)
226
227  def _variable_v2_call(cls,
228                        initial_value=None,
229                        trainable=None,
230                        validate_shape=True,
231                        caching_device=None,
232                        name=None,
233                        variable_def=None,
234                        dtype=None,
235                        import_scope=None,
236                        constraint=None,
237                        synchronization=VariableSynchronization.AUTO,
238                        aggregation=VariableAggregation.NONE,
239                        shape=None):
240    """Call on Variable class. Useful to force the signature."""
241    previous_getter = lambda **kws: default_variable_creator_v2(None, **kws)
242    for _, getter in ops.get_default_graph()._variable_creator_stack:  # pylint: disable=protected-access
243      previous_getter = _make_getter(getter, previous_getter)
244
245    # Reset `aggregation` that is explicitly set as `None` to the enum NONE.
246    if aggregation is None:
247      aggregation = VariableAggregation.NONE
248    return previous_getter(
249        initial_value=initial_value,
250        trainable=trainable,
251        validate_shape=validate_shape,
252        caching_device=caching_device,
253        name=name,
254        variable_def=variable_def,
255        dtype=dtype,
256        import_scope=import_scope,
257        constraint=constraint,
258        synchronization=synchronization,
259        aggregation=aggregation,
260        shape=shape)
261
262  @traceback_utils.filter_traceback
263  def __call__(cls, *args, **kwargs):
264    if cls is VariableV1:
265      return cls._variable_v1_call(*args, **kwargs)
266    elif cls is Variable:
267      return cls._variable_v2_call(*args, **kwargs)
268    else:
269      return super(VariableMetaclass, cls).__call__(*args, **kwargs)
270
271
272@tf_export("Variable", v1=[])
273# TODO(mdan): This should subclass core.Tensor, and not all its subclasses?
274class Variable(trackable.Trackable, metaclass=VariableMetaclass):
275  """See the [variable guide](https://tensorflow.org/guide/variable).
276
277  A variable maintains shared, persistent state manipulated by a program.
278
279  The `Variable()` constructor requires an initial value for the variable, which
280  can be a `Tensor` of any type and shape. This initial value defines the type
281  and shape of the variable. After construction, the type and shape of the
282  variable are fixed. The value can be changed using one of the assign methods.
283
284  >>> v = tf.Variable(1.)
285  >>> v.assign(2.)
286  <tf.Variable ... shape=() dtype=float32, numpy=2.0>
287  >>> v.assign_add(0.5)
288  <tf.Variable ... shape=() dtype=float32, numpy=2.5>
289
290  The `shape` argument to `Variable`'s constructor allows you to construct a
291  variable with a less defined shape than its `initial_value`:
292
293  >>> v = tf.Variable(1., shape=tf.TensorShape(None))
294  >>> v.assign([[1.]])
295  <tf.Variable ... shape=<unknown> dtype=float32, numpy=array([[1.]], ...)>
296
297  Just like any `Tensor`, variables created with `Variable()` can be used as
298  inputs to operations. Additionally, all the operators overloaded for the
299  `Tensor` class are carried over to variables.
300
301  >>> w = tf.Variable([[1.], [2.]])
302  >>> x = tf.constant([[3., 4.]])
303  >>> tf.matmul(w, x)
304  <tf.Tensor:... shape=(2, 2), ... numpy=
305    array([[3., 4.],
306           [6., 8.]], dtype=float32)>
307  >>> tf.sigmoid(w + x)
308  <tf.Tensor:... shape=(2, 2), ...>
309
310  When building a machine learning model it is often convenient to distinguish
311  between variables holding trainable model parameters and other variables such
312  as a `step` variable used to count training steps. To make this easier, the
313  variable constructor supports a `trainable=<bool>`
314  parameter. `tf.GradientTape` watches trainable variables by default:
315
316  >>> with tf.GradientTape(persistent=True) as tape:
317  ...   trainable = tf.Variable(1.)
318  ...   non_trainable = tf.Variable(2., trainable=False)
319  ...   x1 = trainable * 2.
320  ...   x2 = non_trainable * 3.
321  >>> tape.gradient(x1, trainable)
322  <tf.Tensor:... shape=(), dtype=float32, numpy=2.0>
323  >>> assert tape.gradient(x2, non_trainable) is None  # Unwatched
324
325  Variables are automatically tracked when assigned to attributes of types
326  inheriting from `tf.Module`.
327
328  >>> m = tf.Module()
329  >>> m.v = tf.Variable([1.])
330  >>> m.trainable_variables
331  (<tf.Variable ... shape=(1,) ... numpy=array([1.], dtype=float32)>,)
332
333  This tracking then allows saving variable values to
334  [training checkpoints](https://www.tensorflow.org/guide/checkpoint), or to
335  [SavedModels](https://www.tensorflow.org/guide/saved_model) which include
336  serialized TensorFlow graphs.
337
338  Variables are often captured and manipulated by `tf.function`s. This works the
339  same way the un-decorated function would have:
340
341  >>> v = tf.Variable(0.)
342  >>> read_and_decrement = tf.function(lambda: v.assign_sub(0.1))
343  >>> read_and_decrement()
344  <tf.Tensor: shape=(), dtype=float32, numpy=-0.1>
345  >>> read_and_decrement()
346  <tf.Tensor: shape=(), dtype=float32, numpy=-0.2>
347
348  Variables created inside a `tf.function` must be owned outside the function
349  and be created only once:
350
351  >>> class M(tf.Module):
352  ...   @tf.function
353  ...   def __call__(self, x):
354  ...     if not hasattr(self, "v"):  # Or set self.v to None in __init__
355  ...       self.v = tf.Variable(x)
356  ...     return self.v * x
357  >>> m = M()
358  >>> m(2.)
359  <tf.Tensor: shape=(), dtype=float32, numpy=4.0>
360  >>> m(3.)
361  <tf.Tensor: shape=(), dtype=float32, numpy=6.0>
362  >>> m.v
363  <tf.Variable ... shape=() dtype=float32, numpy=2.0>
364
365  See the `tf.function` documentation for details.
366  """
367
368  @deprecated_args(
369      None, "A variable's value can be manually cached by calling "
370      "tf.Variable.read_value() under a tf.device scope. The caching_device "
371      "argument does not work properly.", "caching_device")
372  def __init__(self,
373               initial_value=None,
374               trainable=None,
375               validate_shape=True,
376               caching_device=None,
377               name=None,
378               variable_def=None,
379               dtype=None,
380               import_scope=None,
381               constraint=None,
382               synchronization=VariableSynchronization.AUTO,
383               aggregation=VariableAggregation.NONE,
384               shape=None):
385    """Creates a new variable with value `initial_value`.
386
387    Args:
388      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
389        which is the initial value for the Variable. The initial value must have
390        a shape specified unless `validate_shape` is set to False. Can also be a
391        callable with no argument that returns the initial value when called. In
392        that case, `dtype` must be specified. (Note that initializer functions
393        from init_ops.py must first be bound to a shape before being used here.)
394      trainable: If `True`, GradientTapes automatically watch uses of this
395        variable. Defaults to `True`, unless `synchronization` is set to
396        `ON_READ`, in which case it defaults to `False`.
397      validate_shape: If `False`, allows the variable to be initialized with a
398        value of unknown shape. If `True`, the default, the shape of
399        `initial_value` must be known.
400      caching_device: Note: This argument is only valid when using a v1-style
401        `Session`. Optional device string describing where the Variable should
402        be cached for reading. Defaults to the Variable's device. If not `None`,
403        caches on another device. Typical use is to cache on the device where
404        the Ops using the Variable reside, to deduplicate copying through
405        `Switch` and other conditional statements.
406      name: Optional name for the variable. Defaults to `'Variable'` and gets
407        uniquified automatically.
408      variable_def: `VariableDef` protocol buffer. If not `None`, recreates the
409        Variable object with its contents, referencing the variable's nodes in
410        the graph, which must already exist. The graph is not changed.
411        `variable_def` and the other arguments are mutually exclusive.
412      dtype: If set, initial_value will be converted to the given type. If
413        `None`, either the datatype will be kept (if `initial_value` is a
414        Tensor), or `convert_to_tensor` will decide.
415      import_scope: Optional `string`. Name scope to add to the `Variable.` Only
416        used when initializing from protocol buffer.
417      constraint: An optional projection function to be applied to the variable
418        after being updated by an `Optimizer` (e.g. used to implement norm
419        constraints or value constraints for layer weights). The function must
420        take as input the unprojected Tensor representing the value of the
421        variable and return the Tensor for the projected value (which must have
422        the same shape). Constraints are not safe to use when doing asynchronous
423        distributed training.
424      synchronization: Indicates when a distributed a variable will be
425        aggregated. Accepted values are constants defined in the class
426        `tf.VariableSynchronization`. By default the synchronization is set to
427        `AUTO` and the current `DistributionStrategy` chooses when to
428        synchronize.
429      aggregation: Indicates how a distributed variable will be aggregated.
430        Accepted values are constants defined in the class
431        `tf.VariableAggregation`.
432      shape: (optional) The shape of this variable. If None, the shape of
433        `initial_value` will be used. When setting this argument to
434        `tf.TensorShape(None)` (representing an unspecified shape), the variable
435        can be assigned with values of different shapes.
436
437    Raises:
438      ValueError: If both `variable_def` and initial_value are specified.
439      ValueError: If the initial value is not specified, or does not have a
440        shape and `validate_shape` is `True`.
441    """
442    raise NotImplementedError
443
444  def __repr__(self):
445    raise NotImplementedError
446
447  def value(self):
448    """Returns the last snapshot of this variable.
449
450    You usually do not need to call this method as all ops that need the value
451    of the variable call it automatically through a `convert_to_tensor()` call.
452
453    Returns a `Tensor` which holds the value of the variable.  You can not
454    assign a new value to this tensor as it is not a reference to the variable.
455
456    To avoid copies, if the consumer of the returned value is on the same device
457    as the variable, this actually returns the live value of the variable, not
458    a copy.  Updates to the variable are seen by the consumer.  If the consumer
459    is on a different device it will get a copy of the variable.
460
461    Returns:
462      A `Tensor` containing the value of the variable.
463    """
464    raise NotImplementedError
465
466  def read_value(self):
467    """Returns the value of this variable, read in the current context.
468
469    Can be different from value() if it's on another device, with control
470    dependencies, etc.
471
472    Returns:
473      A `Tensor` containing the value of the variable.
474    """
475    raise NotImplementedError
476
477  def set_shape(self, shape):
478    """Overrides the shape for this variable.
479
480    Args:
481      shape: the `TensorShape` representing the overridden shape.
482    """
483    raise NotImplementedError
484
485  @property
486  def trainable(self):
487    raise NotImplementedError
488
489  @property
490  def synchronization(self):
491    raise NotImplementedError
492
493  @property
494  def aggregation(self):
495    raise NotImplementedError
496
497  def eval(self, session=None):
498    """In a session, computes and returns the value of this variable.
499
500    This is not a graph construction method, it does not add ops to the graph.
501
502    This convenience method requires a session where the graph
503    containing this variable has been launched. If no session is
504    passed, the default session is used.  See `tf.compat.v1.Session` for more
505    information on launching a graph and on sessions.
506
507    ```python
508    v = tf.Variable([1, 2])
509    init = tf.compat.v1.global_variables_initializer()
510
511    with tf.compat.v1.Session() as sess:
512        sess.run(init)
513        # Usage passing the session explicitly.
514        print(v.eval(sess))
515        # Usage with the default session.  The 'with' block
516        # above makes 'sess' the default session.
517        print(v.eval())
518    ```
519
520    Args:
521      session: The session to use to evaluate this variable. If none, the
522        default session is used.
523
524    Returns:
525      A numpy `ndarray` with a copy of the value of this variable.
526    """
527    raise NotImplementedError
528
529  @deprecated(
530      None, "Use Variable.read_value. Variables in 2.X are initialized "
531      "automatically both in eager and graph (inside tf.defun) contexts.")
532  def initialized_value(self):
533    """Returns the value of the initialized variable.
534
535    You should use this instead of the variable itself to initialize another
536    variable with a value that depends on the value of this variable.
537
538    ```python
539    # Initialize 'v' with a random tensor.
540    v = tf.Variable(tf.random.truncated_normal([10, 40]))
541    # Use `initialized_value` to guarantee that `v` has been
542    # initialized before its value is used to initialize `w`.
543    # The random values are picked only once.
544    w = tf.Variable(v.initialized_value() * 2.0)
545    ```
546
547    Returns:
548      A `Tensor` holding the value of this variable after its initializer
549      has run.
550    """
551    with ops.init_scope():
552      return control_flow_ops.cond(
553          is_variable_initialized(self), self.read_value,
554          lambda: self.initial_value)
555
556  @property
557  def initial_value(self):
558    """Returns the Tensor used as the initial value for the variable.
559
560    Note that this is different from `initialized_value()` which runs
561    the op that initializes the variable before returning its value.
562    This method returns the tensor that is used by the op that initializes
563    the variable.
564
565    Returns:
566      A `Tensor`.
567    """
568    raise NotImplementedError
569
570  @property
571  def constraint(self):
572    """Returns the constraint function associated with this variable.
573
574    Returns:
575      The constraint function that was passed to the variable constructor.
576      Can be `None` if no constraint was passed.
577    """
578    raise NotImplementedError
579
580  def assign(self, value, use_locking=False, name=None, read_value=True):
581    """Assigns a new value to the variable.
582
583    This is essentially a shortcut for `assign(self, value)`.
584
585    Args:
586      value: A `Tensor`. The new value for this variable.
587      use_locking: If `True`, use locking during the assignment.
588      name: The name of the operation to be created
589      read_value: if True, will return something which evaluates to the new
590        value of the variable; if False will return the assign op.
591
592    Returns:
593      The updated variable. If `read_value` is false, instead returns None in
594      Eager mode and the assign op in graph mode.
595    """
596    raise NotImplementedError
597
598  def assign_add(self, delta, use_locking=False, name=None, read_value=True):
599    """Adds a value to this variable.
600
601     This is essentially a shortcut for `assign_add(self, delta)`.
602
603    Args:
604      delta: A `Tensor`. The value to add to this variable.
605      use_locking: If `True`, use locking during the operation.
606      name: The name of the operation to be created
607      read_value: if True, will return something which evaluates to the new
608        value of the variable; if False will return the assign op.
609
610    Returns:
611      The updated variable. If `read_value` is false, instead returns None in
612      Eager mode and the assign op in graph mode.
613    """
614    raise NotImplementedError
615
616  def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
617    """Subtracts a value from this variable.
618
619    This is essentially a shortcut for `assign_sub(self, delta)`.
620
621    Args:
622      delta: A `Tensor`. The value to subtract from this variable.
623      use_locking: If `True`, use locking during the operation.
624      name: The name of the operation to be created
625      read_value: if True, will return something which evaluates to the new
626        value of the variable; if False will return the assign op.
627
628    Returns:
629      The updated variable. If `read_value` is false, instead returns None in
630      Eager mode and the assign op in graph mode.
631    """
632    raise NotImplementedError
633
634  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
635    """Subtracts `tf.IndexedSlices` from this variable.
636
637    Args:
638      sparse_delta: `tf.IndexedSlices` to be subtracted from this variable.
639      use_locking: If `True`, use locking during the operation.
640      name: the name of the operation.
641
642    Returns:
643      The updated variable.
644
645    Raises:
646      TypeError: if `sparse_delta` is not an `IndexedSlices`.
647    """
648    raise NotImplementedError
649
650  def scatter_add(self, sparse_delta, use_locking=False, name=None):
651    """Adds `tf.IndexedSlices` to this variable.
652
653    Args:
654      sparse_delta: `tf.IndexedSlices` to be added to this variable.
655      use_locking: If `True`, use locking during the operation.
656      name: the name of the operation.
657
658    Returns:
659      The updated variable.
660
661    Raises:
662      TypeError: if `sparse_delta` is not an `IndexedSlices`.
663    """
664    raise NotImplementedError
665
666  def scatter_max(self, sparse_delta, use_locking=False, name=None):
667    """Updates this variable with the max of `tf.IndexedSlices` and itself.
668
669    Args:
670      sparse_delta: `tf.IndexedSlices` to use as an argument of max with this
671        variable.
672      use_locking: If `True`, use locking during the operation.
673      name: the name of the operation.
674
675    Returns:
676      The updated variable.
677
678    Raises:
679      TypeError: if `sparse_delta` is not an `IndexedSlices`.
680    """
681    raise NotImplementedError
682
683  def scatter_min(self, sparse_delta, use_locking=False, name=None):
684    """Updates this variable with the min of `tf.IndexedSlices` and itself.
685
686    Args:
687      sparse_delta: `tf.IndexedSlices` to use as an argument of min with this
688        variable.
689      use_locking: If `True`, use locking during the operation.
690      name: the name of the operation.
691
692    Returns:
693      The updated variable.
694
695    Raises:
696      TypeError: if `sparse_delta` is not an `IndexedSlices`.
697    """
698    raise NotImplementedError
699
700  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
701    """Multiply this variable by `tf.IndexedSlices`.
702
703    Args:
704      sparse_delta: `tf.IndexedSlices` to multiply this variable by.
705      use_locking: If `True`, use locking during the operation.
706      name: the name of the operation.
707
708    Returns:
709      The updated variable.
710
711    Raises:
712      TypeError: if `sparse_delta` is not an `IndexedSlices`.
713    """
714    raise NotImplementedError
715
716  def scatter_div(self, sparse_delta, use_locking=False, name=None):
717    """Divide this variable by `tf.IndexedSlices`.
718
719    Args:
720      sparse_delta: `tf.IndexedSlices` to divide this variable by.
721      use_locking: If `True`, use locking during the operation.
722      name: the name of the operation.
723
724    Returns:
725      The updated variable.
726
727    Raises:
728      TypeError: if `sparse_delta` is not an `IndexedSlices`.
729    """
730    raise NotImplementedError
731
732  def scatter_update(self, sparse_delta, use_locking=False, name=None):
733    """Assigns `tf.IndexedSlices` to this variable.
734
735    Args:
736      sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
737      use_locking: If `True`, use locking during the operation.
738      name: the name of the operation.
739
740    Returns:
741      The updated variable.
742
743    Raises:
744      TypeError: if `sparse_delta` is not an `IndexedSlices`.
745    """
746    raise NotImplementedError
747
748  def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
749    """Assigns `tf.IndexedSlices` to this variable batch-wise.
750
751    Analogous to `batch_gather`. This assumes that this variable and the
752    sparse_delta IndexedSlices have a series of leading dimensions that are the
753    same for all of them, and the updates are performed on the last dimension of
754    indices. In other words, the dimensions should be the following:
755
756    `num_prefix_dims = sparse_delta.indices.ndims - 1`
757    `batch_dim = num_prefix_dims + 1`
758    `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[
759         batch_dim:]`
760
761    where
762
763    `sparse_delta.updates.shape[:num_prefix_dims]`
764    `== sparse_delta.indices.shape[:num_prefix_dims]`
765    `== var.shape[:num_prefix_dims]`
766
767    And the operation performed can be expressed as:
768
769    `var[i_1, ..., i_n,
770         sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[
771            i_1, ..., i_n, j]`
772
773    When sparse_delta.indices is a 1D tensor, this operation is equivalent to
774    `scatter_update`.
775
776    To avoid this operation one can looping over the first `ndims` of the
777    variable and using `scatter_update` on the subtensors that result of slicing
778    the first dimension. This is a valid option for `ndims = 1`, but less
779    efficient than this implementation.
780
781    Args:
782      sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
783      use_locking: If `True`, use locking during the operation.
784      name: the name of the operation.
785
786    Returns:
787      The updated variable.
788
789    Raises:
790      TypeError: if `sparse_delta` is not an `IndexedSlices`.
791    """
792    raise NotImplementedError
793
794  def scatter_nd_sub(self, indices, updates, name=None):
795    """Applies sparse subtraction to individual values or slices in a Variable.
796
797    Assuming the variable has rank `P` and `indices` is a `Tensor` of rank `Q`.
798
799    `indices` must be integer tensor, containing indices into self.
800    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
801
802    The innermost dimension of `indices` (with length `K`) corresponds to
803    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
804    dimension of self.
805
806    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
807
808    ```
809    [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]].
810    ```
811
812    For example, say we want to add 4 scattered elements to a rank-1 tensor to
813    8 elements. In Python, that update would look like this:
814
815    ```python
816        v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
817        indices = tf.constant([[4], [3], [1] ,[7]])
818        updates = tf.constant([9, 10, 11, 12])
819        v.scatter_nd_sub(indices, updates)
820        print(v)
821    ```
822
823    After the update `v` would look like this:
824
825        [1, -9, 3, -6, -4, 6, 7, -4]
826
827    See `tf.scatter_nd` for more details about how to make updates to
828    slices.
829
830    Args:
831      indices: The indices to be used in the operation.
832      updates: The values to be used in the operation.
833      name: the name of the operation.
834
835    Returns:
836      The updated variable.
837    """
838    raise NotImplementedError
839
840  def scatter_nd_add(self, indices, updates, name=None):
841    """Applies sparse addition to individual values or slices in a Variable.
842
843    The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`.
844
845    `indices` must be integer tensor, containing indices into self.
846    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
847
848    The innermost dimension of `indices` (with length `K`) corresponds to
849    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
850    dimension of self.
851
852    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
853
854    ```
855    [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]].
856    ```
857
858    For example, say we want to add 4 scattered elements to a rank-1 tensor to
859    8 elements. In Python, that update would look like this:
860
861    ```python
862        v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
863        indices = tf.constant([[4], [3], [1] ,[7]])
864        updates = tf.constant([9, 10, 11, 12])
865        v.scatter_nd_add(indices, updates)
866        print(v)
867    ```
868
869    The resulting update to v would look like this:
870
871        [1, 13, 3, 14, 14, 6, 7, 20]
872
873    See `tf.scatter_nd` for more details about how to make updates to
874    slices.
875
876    Args:
877      indices: The indices to be used in the operation.
878      updates: The values to be used in the operation.
879      name: the name of the operation.
880
881    Returns:
882      The updated variable.
883    """
884    raise NotImplementedError
885
886  def scatter_nd_update(self, indices, updates, name=None):
887    """Applies sparse assignment to individual values or slices in a Variable.
888
889    The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`.
890
891    `indices` must be integer tensor, containing indices into self.
892    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
893
894    The innermost dimension of `indices` (with length `K`) corresponds to
895    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
896    dimension of self.
897
898    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
899
900    ```
901    [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]].
902    ```
903
904    For example, say we want to add 4 scattered elements to a rank-1 tensor to
905    8 elements. In Python, that update would look like this:
906
907    ```python
908        v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
909        indices = tf.constant([[4], [3], [1] ,[7]])
910        updates = tf.constant([9, 10, 11, 12])
911        v.scatter_nd_update(indices, updates)
912        print(v)
913    ```
914
915    The resulting update to v would look like this:
916
917        [1, 11, 3, 10, 9, 6, 7, 12]
918
919    See `tf.scatter_nd` for more details about how to make updates to
920    slices.
921
922    Args:
923      indices: The indices to be used in the operation.
924      updates: The values to be used in the operation.
925      name: the name of the operation.
926
927    Returns:
928      The updated variable.
929    """
930    raise NotImplementedError
931
932  def sparse_read(self, indices, name=None):
933    r"""Gather slices from params axis axis according to indices.
934
935    This function supports a subset of tf.gather, see tf.gather for details on
936    usage.
937
938    Args:
939      indices: The index `Tensor`.  Must be one of the following types: `int32`,
940        `int64`. Must be in range `[0, params.shape[axis])`.
941      name: A name for the operation (optional).
942
943    Returns:
944      A `Tensor`. Has the same type as `params`.
945    """
946    raise AttributeError
947
948  def gather_nd(self, indices, name=None):
949    r"""Gather slices from `params` into a Tensor with shape specified by `indices`.
950
951    See tf.gather_nd for details.
952
953    Args:
954      indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
955        Index tensor.
956      name: A name for the operation (optional).
957
958    Returns:
959      A `Tensor`. Has the same type as `params`.
960    """
961    raise AttributeError
962
963  @deprecated(None, "Prefer Dataset.range instead.")
964  def count_up_to(self, limit):
965    """Increments this variable until it reaches `limit`.
966
967    When that Op is run it tries to increment the variable by `1`. If
968    incrementing the variable would bring it above `limit` then the Op raises
969    the exception `OutOfRangeError`.
970
971    If no error is raised, the Op outputs the value of the variable before
972    the increment.
973
974    This is essentially a shortcut for `count_up_to(self, limit)`.
975
976    Args:
977      limit: value at which incrementing the variable raises an error.
978
979    Returns:
980      A `Tensor` that will hold the variable value before the increment. If no
981      other Op modifies this variable, the values produced will all be
982      distinct.
983    """
984    raise NotImplementedError
985
986  @deprecated(None,
987              "Prefer Variable.assign which has equivalent behavior in 2.X.")
988  def load(self, value, session=None):
989    """Load new value into this variable.
990
991    Writes new value to variable's memory. Doesn't add ops to the graph.
992
993    This convenience method requires a session where the graph
994    containing this variable has been launched. If no session is
995    passed, the default session is used.  See `tf.compat.v1.Session` for more
996    information on launching a graph and on sessions.
997
998    ```python
999    v = tf.Variable([1, 2])
1000    init = tf.compat.v1.global_variables_initializer()
1001
1002    with tf.compat.v1.Session() as sess:
1003        sess.run(init)
1004        # Usage passing the session explicitly.
1005        v.load([2, 3], sess)
1006        print(v.eval(sess)) # prints [2 3]
1007        # Usage with the default session.  The 'with' block
1008        # above makes 'sess' the default session.
1009        v.load([3, 4], sess)
1010        print(v.eval()) # prints [3 4]
1011    ```
1012
1013    Args:
1014        value: New variable value
1015        session: The session to use to evaluate this variable. If none, the
1016          default session is used.
1017
1018    Raises:
1019        ValueError: Session is not passed and no default session
1020    """
1021    if context.executing_eagerly():
1022      self.assign(value)
1023    else:
1024      session = session or ops.get_default_session()
1025      if session is None:
1026        raise ValueError(
1027            "Either session argument should be provided or default session "
1028            "should be established")
1029      session.run(self.initializer, {self.initializer.inputs[1]: value})
1030
1031  # Conversion to tensor.
1032  @staticmethod
1033  def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False):  # pylint: disable=invalid-name
1034    """Utility function for converting a Variable to a Tensor."""
1035    _ = name
1036    if dtype and not dtype.is_compatible_with(v.dtype):
1037      raise ValueError(
1038          f"Incompatible type conversion requested to type '{dtype.name}' for "
1039          f"variable of type '{v.dtype.name}' (Variable: {v}).")
1040    if as_ref:
1041      return v._ref()  # pylint: disable=protected-access
1042    else:
1043      return v.value()
1044
1045  @classmethod
1046  def _OverloadAllOperators(cls):  # pylint: disable=invalid-name
1047    """Register overloads for all operators."""
1048    for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
1049      cls._OverloadOperator(operator)
1050    # For slicing, bind getitem differently than a tensor (use SliceHelperVar
1051    # instead)
1052    # pylint: disable=protected-access
1053    setattr(cls, "__getitem__", array_ops._SliceHelperVar)
1054
1055  @classmethod
1056  def _OverloadOperator(cls, operator):  # pylint: disable=invalid-name
1057    """Defer an operator overload to `ops.Tensor`.
1058
1059    We pull the operator out of ops.Tensor dynamically to avoid ordering issues.
1060
1061    Args:
1062      operator: string. The operator name.
1063    """
1064    # We can't use the overload mechanism on __eq__ & __ne__ since __eq__ is
1065    # called when adding a variable to sets. As a result we call a.value() which
1066    # causes infinite recursion when operating within a GradientTape
1067    # TODO(gjn): Consider removing this
1068    if operator == "__eq__" or operator == "__ne__":
1069      return
1070
1071    tensor_oper = getattr(ops.Tensor, operator)
1072
1073    def _run_op(a, *args, **kwargs):
1074      # pylint: disable=protected-access
1075      return tensor_oper(a.value(), *args, **kwargs)
1076
1077    functools.update_wrapper(_run_op, tensor_oper)
1078    setattr(cls, operator, _run_op)
1079
1080  def __hash__(self):
1081    if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions():  # pylint: disable=protected-access
1082      raise TypeError(
1083          "Variable is unhashable. "
1084          f"Instead, use variable.ref() as the key. (Variable: {self})")
1085    else:
1086      return id(self)
1087
1088  # TODO(gjn): duplicate of math_ops.tensor_equals, consider removing
1089  def __eq__(self, other):
1090    """Compares two variables element-wise for equality."""
1091    if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions():  # pylint: disable=protected-access
1092      return gen_math_ops.equal(self, other, incompatible_shape_error=False)
1093    else:
1094      # In legacy graph mode, tensor equality is object equality
1095      return self is other
1096
1097  # TODO(gjn): duplicate of math_ops.tensor_not_equals, consider removing
1098  def __ne__(self, other):
1099    """Compares two variables element-wise for equality."""
1100    if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions():  # pylint: disable=protected-access
1101      return gen_math_ops.not_equal(self, other, incompatible_shape_error=False)
1102    else:
1103      # In legacy graph mode, tensor equality is object equality
1104      return self is not other
1105
1106  def __iter__(self):
1107    """When executing eagerly, iterates over the value of the variable."""
1108    return iter(self.read_value())
1109
1110  # NOTE(mrry): This enables the Variable's overloaded "right" binary
1111  # operators to run when the left operand is an ndarray, because it
1112  # accords the Variable class higher priority than an ndarray, or a
1113  # numpy matrix.
1114  # TODO(mrry): Convert this to using numpy's __numpy_ufunc__
1115  # mechanism, which allows more control over how Variables interact
1116  # with ndarrays.
1117  __array_priority__ = 100
1118
1119  @property
1120  def name(self):
1121    """The name of this variable."""
1122    raise NotImplementedError
1123
1124  @property
1125  def _shared_name(self):
1126    """The shared name of the variable.
1127
1128      Unlike name(), shared_name doesn't have ":0" suffix. It is user-specified
1129      name with name scope prefix.
1130
1131    Returns:
1132      variable name.
1133    """
1134    return self.name[:self.name.index(":")]
1135
1136  @property
1137  def initializer(self):
1138    """The initializer operation for this variable."""
1139    raise NotImplementedError
1140
1141  @property
1142  def device(self):
1143    """The device of this variable."""
1144    raise NotImplementedError
1145
1146  @property
1147  def dtype(self):
1148    """The `DType` of this variable."""
1149    raise NotImplementedError
1150
1151  @property
1152  def op(self):
1153    """The `Operation` of this variable."""
1154    raise NotImplementedError
1155
1156  @property
1157  def graph(self):
1158    """The `Graph` of this variable."""
1159    raise NotImplementedError
1160
1161  @property
1162  def shape(self):
1163    """The `TensorShape` of this variable.
1164
1165    Returns:
1166      A `TensorShape`.
1167    """
1168    raise NotImplementedError
1169
1170  def get_shape(self):
1171    """Alias of `Variable.shape`."""
1172    return self.shape
1173
1174  def _gather_saveables_for_checkpoint(self):
1175    """For implementing `Trackable`. This object is saveable on its own."""
1176    return {trackable.VARIABLE_VALUE_KEY: self}
1177
1178  def to_proto(self, export_scope=None):
1179    """Converts a `Variable` to a `VariableDef` protocol buffer.
1180
1181    Args:
1182      export_scope: Optional `string`. Name scope to remove.
1183
1184    Returns:
1185      A `VariableDef` protocol buffer, or `None` if the `Variable` is not
1186      in the specified name scope.
1187    """
1188    raise NotImplementedError
1189
1190  @staticmethod
1191  def from_proto(variable_def, import_scope=None):
1192    """Returns a `Variable` object created from `variable_def`."""
1193    return RefVariable(variable_def=variable_def, import_scope=import_scope)
1194
1195  def _set_save_slice_info(self, save_slice_info):
1196    """Sets the slice info for this `Variable`.
1197
1198    Args:
1199      save_slice_info: A `Variable.SaveSliceInfo` object.
1200    """
1201    self._save_slice_info = save_slice_info
1202
1203  def _get_save_slice_info(self):
1204    return self._save_slice_info
1205
1206  @deprecated(None, "Use ref() instead.")
1207  def experimental_ref(self):
1208    return self.ref()
1209
1210  def ref(self):
1211    # tf.Tensor also has the same ref() API.  If you update the
1212    # documentation here, please update tf.Tensor.ref() as well.
1213    """Returns a hashable reference object to this Variable.
1214
1215    The primary use case for this API is to put variables in a set/dictionary.
1216    We can't put variables in a set/dictionary as `variable.__hash__()` is no
1217    longer available starting Tensorflow 2.0.
1218
1219    The following will raise an exception starting 2.0
1220
1221    >>> x = tf.Variable(5)
1222    >>> y = tf.Variable(10)
1223    >>> z = tf.Variable(10)
1224    >>> variable_set = {x, y, z}
1225    Traceback (most recent call last):
1226      ...
1227    TypeError: Variable is unhashable. Instead, use tensor.ref() as the key.
1228    >>> variable_dict = {x: 'five', y: 'ten'}
1229    Traceback (most recent call last):
1230      ...
1231    TypeError: Variable is unhashable. Instead, use tensor.ref() as the key.
1232
1233    Instead, we can use `variable.ref()`.
1234
1235    >>> variable_set = {x.ref(), y.ref(), z.ref()}
1236    >>> x.ref() in variable_set
1237    True
1238    >>> variable_dict = {x.ref(): 'five', y.ref(): 'ten', z.ref(): 'ten'}
1239    >>> variable_dict[y.ref()]
1240    'ten'
1241
1242    Also, the reference object provides `.deref()` function that returns the
1243    original Variable.
1244
1245    >>> x = tf.Variable(5)
1246    >>> x.ref().deref()
1247    <tf.Variable 'Variable:0' shape=() dtype=int32, numpy=5>
1248    """
1249    return object_identity.Reference(self)
1250
1251  class SaveSliceInfo:
1252    """Information on how to save this Variable as a slice.
1253
1254    Provides internal support for saving variables as slices of a larger
1255    variable.  This API is not public and is subject to change.
1256
1257    Available properties:
1258
1259    * full_name
1260    * full_shape
1261    * var_offset
1262    * var_shape
1263    """
1264
1265    def __init__(self,
1266                 full_name=None,
1267                 full_shape=None,
1268                 var_offset=None,
1269                 var_shape=None,
1270                 save_slice_info_def=None,
1271                 import_scope=None):
1272      """Create a `SaveSliceInfo`.
1273
1274      Args:
1275        full_name: Name of the full variable of which this `Variable` is a
1276          slice.
1277        full_shape: Shape of the full variable, as a list of int.
1278        var_offset: Offset of this `Variable` into the full variable, as a list
1279          of int.
1280        var_shape: Shape of this `Variable`, as a list of int.
1281        save_slice_info_def: `SaveSliceInfoDef` protocol buffer. If not `None`,
1282          recreates the SaveSliceInfo object its contents. `save_slice_info_def`
1283          and other arguments are mutually exclusive.
1284        import_scope: Optional `string`. Name scope to add. Only used when
1285          initializing from protocol buffer.
1286      """
1287      if save_slice_info_def:
1288        assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef)
1289        self.full_name = ops.prepend_name_scope(
1290            save_slice_info_def.full_name, import_scope=import_scope)
1291        self.full_shape = [i for i in save_slice_info_def.full_shape]
1292        self.var_offset = [i for i in save_slice_info_def.var_offset]
1293        self.var_shape = [i for i in save_slice_info_def.var_shape]
1294      else:
1295        self.full_name = full_name
1296        self.full_shape = full_shape
1297        self.var_offset = var_offset
1298        self.var_shape = var_shape
1299
1300    @property
1301    def spec(self):
1302      """Computes the spec string used for saving."""
1303      full_shape_str = " ".join("%d" % d for d in self.full_shape) + " "
1304      sl_spec = ":".join(
1305          "%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape))
1306      return full_shape_str + sl_spec
1307
1308    def to_proto(self, export_scope=None):
1309      """Returns a SaveSliceInfoDef() proto.
1310
1311      Args:
1312        export_scope: Optional `string`. Name scope to remove.
1313
1314      Returns:
1315        A `SaveSliceInfoDef` protocol buffer, or None if the `Variable` is not
1316        in the specified name scope.
1317      """
1318      if (export_scope is None or self.full_name.startswith(export_scope)):
1319        save_slice_info_def = variable_pb2.SaveSliceInfoDef()
1320        save_slice_info_def.full_name = ops.strip_name_scope(
1321            self.full_name, export_scope)
1322        for i in self.full_shape:
1323          save_slice_info_def.full_shape.append(i)
1324        for i in self.var_offset:
1325          save_slice_info_def.var_offset.append(i)
1326        for i in self.var_shape:
1327          save_slice_info_def.var_shape.append(i)
1328        return save_slice_info_def
1329      else:
1330        return None
1331
1332
1333Variable._OverloadAllOperators()  # pylint: disable=protected-access
1334_pywrap_utils.RegisterType("Variable", Variable)
1335
1336
1337@tf_export(v1=["Variable"])
1338class VariableV1(Variable):
1339  """See the [Variables Guide](https://tensorflow.org/guide/variables).
1340
1341  A variable maintains state in the graph across calls to `run()`. You add a
1342  variable to the graph by constructing an instance of the class `Variable`.
1343
1344  The `Variable()` constructor requires an initial value for the variable,
1345  which can be a `Tensor` of any type and shape. The initial value defines the
1346  type and shape of the variable. After construction, the type and shape of
1347  the variable are fixed. The value can be changed using one of the assign
1348  methods.
1349
1350  If you want to change the shape of a variable later you have to use an
1351  `assign` Op with `validate_shape=False`.
1352
1353  Just like any `Tensor`, variables created with `Variable()` can be used as
1354  inputs for other Ops in the graph. Additionally, all the operators
1355  overloaded for the `Tensor` class are carried over to variables, so you can
1356  also add nodes to the graph by just doing arithmetic on variables.
1357
1358  ```python
1359  import tensorflow as tf
1360
1361  # Create a variable.
1362  w = tf.Variable(<initial-value>, name=<optional-name>)
1363
1364  # Use the variable in the graph like any Tensor.
1365  y = tf.matmul(w, ...another variable or tensor...)
1366
1367  # The overloaded operators are available too.
1368  z = tf.sigmoid(w + y)
1369
1370  # Assign a new value to the variable with `assign()` or a related method.
1371  w.assign(w + 1.0)
1372  w.assign_add(1.0)
1373  ```
1374
1375  When you launch the graph, variables have to be explicitly initialized before
1376  you can run Ops that use their value. You can initialize a variable by
1377  running its *initializer op*, restoring the variable from a save file, or
1378  simply running an `assign` Op that assigns a value to the variable. In fact,
1379  the variable *initializer op* is just an `assign` Op that assigns the
1380  variable's initial value to the variable itself.
1381
1382  ```python
1383  # Launch the graph in a session.
1384  with tf.compat.v1.Session() as sess:
1385      # Run the variable initializer.
1386      sess.run(w.initializer)
1387      # ...you now can run ops that use the value of 'w'...
1388  ```
1389
1390  The most common initialization pattern is to use the convenience function
1391  `global_variables_initializer()` to add an Op to the graph that initializes
1392  all the variables. You then run that Op after launching the graph.
1393
1394  ```python
1395  # Add an Op to initialize global variables.
1396  init_op = tf.compat.v1.global_variables_initializer()
1397
1398  # Launch the graph in a session.
1399  with tf.compat.v1.Session() as sess:
1400      # Run the Op that initializes global variables.
1401      sess.run(init_op)
1402      # ...you can now run any Op that uses variable values...
1403  ```
1404
1405  If you need to create a variable with an initial value dependent on another
1406  variable, use the other variable's `initialized_value()`. This ensures that
1407  variables are initialized in the right order.
1408
1409  All variables are automatically collected in the graph where they are
1410  created. By default, the constructor adds the new variable to the graph
1411  collection `GraphKeys.GLOBAL_VARIABLES`. The convenience function
1412  `global_variables()` returns the contents of that collection.
1413
1414  When building a machine learning model it is often convenient to distinguish
1415  between variables holding the trainable model parameters and other variables
1416  such as a `global step` variable used to count training steps. To make this
1417  easier, the variable constructor supports a `trainable=<bool>` parameter. If
1418  `True`, the new variable is also added to the graph collection
1419  `GraphKeys.TRAINABLE_VARIABLES`. The convenience function
1420  `trainable_variables()` returns the contents of this collection. The
1421  various `Optimizer` classes use this collection as the default list of
1422  variables to optimize.
1423
1424  WARNING: tf.Variable objects by default have a non-intuitive memory model. A
1425  Variable is represented internally as a mutable Tensor which can
1426  non-deterministically alias other Tensors in a graph. The set of operations
1427  which consume a Variable and can lead to aliasing is undetermined and can
1428  change across TensorFlow versions. Avoid writing code which relies on the
1429  value of a Variable either changing or not changing as other operations
1430  happen. For example, using Variable objects or simple functions thereof as
1431  predicates in a `tf.cond` is dangerous and error-prone:
1432
1433  ```
1434  v = tf.Variable(True)
1435  tf.cond(v, lambda: v.assign(False), my_false_fn)  # Note: this is broken.
1436  ```
1437
1438  Here, adding `use_resource=True` when constructing the variable will
1439  fix any nondeterminism issues:
1440  ```
1441  v = tf.Variable(True, use_resource=True)
1442  tf.cond(v, lambda: v.assign(False), my_false_fn)
1443  ```
1444
1445  To use the replacement for variables which does
1446  not have these issues:
1447
1448  * Add `use_resource=True` when constructing `tf.Variable`;
1449  * Call `tf.compat.v1.get_variable_scope().set_use_resource(True)` inside a
1450    `tf.compat.v1.variable_scope` before the `tf.compat.v1.get_variable()` call.
1451  """
1452
1453  def __init__(
1454      self,  # pylint: disable=super-init-not-called
1455      initial_value=None,
1456      trainable=None,
1457      collections=None,
1458      validate_shape=True,
1459      caching_device=None,
1460      name=None,
1461      variable_def=None,
1462      dtype=None,
1463      expected_shape=None,
1464      import_scope=None,
1465      constraint=None,
1466      use_resource=None,
1467      synchronization=VariableSynchronization.AUTO,
1468      aggregation=VariableAggregation.NONE,
1469      shape=None):
1470    """Creates a new variable with value `initial_value`.
1471
1472    The new variable is added to the graph collections listed in `collections`,
1473    which defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1474
1475    If `trainable` is `True` the variable is also added to the graph collection
1476    `GraphKeys.TRAINABLE_VARIABLES`.
1477
1478    This constructor creates both a `variable` Op and an `assign` Op to set the
1479    variable to its initial value.
1480
1481    Args:
1482      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
1483        which is the initial value for the Variable. The initial value must have
1484        a shape specified unless `validate_shape` is set to False. Can also be a
1485        callable with no argument that returns the initial value when called. In
1486        that case, `dtype` must be specified. (Note that initializer functions
1487        from init_ops.py must first be bound to a shape before being used here.)
1488      trainable: If `True`, also adds the variable to the graph collection
1489        `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default
1490        list of variables to use by the `Optimizer` classes. Defaults to `True`,
1491        unless `synchronization` is set to `ON_READ`, in which case it defaults
1492        to `False`.
1493      collections: List of graph collections keys. The new variable is added to
1494        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1495      validate_shape: If `False`, allows the variable to be initialized with a
1496        value of unknown shape. If `True`, the default, the shape of
1497        `initial_value` must be known.
1498      caching_device: Optional device string describing where the Variable
1499        should be cached for reading.  Defaults to the Variable's device. If not
1500        `None`, caches on another device.  Typical use is to cache on the device
1501        where the Ops using the Variable reside, to deduplicate copying through
1502        `Switch` and other conditional statements.
1503      name: Optional name for the variable. Defaults to `'Variable'` and gets
1504        uniquified automatically.
1505      variable_def: `VariableDef` protocol buffer. If not `None`, recreates the
1506        Variable object with its contents, referencing the variable's nodes in
1507        the graph, which must already exist. The graph is not changed.
1508        `variable_def` and the other arguments are mutually exclusive.
1509      dtype: If set, initial_value will be converted to the given type. If
1510        `None`, either the datatype will be kept (if `initial_value` is a
1511        Tensor), or `convert_to_tensor` will decide.
1512      expected_shape: A TensorShape. If set, initial_value is expected to have
1513        this shape.
1514      import_scope: Optional `string`. Name scope to add to the `Variable.` Only
1515        used when initializing from protocol buffer.
1516      constraint: An optional projection function to be applied to the variable
1517        after being updated by an `Optimizer` (e.g. used to implement norm
1518        constraints or value constraints for layer weights). The function must
1519        take as input the unprojected Tensor representing the value of the
1520        variable and return the Tensor for the projected value (which must have
1521        the same shape). Constraints are not safe to use when doing asynchronous
1522        distributed training.
1523      use_resource: whether to use resource variables.
1524      synchronization: Indicates when a distributed a variable will be
1525        aggregated. Accepted values are constants defined in the class
1526        `tf.VariableSynchronization`. By default the synchronization is set to
1527        `AUTO` and the current `DistributionStrategy` chooses when to
1528        synchronize.
1529      aggregation: Indicates how a distributed variable will be aggregated.
1530        Accepted values are constants defined in the class
1531        `tf.VariableAggregation`.
1532      shape: (optional) The shape of this variable. If None, the shape of
1533        `initial_value` will be used. When setting this argument to
1534        `tf.TensorShape(None)` (representing an unspecified shape), the variable
1535        can be assigned with values of different shapes.
1536
1537    Raises:
1538      ValueError: If both `variable_def` and initial_value are specified.
1539      ValueError: If the initial value is not specified, or does not have a
1540        shape and `validate_shape` is `True`.
1541      RuntimeError: If eager execution is enabled.
1542    """
1543
1544  SaveSliceInfo = Variable.SaveSliceInfo
1545
1546
1547# TODO(apassos): do not repeat all comments here
1548class RefVariable(VariableV1, core.Tensor):
1549  """Ref-based implementation of variables."""
1550
1551  def __init__(
1552      self,  # pylint: disable=super-init-not-called
1553      initial_value=None,
1554      trainable=None,
1555      collections=None,
1556      validate_shape=True,
1557      caching_device=None,
1558      name=None,
1559      variable_def=None,
1560      dtype=None,
1561      expected_shape=None,
1562      import_scope=None,
1563      constraint=None,
1564      synchronization=None,
1565      aggregation=None,
1566      shape=None):
1567    """Creates a new variable with value `initial_value`.
1568
1569    The new variable is added to the graph collections listed in `collections`,
1570    which defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1571
1572    If `trainable` is `True` the variable is also added to the graph collection
1573    `GraphKeys.TRAINABLE_VARIABLES`.
1574
1575    This constructor creates both a `variable` Op and an `assign` Op to set the
1576    variable to its initial value.
1577
1578    Args:
1579      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
1580        which is the initial value for the Variable. The initial value must have
1581        a shape specified unless `validate_shape` is set to False. Can also be a
1582        callable with no argument that returns the initial value when called. In
1583        that case, `dtype` must be specified. (Note that initializer functions
1584        from init_ops.py must first be bound to a shape before being used here.)
1585      trainable: If `True`, also adds the variable to the graph collection
1586        `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default
1587        list of variables to use by the `Optimizer` classes. Defaults to `True`,
1588        unless `synchronization` is set to `ON_READ`, in which case it defaults
1589        to `False`.
1590      collections: List of graph collections keys. The new variable is added to
1591        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1592      validate_shape: If `False`, allows the variable to be initialized with a
1593        value of unknown shape. If `True`, the default, the shape of
1594        `initial_value` must be known.
1595      caching_device: Optional device string describing where the Variable
1596        should be cached for reading.  Defaults to the Variable's device. If not
1597        `None`, caches on another device.  Typical use is to cache on the device
1598        where the Ops using the Variable reside, to deduplicate copying through
1599        `Switch` and other conditional statements.
1600      name: Optional name for the variable. Defaults to `'Variable'` and gets
1601        uniquified automatically.
1602      variable_def: `VariableDef` protocol buffer. If not `None`, recreates the
1603        Variable object with its contents, referencing the variable's nodes in
1604        the graph, which must already exist. The graph is not changed.
1605        `variable_def` and the other arguments are mutually exclusive.
1606      dtype: If set, initial_value will be converted to the given type. If
1607        `None`, either the datatype will be kept (if `initial_value` is a
1608        Tensor), or `convert_to_tensor` will decide.
1609      expected_shape: A TensorShape. If set, initial_value is expected to have
1610        this shape.
1611      import_scope: Optional `string`. Name scope to add to the `Variable.` Only
1612        used when initializing from protocol buffer.
1613      constraint: An optional projection function to be applied to the variable
1614        after being updated by an `Optimizer` (e.g. used to implement norm
1615        constraints or value constraints for layer weights). The function must
1616        take as input the unprojected Tensor representing the value of the
1617        variable and return the Tensor for the projected value (which must have
1618        the same shape). Constraints are not safe to use when doing asynchronous
1619        distributed training.
1620      synchronization: Indicates when a distributed a variable will be
1621        aggregated. Accepted values are constants defined in the class
1622        `tf.VariableSynchronization`. By default the synchronization is set to
1623        `AUTO` and the current `DistributionStrategy` chooses when to
1624        synchronize.
1625      aggregation: Indicates how a distributed variable will be aggregated.
1626        Accepted values are constants defined in the class
1627        `tf.VariableAggregation`.
1628      shape: (optional) The shape of this variable. If None, the shape of
1629        `initial_value` will be used. When setting this argument to
1630        `tf.TensorShape(None)` (representing an unspecified shape), the variable
1631        can be assigned with values of different shapes.
1632
1633    Raises:
1634      ValueError: If both `variable_def` and initial_value are specified.
1635      ValueError: If the initial value is not specified, or does not have a
1636        shape and `validate_shape` is `True`.
1637      RuntimeError: If eager execution is enabled.
1638    """
1639    self._in_graph_mode = True
1640    if variable_def:
1641      # If variable_def is provided, recreates the variable from its fields.
1642      if initial_value:
1643        raise ValueError("variable_def and initial_value are mutually "
1644                         "exclusive.")
1645      self._init_from_proto(variable_def, import_scope=import_scope)
1646    else:
1647      # Create from initial_value.
1648      self._init_from_args(
1649          initial_value=initial_value,
1650          trainable=trainable,
1651          collections=collections,
1652          validate_shape=validate_shape,
1653          caching_device=caching_device,
1654          name=name,
1655          dtype=dtype,
1656          expected_shape=expected_shape,
1657          constraint=constraint,
1658          synchronization=synchronization,
1659          aggregation=aggregation,
1660          shape=shape)
1661
1662  def __repr__(self):
1663    if context.executing_eagerly() and not self._in_graph_mode:
1664      return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % (
1665          self.name, self.get_shape(), self.dtype.name,
1666          ops.numpy_text(self.read_value(), is_repr=True))
1667    else:
1668      return "<tf.Variable '%s' shape=%s dtype=%s>" % (
1669          self.name, self.get_shape(), self.dtype.name)
1670
1671  def _init_from_args(self,
1672                      initial_value=None,
1673                      trainable=None,
1674                      collections=None,
1675                      validate_shape=True,
1676                      caching_device=None,
1677                      name=None,
1678                      dtype=None,
1679                      expected_shape=None,
1680                      constraint=None,
1681                      synchronization=None,
1682                      aggregation=None,
1683                      shape=None):
1684    """Creates a new variable from arguments.
1685
1686    Args:
1687      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
1688        which is the initial value for the Variable. The initial value must have
1689        a shape specified unless `validate_shape` is set to False. Can also be a
1690        callable with no argument that returns the initial value when called.
1691        (Note that initializer functions from init_ops.py must first be bound to
1692        a shape before being used here.)
1693      trainable: If `True`, also adds the variable to the graph collection
1694        `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default
1695        list of variables to use by the `Optimizer` classes. Defaults to `True`,
1696        unless `synchronization` is set to `ON_READ`, in which case it defaults
1697        to `False`.
1698      collections: List of graph collections keys. The new variable is added to
1699        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1700      validate_shape: If `False`, allows the variable to be initialized with a
1701        value of unknown shape. If `True`, the default, the shape of
1702        `initial_value` must be known.
1703      caching_device: Optional device string or function describing where the
1704        Variable should be cached for reading.  Defaults to the Variable's
1705        device.  If not `None`, caches on another device.  Typical use is to
1706        cache on the device where the Ops using the Variable reside, to
1707        deduplicate copying through `Switch` and other conditional statements.
1708      name: Optional name for the variable. Defaults to `'Variable'` and gets
1709        uniquified automatically.
1710      dtype: If set, initial_value will be converted to the given type. If None,
1711        either the datatype will be kept (if initial_value is a Tensor) or
1712        float32 will be used (if it is a Python object convertible to a Tensor).
1713      expected_shape: Deprecated. Ignored.
1714      constraint: An optional projection function to be applied to the variable
1715        after being updated by an `Optimizer` (e.g. used to implement norm
1716        constraints or value constraints for layer weights). The function must
1717        take as input the unprojected Tensor representing the value of the
1718        variable and return the Tensor for the projected value (which must have
1719        the same shape). Constraints are not safe to use when doing asynchronous
1720        distributed training.
1721      synchronization: Indicates when a distributed a variable will be
1722        aggregated. Accepted values are constants defined in the class
1723        `tf.VariableSynchronization`. By default the synchronization is set to
1724        `AUTO` and the current `DistributionStrategy` chooses when to
1725        synchronize.
1726      aggregation: Indicates how a distributed variable will be aggregated.
1727        Accepted values are constants defined in the class
1728        `tf.VariableAggregation`.
1729      shape: (optional) The shape of this variable. If None, the shape of
1730        `initial_value` will be used. When setting this argument to
1731        `tf.TensorShape(None)` (representing an unspecified shape), the variable
1732        can be assigned with values of different shapes.
1733
1734    Raises:
1735      ValueError: If the initial value is not specified, or does not have a
1736        shape and `validate_shape` is `True`.
1737      RuntimeError: If lifted into the eager context.
1738    """
1739    _ = expected_shape
1740    if initial_value is None:
1741      raise ValueError("initial_value must be specified.")
1742    init_from_fn = callable(initial_value)
1743
1744    if collections is None:
1745      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
1746    if not isinstance(collections, (list, tuple, set)):
1747      raise ValueError(
1748          "collections argument to Variable constructor must be a list, tuple, "
1749          "or set. Got %s of type %s" % (collections, type(collections)))
1750    if constraint is not None and not callable(constraint):
1751      raise ValueError("The `constraint` argument must be a callable.")
1752
1753    # Store the graph key so optimizers know how to only retrieve variables from
1754    # this graph.
1755    self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
1756    if isinstance(initial_value, trackable.CheckpointInitialValue):
1757      self._maybe_initialize_trackable()
1758      self._update_uid = initial_value.checkpoint_position.restore_uid
1759      initial_value = initial_value.wrapped_value
1760
1761    synchronization, aggregation, trainable = (
1762        validate_synchronization_aggregation_trainable(synchronization,
1763                                                       aggregation, trainable,
1764                                                       name))
1765    self._synchronization = synchronization
1766    self._aggregation = aggregation
1767    self._trainable = trainable
1768    if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
1769      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
1770    with ops.init_scope():
1771      # Ensure that we weren't lifted into the eager context.
1772      if context.executing_eagerly():
1773        raise RuntimeError(
1774            "Reference variables are not supported when eager execution is "
1775            "enabled. Please run `tf.compat.v1.enable_resource_variables()` to "
1776            "switch to resource variables.")
1777      with ops.name_scope(name, "Variable",
1778                          [] if init_from_fn else [initial_value]) as name:
1779
1780        if init_from_fn:
1781          # Use attr_scope and device(None) to simulate the behavior of
1782          # colocate_with when the variable we want to colocate with doesn't
1783          # yet exist.
1784          true_name = ops.name_from_scope_name(name)  # pylint: disable=protected-access
1785          attr = attr_value_pb2.AttrValue(
1786              list=attr_value_pb2.AttrValue.ListValue(
1787                  s=[compat.as_bytes("loc:@%s" % true_name)]))
1788          # pylint: disable=protected-access
1789          with ops.get_default_graph()._attr_scope({"_class": attr}):
1790            with ops.name_scope("Initializer"), ops.device(None):
1791              initial_value = initial_value()
1792              if isinstance(initial_value, trackable.CheckpointInitialValue):
1793                self._maybe_initialize_trackable()
1794                self._update_uid = initial_value.checkpoint_position.restore_uid
1795                initial_value = initial_value.wrapped_value
1796              self._initial_value = ops.convert_to_tensor(
1797                  initial_value, name="initial_value", dtype=dtype)
1798              if shape is None:
1799                shape = (
1800                    self._initial_value.get_shape()
1801                    if validate_shape else tensor_shape.unknown_shape())
1802            self._variable = state_ops.variable_op_v2(
1803                shape, self._initial_value.dtype.base_dtype, name=name)
1804          # pylint: enable=protected-access
1805
1806        # Or get the initial value from a Tensor or Python object.
1807        else:
1808          self._initial_value = ops.convert_to_tensor(
1809              initial_value, name="initial_value", dtype=dtype)
1810          # pylint: disable=protected-access
1811          if self._initial_value.op._get_control_flow_context() is not None:
1812            raise ValueError(
1813                "Initializer for variable %s is from inside a control-flow "
1814                "construct, such as a loop or conditional. When creating a "
1815                "variable inside a loop or conditional, use a lambda as the "
1816                "initializer." % name)
1817          if shape is None:
1818            # pylint: enable=protected-access
1819            shape = (
1820                self._initial_value.get_shape()
1821                if validate_shape else tensor_shape.unknown_shape())
1822          # In this case, the variable op can't be created until after the
1823          # initial_value has been converted to a Tensor with a known type.
1824          self._variable = state_ops.variable_op_v2(
1825              shape, self._initial_value.dtype.base_dtype, name=name)
1826
1827        # Cache the name in `self`, because some APIs call `Variable.name` in a
1828        # tight loop, and this halves the cost.
1829        self._name = self._variable.name
1830
1831        # Manually overrides the variable's shape with the initial value's.
1832        if validate_shape:
1833          initial_value_shape = self._initial_value.get_shape()
1834          if not initial_value_shape.is_fully_defined():
1835            raise ValueError("initial_value must have a shape specified: %s" %
1836                             self._initial_value)
1837
1838        # If 'initial_value' makes use of other variables, make sure we don't
1839        # have an issue if these other variables aren't initialized first by
1840        # using their initialized_value() method.
1841        self._initializer_op = state_ops.assign(
1842            self._variable,
1843            _try_guard_against_uninitialized_dependencies(
1844                name, self._initial_value),
1845            validate_shape=validate_shape).op
1846
1847        # TODO(vrv): Change this class to not take caching_device, but
1848        # to take the op to colocate the snapshot with, so we can use
1849        # colocation rather than devices.
1850        if caching_device is not None:
1851          with ops.device(caching_device):
1852            self._snapshot = array_ops.identity(self._variable, name="read")
1853        else:
1854          with ops.colocate_with(self._variable.op):
1855            self._snapshot = array_ops.identity(self._variable, name="read")
1856      ops.add_to_collections(collections, self)
1857
1858    self._caching_device = caching_device
1859    self._save_slice_info = None
1860    self._constraint = constraint
1861
1862  def _init_from_proto(self, variable_def, import_scope=None):
1863    """Recreates the Variable object from a `VariableDef` protocol buffer.
1864
1865    Args:
1866      variable_def: `VariableDef` protocol buffer, describing a variable whose
1867        nodes already exists in the graph.
1868      import_scope: Optional `string`. Name scope to add.
1869    """
1870    assert isinstance(variable_def, variable_pb2.VariableDef)
1871    # Create from variable_def.
1872    g = ops.get_default_graph()
1873    self._variable = g.as_graph_element(
1874        ops.prepend_name_scope(
1875            variable_def.variable_name, import_scope=import_scope))
1876    self._name = self._variable.name
1877    self._initializer_op = g.as_graph_element(
1878        ops.prepend_name_scope(
1879            variable_def.initializer_name, import_scope=import_scope))
1880    # Tests whether initial_value_name exists first for backwards compatibility.
1881    if (hasattr(variable_def, "initial_value_name") and
1882        variable_def.initial_value_name):
1883      self._initial_value = g.as_graph_element(
1884          ops.prepend_name_scope(
1885              variable_def.initial_value_name, import_scope=import_scope))
1886    else:
1887      self._initial_value = None
1888    synchronization, aggregation, trainable = (
1889        validate_synchronization_aggregation_trainable(
1890            variable_def.synchronization, variable_def.aggregation,
1891            variable_def.trainable, variable_def.variable_name))
1892    self._synchronization = synchronization
1893    self._aggregation = aggregation
1894    self._trainable = trainable
1895    self._snapshot = g.as_graph_element(
1896        ops.prepend_name_scope(
1897            variable_def.snapshot_name, import_scope=import_scope))
1898    if variable_def.HasField("save_slice_info_def"):
1899      self._save_slice_info = Variable.SaveSliceInfo(
1900          save_slice_info_def=variable_def.save_slice_info_def,
1901          import_scope=import_scope)
1902    else:
1903      self._save_slice_info = None
1904    self._caching_device = None
1905    self._constraint = None
1906
1907  def _as_graph_element(self):
1908    """Conversion function for Graph.as_graph_element()."""
1909    return self._variable
1910
1911  def value(self):
1912    """Returns the last snapshot of this variable.
1913
1914    You usually do not need to call this method as all ops that need the value
1915    of the variable call it automatically through a `convert_to_tensor()` call.
1916
1917    Returns a `Tensor` which holds the value of the variable.  You can not
1918    assign a new value to this tensor as it is not a reference to the variable.
1919
1920    To avoid copies, if the consumer of the returned value is on the same device
1921    as the variable, this actually returns the live value of the variable, not
1922    a copy.  Updates to the variable are seen by the consumer.  If the consumer
1923    is on a different device it will get a copy of the variable.
1924
1925    Returns:
1926      A `Tensor` containing the value of the variable.
1927    """
1928    return self._snapshot
1929
1930  def read_value(self):
1931    """Returns the value of this variable, read in the current context.
1932
1933    Can be different from value() if it's on another device, with control
1934    dependencies, etc.
1935
1936    Returns:
1937      A `Tensor` containing the value of the variable.
1938    """
1939    return array_ops.identity(self._variable, name="read")
1940
1941  def _ref(self):
1942    """Returns a reference to this variable.
1943
1944    You usually do not need to call this method as all ops that need a reference
1945    to the variable call it automatically.
1946
1947    Returns is a `Tensor` which holds a reference to the variable.  You can
1948    assign a new value to the variable by passing the tensor to an assign op.
1949    See `tf.Variable.value` if you want to get the value of the
1950    variable.
1951
1952    Returns:
1953      A `Tensor` that is a reference to the variable.
1954    """
1955    return self._variable
1956
1957  def set_shape(self, shape):
1958    """Overrides the shape for this variable.
1959
1960    Args:
1961      shape: the `TensorShape` representing the overridden shape.
1962    """
1963    self._ref().set_shape(shape)
1964    self.value().set_shape(shape)
1965
1966  @property
1967  def trainable(self):
1968    return self._trainable
1969
1970  @property
1971  def synchronization(self):
1972    return self._synchronization
1973
1974  @property
1975  def aggregation(self):
1976    return self._aggregation
1977
1978  def eval(self, session=None):
1979    """In a session, computes and returns the value of this variable.
1980
1981    This is not a graph construction method, it does not add ops to the graph.
1982
1983    This convenience method requires a session where the graph
1984    containing this variable has been launched. If no session is
1985    passed, the default session is used.  See `tf.compat.v1.Session` for more
1986    information on launching a graph and on sessions.
1987
1988    ```python
1989    v = tf.Variable([1, 2])
1990    init = tf.compat.v1.global_variables_initializer()
1991
1992    with tf.compat.v1.Session() as sess:
1993        sess.run(init)
1994        # Usage passing the session explicitly.
1995        print(v.eval(sess))
1996        # Usage with the default session.  The 'with' block
1997        # above makes 'sess' the default session.
1998        print(v.eval())
1999    ```
2000
2001    Args:
2002      session: The session to use to evaluate this variable. If none, the
2003        default session is used.
2004
2005    Returns:
2006      A numpy `ndarray` with a copy of the value of this variable.
2007    """
2008    return self._variable.eval(session=session)
2009
2010  @property
2011  def initial_value(self):
2012    """Returns the Tensor used as the initial value for the variable.
2013
2014    Note that this is different from `initialized_value()` which runs
2015    the op that initializes the variable before returning its value.
2016    This method returns the tensor that is used by the op that initializes
2017    the variable.
2018
2019    Returns:
2020      A `Tensor`.
2021    """
2022    return self._initial_value
2023
2024  @property
2025  def constraint(self):
2026    """Returns the constraint function associated with this variable.
2027
2028    Returns:
2029      The constraint function that was passed to the variable constructor.
2030      Can be `None` if no constraint was passed.
2031    """
2032    return self._constraint
2033
2034  def assign(self, value, use_locking=False, name=None, read_value=True):
2035    """Assigns a new value to the variable.
2036
2037    This is essentially a shortcut for `assign(self, value)`.
2038
2039    Args:
2040      value: A `Tensor`. The new value for this variable.
2041      use_locking: If `True`, use locking during the assignment.
2042      name: The name of the operation to be created
2043      read_value: if True, will return something which evaluates to the new
2044        value of the variable; if False will return the assign op.
2045
2046    Returns:
2047      A `Tensor` that will hold the new value of this variable after
2048      the assignment has completed.
2049    """
2050    assign = state_ops.assign(
2051        self._variable, value, use_locking=use_locking, name=name)
2052    if read_value:
2053      return assign
2054    return assign.op
2055
2056  def assign_add(self, delta, use_locking=False, name=None, read_value=True):
2057    """Adds a value to this variable.
2058
2059     This is essentially a shortcut for `assign_add(self, delta)`.
2060
2061    Args:
2062      delta: A `Tensor`. The value to add to this variable.
2063      use_locking: If `True`, use locking during the operation.
2064      name: The name of the operation to be created
2065      read_value: if True, will return something which evaluates to the new
2066        value of the variable; if False will return the assign op.
2067
2068    Returns:
2069      A `Tensor` that will hold the new value of this variable after
2070      the addition has completed.
2071    """
2072    assign = state_ops.assign_add(
2073        self._variable, delta, use_locking=use_locking, name=name)
2074    if read_value:
2075      return assign
2076    return assign.op
2077
2078  def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
2079    """Subtracts a value from this variable.
2080
2081    This is essentially a shortcut for `assign_sub(self, delta)`.
2082
2083    Args:
2084      delta: A `Tensor`. The value to subtract from this variable.
2085      use_locking: If `True`, use locking during the operation.
2086      name: The name of the operation to be created
2087      read_value: if True, will return something which evaluates to the new
2088        value of the variable; if False will return the assign op.
2089
2090    Returns:
2091      A `Tensor` that will hold the new value of this variable after
2092      the subtraction has completed.
2093    """
2094    assign = state_ops.assign_sub(
2095        self._variable, delta, use_locking=use_locking, name=name)
2096    if read_value:
2097      return assign
2098    return assign.op
2099
2100  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
2101    """Subtracts `tf.IndexedSlices` from this variable.
2102
2103    Args:
2104      sparse_delta: `tf.IndexedSlices` to be subtracted from this variable.
2105      use_locking: If `True`, use locking during the operation.
2106      name: the name of the operation.
2107
2108    Returns:
2109      A `Tensor` that will hold the new value of this variable after
2110      the scattered subtraction has completed.
2111
2112    Raises:
2113      TypeError: if `sparse_delta` is not an `IndexedSlices`.
2114    """
2115    if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
2116      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
2117    return gen_state_ops.scatter_sub(
2118        self._variable,
2119        sparse_delta.indices,
2120        sparse_delta.values,
2121        use_locking=use_locking,
2122        name=name)
2123
2124  def scatter_add(self, sparse_delta, use_locking=False, name=None):
2125    """Adds `tf.IndexedSlices` to this variable.
2126
2127    Args:
2128      sparse_delta: `tf.IndexedSlices` to be added to this variable.
2129      use_locking: If `True`, use locking during the operation.
2130      name: the name of the operation.
2131
2132    Returns:
2133      A `Tensor` that will hold the new value of this variable after
2134      the scattered addition has completed.
2135
2136    Raises:
2137      TypeError: if `sparse_delta` is not an `IndexedSlices`.
2138    """
2139    if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
2140      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
2141    return gen_state_ops.scatter_add(
2142        self._variable,
2143        sparse_delta.indices,
2144        sparse_delta.values,
2145        use_locking=use_locking,
2146        name=name)
2147
2148  def scatter_max(self, sparse_delta, use_locking=False, name=None):
2149    """Updates this variable with the max of `tf.IndexedSlices` and itself.
2150
2151    Args:
2152      sparse_delta: `tf.IndexedSlices` to use as an argument of max with this
2153        variable.
2154      use_locking: If `True`, use locking during the operation.
2155      name: the name of the operation.
2156
2157    Returns:
2158      A `Tensor` that will hold the new value of this variable after
2159      the scattered maximization has completed.
2160
2161    Raises:
2162      TypeError: if `sparse_delta` is not an `IndexedSlices`.
2163    """
2164    if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
2165      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
2166    return gen_state_ops.scatter_max(
2167        self._variable,
2168        sparse_delta.indices,
2169        sparse_delta.values,
2170        use_locking=use_locking,
2171        name=name)
2172
2173  def scatter_min(self, sparse_delta, use_locking=False, name=None):
2174    """Updates this variable with the min of `tf.IndexedSlices` and itself.
2175
2176    Args:
2177      sparse_delta: `tf.IndexedSlices` to use as an argument of min with this
2178        variable.
2179      use_locking: If `True`, use locking during the operation.
2180      name: the name of the operation.
2181
2182    Returns:
2183      A `Tensor` that will hold the new value of this variable after
2184      the scattered minimization has completed.
2185
2186    Raises:
2187      TypeError: if `sparse_delta` is not an `IndexedSlices`.
2188    """
2189    if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
2190      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
2191    return gen_state_ops.scatter_min(
2192        self._variable,
2193        sparse_delta.indices,
2194        sparse_delta.values,
2195        use_locking=use_locking,
2196        name=name)
2197
2198  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
2199    """Multiply this variable by `tf.IndexedSlices`.
2200
2201    Args:
2202      sparse_delta: `tf.IndexedSlices` to multiply this variable by.
2203      use_locking: If `True`, use locking during the operation.
2204      name: the name of the operation.
2205
2206    Returns:
2207      A `Tensor` that will hold the new value of this variable after
2208      the scattered multiplication has completed.
2209
2210    Raises:
2211      TypeError: if `sparse_delta` is not an `IndexedSlices`.
2212    """
2213    if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
2214      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
2215    return gen_state_ops.scatter_mul(
2216        self._variable,
2217        sparse_delta.indices,
2218        sparse_delta.values,
2219        use_locking=use_locking,
2220        name=name)
2221
2222  def scatter_div(self, sparse_delta, use_locking=False, name=None):
2223    """Divide this variable by `tf.IndexedSlices`.
2224
2225    Args:
2226      sparse_delta: `tf.IndexedSlices` to divide this variable by.
2227      use_locking: If `True`, use locking during the operation.
2228      name: the name of the operation.
2229
2230    Returns:
2231      A `Tensor` that will hold the new value of this variable after
2232      the scattered division has completed.
2233
2234    Raises:
2235      TypeError: if `sparse_delta` is not an `IndexedSlices`.
2236    """
2237    if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
2238      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
2239    return gen_state_ops.scatter_div(
2240        self._variable,
2241        sparse_delta.indices,
2242        sparse_delta.values,
2243        use_locking=use_locking,
2244        name=name)
2245
2246  def scatter_update(self, sparse_delta, use_locking=False, name=None):
2247    """Assigns `tf.IndexedSlices` to this variable.
2248
2249    Args:
2250      sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
2251      use_locking: If `True`, use locking during the operation.
2252      name: the name of the operation.
2253
2254    Returns:
2255      A `Tensor` that will hold the new value of this variable after
2256      the scattered assignment has completed.
2257
2258    Raises:
2259      TypeError: if `sparse_delta` is not an `IndexedSlices`.
2260    """
2261    if not isinstance(sparse_delta, indexed_slices.IndexedSlices):
2262      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
2263    return gen_state_ops.scatter_update(
2264        self._variable,
2265        sparse_delta.indices,
2266        sparse_delta.values,
2267        use_locking=use_locking,
2268        name=name)
2269
2270  def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
2271    """Assigns `tf.IndexedSlices` to this variable batch-wise.
2272
2273    Analogous to `batch_gather`. This assumes that this variable and the
2274    sparse_delta IndexedSlices have a series of leading dimensions that are the
2275    same for all of them, and the updates are performed on the last dimension of
2276    indices. In other words, the dimensions should be the following:
2277
2278    `num_prefix_dims = sparse_delta.indices.ndims - 1`
2279    `batch_dim = num_prefix_dims + 1`
2280    `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[
2281         batch_dim:]`
2282
2283    where
2284
2285    `sparse_delta.updates.shape[:num_prefix_dims]`
2286    `== sparse_delta.indices.shape[:num_prefix_dims]`
2287    `== var.shape[:num_prefix_dims]`
2288
2289    And the operation performed can be expressed as:
2290
2291    `var[i_1, ..., i_n,
2292         sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[
2293            i_1, ..., i_n, j]`
2294
2295    When sparse_delta.indices is a 1D tensor, this operation is equivalent to
2296    `scatter_update`.
2297
2298    To avoid this operation one can looping over the first `ndims` of the
2299    variable and using `scatter_update` on the subtensors that result of slicing
2300    the first dimension. This is a valid option for `ndims = 1`, but less
2301    efficient than this implementation.
2302
2303    Args:
2304      sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
2305      use_locking: If `True`, use locking during the operation.
2306      name: the name of the operation.
2307
2308    Returns:
2309      A `Tensor` that will hold the new value of this variable after
2310      the scattered assignment has completed.
2311
2312    Raises:
2313      TypeError: if `sparse_delta` is not an `IndexedSlices`.
2314    """
2315    return state_ops.batch_scatter_update(
2316        self,
2317        sparse_delta.indices,
2318        sparse_delta.values,
2319        use_locking=use_locking,
2320        name=name)
2321
2322  def scatter_nd_sub(self, indices, updates, name=None):
2323    """Applies sparse subtraction to individual values or slices in a Variable.
2324
2325    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
2326
2327    `indices` must be integer tensor, containing indices into `ref`.
2328    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
2329
2330    The innermost dimension of `indices` (with length `K`) corresponds to
2331    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
2332    dimension of `ref`.
2333
2334    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
2335
2336    ```
2337    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
2338    ```
2339
2340    For example, say we want to add 4 scattered elements to a rank-1 tensor to
2341    8 elements. In Python, that update would look like this:
2342
2343    ```python
2344        ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
2345        indices = tf.constant([[4], [3], [1] ,[7]])
2346        updates = tf.constant([9, 10, 11, 12])
2347        op = ref.scatter_nd_sub(indices, updates)
2348        with tf.compat.v1.Session() as sess:
2349          print sess.run(op)
2350    ```
2351
2352    The resulting update to ref would look like this:
2353
2354        [1, -9, 3, -6, -6, 6, 7, -4]
2355
2356    See `tf.scatter_nd` for more details about how to make updates to
2357    slices.
2358
2359    Args:
2360      indices: The indices to be used in the operation.
2361      updates: The values to be used in the operation.
2362      name: the name of the operation.
2363
2364    Returns:
2365      A `Tensor` that will hold the new value of this variable after
2366      the scattered subtraction has completed.
2367    """
2368    return gen_state_ops.scatter_nd_sub(
2369        self._variable, indices, updates, use_locking=True, name=name)
2370
2371  def scatter_nd_add(self, indices, updates, name=None):
2372    """Applies sparse addition to individual values or slices in a Variable.
2373
2374    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
2375
2376    `indices` must be integer tensor, containing indices into `ref`.
2377    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
2378
2379    The innermost dimension of `indices` (with length `K`) corresponds to
2380    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
2381    dimension of `ref`.
2382
2383    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
2384
2385    ```
2386    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
2387    ```
2388
2389    For example, say we want to add 4 scattered elements to a rank-1 tensor to
2390    8 elements. In Python, that update would look like this:
2391
2392    ```python
2393        ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
2394        indices = tf.constant([[4], [3], [1] ,[7]])
2395        updates = tf.constant([9, 10, 11, 12])
2396        add = ref.scatter_nd_add(indices, updates)
2397        with tf.compat.v1.Session() as sess:
2398          print sess.run(add)
2399    ```
2400
2401    The resulting update to ref would look like this:
2402
2403        [1, 13, 3, 14, 14, 6, 7, 20]
2404
2405    See `tf.scatter_nd` for more details about how to make updates to
2406    slices.
2407
2408    Args:
2409      indices: The indices to be used in the operation.
2410      updates: The values to be used in the operation.
2411      name: the name of the operation.
2412
2413    Returns:
2414      A `Tensor` that will hold the new value of this variable after
2415      the scattered addition has completed.
2416    """
2417    return gen_state_ops.scatter_nd_add(
2418        self._variable, indices, updates, use_locking=True, name=name)
2419
2420  def scatter_nd_update(self, indices, updates, name=None):
2421    """Applies sparse assignment to individual values or slices in a Variable.
2422
2423    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
2424
2425    `indices` must be integer tensor, containing indices into `ref`.
2426    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
2427
2428    The innermost dimension of `indices` (with length `K`) corresponds to
2429    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
2430    dimension of `ref`.
2431
2432    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
2433
2434    ```
2435    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
2436    ```
2437
2438    For example, say we want to add 4 scattered elements to a rank-1 tensor to
2439    8 elements. In Python, that update would look like this:
2440
2441    ```python
2442        ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
2443        indices = tf.constant([[4], [3], [1] ,[7]])
2444        updates = tf.constant([9, 10, 11, 12])
2445        op = ref.scatter_nd_update(indices, updates)
2446        with tf.compat.v1.Session() as sess:
2447          print sess.run(op)
2448    ```
2449
2450    The resulting update to ref would look like this:
2451
2452        [1, 11, 3, 10, 9, 6, 7, 12]
2453
2454    See `tf.scatter_nd` for more details about how to make updates to
2455    slices.
2456
2457    Args:
2458      indices: The indices to be used in the operation.
2459      updates: The values to be used in the operation.
2460      name: the name of the operation.
2461
2462    Returns:
2463      A `Tensor` that will hold the new value of this variable after
2464      the scattered assignment has completed.
2465    """
2466    return gen_state_ops.scatter_nd_update(
2467        self._variable, indices, updates, use_locking=True, name=name)
2468
2469  def scatter_nd_max(self, indices, updates, name=None):
2470    """Updates this variable with the max of `tf.IndexedSlices` and itself.
2471
2472    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
2473
2474    `indices` must be integer tensor, containing indices into `ref`.
2475    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
2476
2477    The innermost dimension of `indices` (with length `K`) corresponds to
2478    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
2479    dimension of `ref`.
2480
2481    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
2482
2483    ```
2484    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
2485    ```
2486
2487    See `tf.scatter_nd` for more details about how to make updates to
2488    slices.
2489
2490    Args:
2491      indices: The indices to be used in the operation.
2492      updates: The values to be used in the operation.
2493      name: the name of the operation.
2494
2495    Returns:
2496      A `Tensor` that will hold the new value of this variable after
2497      the scattered addition has completed.
2498    """
2499    return gen_state_ops.scatter_nd_max(
2500        self._variable, indices, updates, use_locking=True, name=name)
2501
2502  def scatter_nd_min(self, indices, updates, name=None):
2503    """Updates this variable with the min of `tf.IndexedSlices` and itself.
2504
2505    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
2506
2507    `indices` must be integer tensor, containing indices into `ref`.
2508    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
2509
2510    The innermost dimension of `indices` (with length `K`) corresponds to
2511    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
2512    dimension of `ref`.
2513
2514    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
2515
2516    ```
2517    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
2518    ```
2519
2520    See `tf.scatter_nd` for more details about how to make updates to
2521    slices.
2522
2523    Args:
2524      indices: The indices to be used in the operation.
2525      updates: The values to be used in the operation.
2526      name: the name of the operation.
2527
2528    Returns:
2529      A `Tensor` that will hold the new value of this variable after
2530      the scattered addition has completed.
2531    """
2532    return gen_state_ops.scatter_nd_min(
2533        self._variable, indices, updates, use_locking=True, name=name)
2534
2535  def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask,
2536                            end_mask, ellipsis_mask, new_axis_mask,
2537                            shrink_axis_mask):
2538    return gen_array_ops.strided_slice_assign(
2539        ref=self._ref(),
2540        begin=begin,
2541        end=end,
2542        strides=strides,
2543        value=value,
2544        name=name,
2545        begin_mask=begin_mask,
2546        end_mask=end_mask,
2547        ellipsis_mask=ellipsis_mask,
2548        new_axis_mask=new_axis_mask,
2549        shrink_axis_mask=shrink_axis_mask)
2550
2551  @deprecated(None, "Prefer Dataset.range instead.")
2552  def count_up_to(self, limit):
2553    """Increments this variable until it reaches `limit`.
2554
2555    When that Op is run it tries to increment the variable by `1`. If
2556    incrementing the variable would bring it above `limit` then the Op raises
2557    the exception `OutOfRangeError`.
2558
2559    If no error is raised, the Op outputs the value of the variable before
2560    the increment.
2561
2562    This is essentially a shortcut for `count_up_to(self, limit)`.
2563
2564    Args:
2565      limit: value at which incrementing the variable raises an error.
2566
2567    Returns:
2568      A `Tensor` that will hold the variable value before the increment. If no
2569      other Op modifies this variable, the values produced will all be
2570      distinct.
2571    """
2572    return state_ops.count_up_to(self._variable, limit=limit)
2573
2574  # Conversion to tensor.
2575  @staticmethod
2576  def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False):  # pylint: disable=invalid-name
2577    """Utility function for converting a Variable to a Tensor."""
2578    _ = name
2579    if dtype and not dtype.is_compatible_with(v.dtype):
2580      raise ValueError(
2581          "Incompatible type conversion requested to type '%s' for variable "
2582          "of type '%s'" % (dtype.name, v.dtype.name))
2583    if as_ref:
2584      return v._ref()  # pylint: disable=protected-access
2585    else:
2586      return v.value()
2587
2588  # NOTE(mrry): This enables the Variable's overloaded "right" binary
2589  # operators to run when the left operand is an ndarray, because it
2590  # accords the Variable class higher priority than an ndarray, or a
2591  # numpy matrix.
2592  # TODO(mrry): Convert this to using numpy's __numpy_ufunc__
2593  # mechanism, which allows more control over how Variables interact
2594  # with ndarrays.
2595  __array_priority__ = 100
2596
2597  @property
2598  def name(self):
2599    """The name of this variable."""
2600    return self._name
2601
2602  @property
2603  def initializer(self):
2604    """The initializer operation for this variable."""
2605    return self._initializer_op
2606
2607  @property
2608  def device(self):
2609    """The device of this variable."""
2610    return self._variable.device
2611
2612  @property
2613  def dtype(self):
2614    """The `DType` of this variable."""
2615    return self._variable.dtype
2616
2617  @property
2618  def op(self):
2619    """The `Operation` of this variable."""
2620    return self._variable.op
2621
2622  @property
2623  def graph(self):
2624    """The `Graph` of this variable."""
2625    return self._variable.graph
2626
2627  @property
2628  def _distribute_strategy(self):
2629    """The `tf.distribute.Strategy` that this variable was created under."""
2630    return None  # Ref variables are never created inside a strategy.
2631
2632  @property
2633  def shape(self):
2634    """The `TensorShape` of this variable.
2635
2636    Returns:
2637      A `TensorShape`.
2638    """
2639    return self._variable.get_shape()
2640
2641  def to_proto(self, export_scope=None):
2642    """Converts a `Variable` to a `VariableDef` protocol buffer.
2643
2644    Args:
2645      export_scope: Optional `string`. Name scope to remove.
2646
2647    Returns:
2648      A `VariableDef` protocol buffer, or `None` if the `Variable` is not
2649      in the specified name scope.
2650    """
2651    if (export_scope is None or self._variable.name.startswith(export_scope)):
2652      var_def = variable_pb2.VariableDef()
2653      var_def.variable_name = ops.strip_name_scope(self._variable.name,
2654                                                   export_scope)
2655      if self._initial_value is not None:
2656        # For backwards compatibility.
2657        var_def.initial_value_name = ops.strip_name_scope(
2658            self._initial_value.name, export_scope)
2659      var_def.trainable = self.trainable
2660      var_def.synchronization = self.synchronization.value
2661      var_def.aggregation = self.aggregation.value
2662      var_def.initializer_name = ops.strip_name_scope(self.initializer.name,
2663                                                      export_scope)
2664      var_def.snapshot_name = ops.strip_name_scope(self._snapshot.name,
2665                                                   export_scope)
2666      if self._save_slice_info:
2667        var_def.save_slice_info_def.MergeFrom(
2668            self._save_slice_info.to_proto(export_scope=export_scope))
2669      return var_def
2670    else:
2671      return None
2672
2673  def __iadd__(self, other):
2674    logging.log_first_n(
2675        logging.WARN, "Variable += will be deprecated. Use variable.assign_add"
2676        " if you want assignment to the variable value or 'x = x + y'"
2677        " if you want a new python Tensor object.", 1)
2678    return self + other
2679
2680  def __isub__(self, other):
2681    logging.log_first_n(
2682        logging.WARN, "Variable -= will be deprecated. Use variable.assign_sub"
2683        " if you want assignment to the variable value or 'x = x - y'"
2684        " if you want a new python Tensor object.", 1)
2685    return self - other
2686
2687  def __imul__(self, other):
2688    logging.log_first_n(
2689        logging.WARN,
2690        "Variable *= will be deprecated. Use `var.assign(var * other)`"
2691        " if you want assignment to the variable value or `x = x * y`"
2692        " if you want a new python Tensor object.", 1)
2693    return self * other
2694
2695  def __idiv__(self, other):
2696    logging.log_first_n(
2697        logging.WARN,
2698        "Variable /= will be deprecated. Use `var.assign(var / other)`"
2699        " if you want assignment to the variable value or `x = x / y`"
2700        " if you want a new python Tensor object.", 1)
2701    return self / other
2702
2703  def __itruediv__(self, other):
2704    logging.log_first_n(
2705        logging.WARN,
2706        "Variable /= will be deprecated. Use `var.assign(var / other)`"
2707        " if you want assignment to the variable value or `x = x / y`"
2708        " if you want a new python Tensor object.", 1)
2709    return self / other
2710
2711  def __irealdiv__(self, other):
2712    logging.log_first_n(
2713        logging.WARN,
2714        "Variable /= will be deprecated. Use `var.assign(var / other)`"
2715        " if you want assignment to the variable value or `x = x / y`"
2716        " if you want a new python Tensor object.", 1)
2717    return self / other
2718
2719  def __ipow__(self, other):
2720    logging.log_first_n(
2721        logging.WARN,
2722        "Variable **= will be deprecated. Use `var.assign(var ** other)`"
2723        " if you want assignment to the variable value or `x = x ** y`"
2724        " if you want a new python Tensor object.", 1)
2725    return self**other
2726
2727
2728def _try_guard_against_uninitialized_dependencies(name, initial_value):
2729  """Attempt to guard against dependencies on uninitialized variables.
2730
2731  Replace references to variables in `initial_value` with references to the
2732  variable's initialized values. The initialized values are essentially
2733  conditional TensorFlow graphs that return a variable's value if it is
2734  initialized or its `initial_value` if it hasn't been initialized. This
2735  replacement is done on a best effort basis:
2736
2737  - If the `initial_value` graph contains cycles, we don't do any
2738    replacements for that graph.
2739  - If the variables that `initial_value` depends on are not present in the
2740    `GLOBAL_VARIABLES` or `LOCAL_VARIABLES` we don't replace them.
2741
2742  In these cases, it is up to the caller to ensure that the `initial_value`
2743  graph uses initialized variables or that they guard access to variables
2744  using their `initialized_value` method.
2745
2746  Args:
2747    name: Variable name.
2748    initial_value: `Tensor`. The initial value.
2749
2750  Returns:
2751    A `Tensor` suitable to initialize a variable.
2752  Raises:
2753    TypeError: If `initial_value` is not a `Tensor`.
2754  """
2755  if not isinstance(initial_value, ops.Tensor):
2756    raise TypeError("initial_value needs to be a Tensor: %s" % initial_value)
2757
2758  # Don't modify initial_value if it contains any cyclic dependencies.
2759  if _has_cycle(initial_value.op, state={}):
2760    return initial_value
2761  return _safe_initial_value_from_tensor(name, initial_value, op_cache={})
2762
2763
2764_UNKNOWN, _STARTED, _FINISHED = range(3)
2765
2766
2767def _has_cycle(op, state):
2768  """Detect cycles in the dependencies of `initial_value`."""
2769  op_state = state.get(op.name, _UNKNOWN)
2770  if op_state == _STARTED:
2771    return True
2772  elif op_state == _FINISHED:
2773    return False
2774
2775  state[op.name] = _STARTED
2776  for i in itertools.chain((i.op for i in op.inputs), op.control_inputs):
2777    if _has_cycle(i, state):
2778      return True
2779  state[op.name] = _FINISHED
2780  return False
2781
2782
2783def _safe_initial_value_from_tensor(name, tensor, op_cache):
2784  """Replace dependencies on variables with their initialized values.
2785
2786  Args:
2787    name: Variable name.
2788    tensor: A `Tensor`. The tensor to replace.
2789    op_cache: A dict mapping operation names to `Operation`s. Used to memoize
2790      the results so as to avoid creating redundant operations.
2791
2792  Returns:
2793    A `Tensor` compatible with `tensor`. Any inputs that lead to variable
2794    values will be replaced with a corresponding graph that uses the
2795    variable's initialized values. This is done on a best-effort basis. If no
2796    modifications need to be made then `tensor` will be returned unchanged.
2797  """
2798  op = tensor.op
2799  new_op = op_cache.get(op.name)
2800  if new_op is None:
2801    new_op = _safe_initial_value_from_op(name, op, op_cache)
2802    op_cache[op.name] = new_op
2803  return new_op.outputs[tensor.value_index]
2804
2805
2806def _safe_initial_value_from_op(name, op, op_cache):
2807  """Replace dependencies on variables with their initialized values.
2808
2809  Args:
2810    name: Variable name.
2811    op: An `Operation`. The operation to replace.
2812    op_cache: A dict mapping operation names to `Operation`s. Used to memoize
2813      the results so as to avoid creating redundant operations.
2814
2815  Returns:
2816    An `Operation` compatible with `op`. Any inputs that lead to variable
2817    values will be replaced with a corresponding graph that uses the
2818    variable's initialized values. This is done on a best-effort basis. If no
2819    modifications need to be made then `op` will be returned unchanged.
2820  """
2821  op_type = op.node_def.op
2822  if op_type in ("IsVariableInitialized", "VarIsInitializedOp",
2823                 "ReadVariableOp", "If"):
2824    return op
2825
2826  # Attempt to find the initialized_value of any variable reference / handles.
2827  # TODO(b/70206927): Fix handling of ResourceVariables.
2828  if op_type in ("Variable", "VariableV2", "VarHandleOp"):
2829    initialized_value = _find_initialized_value_for_variable(op)
2830    return op if initialized_value is None else initialized_value.op
2831
2832  # Recursively build initializer expressions for inputs.
2833  modified = False
2834  new_op_inputs = []
2835  for op_input in op.inputs:
2836    new_op_input = _safe_initial_value_from_tensor(name, op_input, op_cache)
2837    new_op_inputs.append(new_op_input)
2838    modified = modified or (new_op_input != op_input)
2839
2840  # If at least one input was modified, replace the op.
2841  if modified:
2842    new_op_type = op_type
2843    if new_op_type == "RefSwitch":
2844      new_op_type = "Switch"
2845    new_op_name = op.node_def.name + "_" + name
2846    new_op_name = new_op_name.replace(":", "_")
2847    return op.graph.create_op(
2848        new_op_type,
2849        new_op_inputs,
2850        op._output_types,  # pylint: disable=protected-access
2851        name=new_op_name,
2852        attrs=op.node_def.attr)
2853
2854  return op
2855
2856
2857def _find_initialized_value_for_variable(variable_op):
2858  """Find the initialized value for a variable op.
2859
2860  To do so, lookup the variable op in the variables collection.
2861
2862  Args:
2863    variable_op: A variable `Operation`.
2864
2865  Returns:
2866    A `Tensor` representing the initialized value for the variable or `None`
2867    if the initialized value could not be found.
2868  """
2869  try:
2870    var_names = [variable_op.node_def.name, variable_op.node_def.name + ":0"]
2871    for collection_name in (ops.GraphKeys.GLOBAL_VARIABLES,
2872                            ops.GraphKeys.LOCAL_VARIABLES):
2873      for var in variable_op.graph.get_collection(collection_name):
2874        if var.name in var_names:
2875          return var.initialized_value()
2876  except AttributeError:
2877    # Return None when an incomplete user-defined variable type was put in
2878    # the collection.
2879    return None
2880  return None
2881
2882
2883class PartitionedVariable:
2884  """A container for partitioned `Variable` objects.
2885
2886  @compatibility(eager) `tf.PartitionedVariable` is not compatible with
2887  eager execution.  Use `tf.Variable` instead which is compatible
2888  with both eager execution and graph construction.  See [the
2889  TensorFlow Eager Execution
2890  guide](https://www.tensorflow.org/guide/eager#variables_and_optimizers)
2891  for details on how variables work in eager execution.
2892  @end_compatibility
2893  """
2894
2895  def __init__(self, name, shape, dtype, variable_list, partitions):
2896    """Creates a new partitioned variable wrapper.
2897
2898    Variables passed via the variable_list must contain a save_slice_info
2899    field.  Concatenation and iteration is in lexicographic order according
2900    to the var_offset property of the save_slice_info.
2901
2902    Args:
2903      name: String. Overall name of the variables.
2904      shape: List of integers.  Overall shape of the variables.
2905      dtype: Type of the variables.
2906      variable_list: List of `Variable` that comprise this partitioned variable.
2907      partitions: List of integers.  Number of partitions for each dimension.
2908
2909    Raises:
2910      TypeError: If `variable_list` is not a list of `Variable` objects, or
2911        `partitions` is not a list.
2912      ValueError: If `variable_list` is empty, or the `Variable` shape
2913        information does not match `shape`, or `partitions` has invalid values.
2914    """
2915    if not isinstance(variable_list, (list, tuple)):
2916      raise TypeError("variable_list is not a list or tuple: %s" %
2917                      variable_list)
2918    if not isinstance(partitions, (list, tuple)):
2919      raise TypeError("partitions is not a list or tuple: %s" % partitions)
2920    if not all(p >= 1 for p in partitions):
2921      raise ValueError("partition values must be positive: %s" % partitions)
2922    if not variable_list:
2923      raise ValueError("variable_list may not be empty")
2924    # pylint: disable=protected-access
2925    for v in variable_list:
2926      # Sort the variable_list lexicographically according to var offset value.
2927      if not all(v._get_save_slice_info() is not None for v in variable_list):
2928        raise ValueError(
2929            "All variables must have a save_slice_info available: %s" %
2930            [v.name for v in variable_list])
2931      if len(shape) != len(partitions):
2932        raise ValueError("len(shape) != len(partitions): %s vs. %s" %
2933                         (shape, partitions))
2934      if v._get_save_slice_info().full_shape != shape:
2935        raise ValueError("All variables' full shapes must match shape: %s; "
2936                         "but full shapes were: %s" %
2937                         (shape, str([v._get_save_slice_info().full_shape])))
2938    self._variable_list = sorted(
2939        variable_list, key=lambda v: v._get_save_slice_info().var_offset)
2940    # pylint: enable=protected-access
2941
2942    self._name = name
2943    self._shape = shape
2944    self._dtype = dtype
2945    self._partitions = partitions
2946    self._as_tensor = None
2947
2948  def __iter__(self):
2949    """Return an iterable for accessing the underlying partition Variables."""
2950    return iter(self._variable_list)
2951
2952  def __len__(self):
2953    num_partition_axes = len(self._partition_axes())
2954    if num_partition_axes > 1:
2955      raise ValueError("Cannot get a length for %d > 1 partition axes" %
2956                       num_partition_axes)
2957    return len(self._variable_list)
2958
2959  def _partition_axes(self):
2960    if all(p == 1 for p in self._partitions):
2961      return [0]
2962    else:
2963      return [i for i, p in enumerate(self._partitions) if p > 1]
2964
2965  def _concat(self):
2966    """Returns the overall concatenated value as a `Tensor`.
2967
2968    This is different from using the partitioned variable directly as a tensor
2969    (through tensor conversion and `as_tensor`) in that it creates a new set of
2970    operations that keeps the control dependencies from its scope.
2971
2972    Returns:
2973      `Tensor` containing the concatenated value.
2974    """
2975    if len(self._variable_list) == 1:
2976      with ops.name_scope(None):
2977        return array_ops.identity(self._variable_list[0], name=self._name)
2978
2979    partition_axes = self._partition_axes()
2980
2981    if len(partition_axes) > 1:
2982      raise NotImplementedError(
2983          "Cannot concatenate along more than one dimension: %s.  "
2984          "Multi-axis partition concat is not supported" % str(partition_axes))
2985    partition_ix = partition_axes[0]
2986
2987    with ops.name_scope(self._name + "/ConcatPartitions/"):
2988      concatenated = array_ops.concat(self._variable_list, partition_ix)
2989
2990    with ops.name_scope(None):
2991      return array_ops.identity(concatenated, name=self._name)
2992
2993  def as_tensor(self):
2994    """Returns the overall concatenated value as a `Tensor`.
2995
2996    The returned tensor will not inherit the control dependencies from the scope
2997    where the value is used, which is similar to getting the value of
2998    `Variable`.
2999
3000    Returns:
3001      `Tensor` containing the concatenated value.
3002    """
3003    with ops.control_dependencies(None):
3004      return self._concat()
3005
3006  @staticmethod
3007  def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False):
3008    # pylint: disable=invalid-name
3009    _ = name
3010    if dtype is not None and not dtype.is_compatible_with(v.dtype):
3011      raise ValueError(
3012          "Incompatible type conversion requested to type '%s' for variable "
3013          "of type '%s'" % (dtype.name, v.dtype.name))
3014    if as_ref:
3015      raise NotImplementedError(
3016          "PartitionedVariable doesn't support being used as a reference.")
3017    else:
3018      return v.as_tensor()
3019
3020  @property
3021  def name(self):
3022    return self._name
3023
3024  @property
3025  def dtype(self):
3026    return self._dtype
3027
3028  @property
3029  def shape(self):
3030    return self.get_shape()
3031
3032  @property
3033  def _distribute_strategy(self):
3034    """The `tf.distribute.Strategy` that this variable was created under."""
3035    # NOTE(yuefengz): Today, no partitioned variables in a distribute strategy.
3036    return None
3037
3038  def get_shape(self):
3039    return self._shape
3040
3041  def _get_variable_list(self):
3042    return self._variable_list
3043
3044  def _get_partitions(self):
3045    return self._partitions
3046
3047  def _apply_assign_fn(self, assign_fn, value):
3048    partition_axes = self._partition_axes()
3049    if len(partition_axes) > 1:
3050      raise NotImplementedError(
3051          "Cannot do assign action along more than one dimension: %s.  "
3052          "Multi-axis partition assign action is not supported " %
3053          str(partition_axes))
3054    if isinstance(value, list):
3055      assert len(value) == len(self._variable_list)
3056      value_list = value
3057    elif isinstance(value, PartitionedVariable):
3058      value_list = [var_part for var_part in value]
3059    else:
3060      partition_ix = partition_axes[0]
3061      size_splits_list = [
3062          tensor_shape.dimension_value(var.shape[partition_ix])
3063          for var in self._variable_list
3064      ]
3065      value_list = array_ops.split(value, size_splits_list, axis=partition_ix)
3066
3067    op_list = [
3068        assign_fn(var, value_list[idx])
3069        for idx, var in enumerate(self._variable_list)
3070    ]
3071    return op_list
3072
3073  def assign(self, value, use_locking=False, name=None, read_value=True):
3074    assign_fn = lambda var, r_value: var.assign(
3075        r_value, use_locking=use_locking, name=name, read_value=read_value)
3076    assign_list = self._apply_assign_fn(assign_fn, value)
3077    if read_value:
3078      return assign_list
3079    return [assign.op for assign in assign_list]
3080
3081  def assign_add(self, value, use_locking=False, name=None, read_value=True):
3082    assign_fn = lambda var, r_value: var.assign_add(
3083        r_value, use_locking=use_locking, name=name, read_value=read_value)
3084    assign_list = self._apply_assign_fn(assign_fn, value)
3085    if read_value:
3086      return assign_list
3087    return [assign.op for assign in assign_list]
3088
3089  def assign_sub(self, value, use_locking=False, name=None, read_value=True):
3090    assign_fn = lambda var, r_value: var.assign_sub(
3091        r_value, use_locking=use_locking, name=name, read_value=read_value)
3092    assign_list = self._apply_assign_fn(assign_fn, value)
3093    if read_value:
3094      return assign_list
3095    return [assign.op for assign in assign_list]
3096
3097
3098# Register a conversion function which reads the value of the variable,
3099# allowing instances of the class to be used as tensors.
3100ops.register_tensor_conversion_function(RefVariable,
3101                                        RefVariable._TensorConversionFunction)  # pylint: disable=protected-access
3102
3103
3104@tf_export(v1=["global_variables"])
3105def global_variables(scope=None):
3106  """Returns global variables.
3107
3108  Global variables are variables that are shared across machines in a
3109  distributed environment. The `Variable()` constructor or `get_variable()`
3110  automatically adds new variables to the graph collection
3111  `GraphKeys.GLOBAL_VARIABLES`.
3112  This convenience function returns the contents of that collection.
3113
3114  An alternative to global variables are local variables. See
3115  `tf.compat.v1.local_variables`
3116
3117  @compatibility(TF2)
3118  Not compatible with eager execution and `tf.function`. In particular, Graph
3119  collections are deprecated in TF2. Instead please create a
3120  [tf.Module](https://www.tensorflow.org/guide/intro_to_modules)
3121  container for all your model state, including variables.
3122  You can then list all the variables in your `tf.Module` through the
3123  `variables` attribute.
3124  @end_compatibility
3125
3126  Args:
3127    scope: (Optional.) A string. If supplied, the resulting list is filtered to
3128      include only items whose `name` attribute matches `scope` using
3129      `re.match`. Items without a `name` attribute are never returned if a scope
3130      is supplied. The choice of `re.match` means that a `scope` without special
3131      tokens filters by prefix.
3132
3133  Returns:
3134    A list of `Variable` objects.
3135  """
3136  return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope)
3137
3138
3139@tf_export(v1=["all_variables"])
3140@deprecated("2017-03-02", "Please use tf.global_variables instead.")
3141def all_variables():
3142  """Use `tf.compat.v1.global_variables` instead."""
3143  return global_variables()
3144
3145
3146def _all_saveable_objects(scope=None):
3147  """Returns all variables and `SaveableObject`s that must be checkpointed.
3148
3149  Args:
3150    scope: (Optional.) A string. If supplied, the resulting list is filtered to
3151      include only items whose `name` attribute matches `scope` using
3152      `re.match`. Items without a `name` attribute are never returned if a scope
3153      is supplied. The choice of `re.match` means that a `scope` without special
3154      tokens filters by prefix.
3155
3156  Returns:
3157    A list of `Variable` and `SaveableObject` to be checkpointed
3158  """
3159  # TODO(andreasst): make this function public once things are settled.
3160  return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) +
3161          ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope))
3162
3163
3164@tf_export(v1=["local_variables"])
3165def local_variables(scope=None):
3166  """Returns local variables.
3167
3168  Local variables - per process variables, usually not saved/restored to
3169  checkpoint and used for temporary or intermediate values.
3170  For example, they can be used as counters for metrics computation or
3171  number of epochs this machine has read data.
3172  The `tf.contrib.framework.local_variable()` function automatically adds the
3173  new variable to `GraphKeys.LOCAL_VARIABLES`.
3174  This convenience function returns the contents of that collection.
3175
3176  An alternative to local variables are global variables. See
3177  `tf.compat.v1.global_variables`
3178
3179  Args:
3180    scope: (Optional.) A string. If supplied, the resulting list is filtered to
3181      include only items whose `name` attribute matches `scope` using
3182      `re.match`. Items without a `name` attribute are never returned if a scope
3183      is supplied. The choice of `re.match` means that a `scope` without special
3184      tokens filters by prefix.
3185
3186  Returns:
3187    A list of local `Variable` objects.
3188  """
3189  return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES, scope)
3190
3191
3192@tf_export(v1=["model_variables"])
3193def model_variables(scope=None):
3194  """Returns all variables in the MODEL_VARIABLES collection.
3195
3196  Args:
3197    scope: (Optional.) A string. If supplied, the resulting list is filtered to
3198      include only items whose `name` attribute matches `scope` using
3199      `re.match`. Items without a `name` attribute are never returned if a scope
3200      is supplied. The choice of `re.match` means that a `scope` without special
3201      tokens filters by prefix.
3202
3203  Returns:
3204    A list of local Variable objects.
3205  """
3206  return ops.get_collection(ops.GraphKeys.MODEL_VARIABLES, scope)
3207
3208
3209@tf_export(v1=["trainable_variables"])
3210def trainable_variables(scope=None):
3211  """Returns all variables created with `trainable=True`.
3212
3213  When passed `trainable=True`, the `Variable()` constructor automatically
3214  adds new variables to the graph collection
3215  `GraphKeys.TRAINABLE_VARIABLES`. This convenience function returns the
3216  contents of that collection.
3217
3218  @compatibility(TF2)
3219  Not compatible with eager execution and `tf.function`. In particular, Graph
3220  collections are deprecated in TF2. Instead please create a `tf.Module`
3221  container for all your model state, including variables.
3222  You can then list all the trainable variables in your `tf.Module` through the
3223  `trainable_variables` attribute.
3224  @end_compatibility
3225
3226  Args:
3227    scope: (Optional.) A string. If supplied, the resulting list is filtered to
3228      include only items whose `name` attribute matches `scope` using
3229      `re.match`. Items without a `name` attribute are never returned if a scope
3230      is supplied. The choice of `re.match` means that a `scope` without special
3231      tokens filters by prefix.
3232
3233  Returns:
3234    A list of Variable objects.
3235  """
3236  return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, scope)
3237
3238
3239@tf_export(v1=["moving_average_variables"])
3240def moving_average_variables(scope=None):
3241  """Returns all variables that maintain their moving averages.
3242
3243  If an `ExponentialMovingAverage` object is created and the `apply()`
3244  method is called on a list of variables, these variables will
3245  be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection.
3246  This convenience function returns the contents of that collection.
3247
3248  Args:
3249    scope: (Optional.) A string. If supplied, the resulting list is filtered to
3250      include only items whose `name` attribute matches `scope` using
3251      `re.match`. Items without a `name` attribute are never returned if a scope
3252      is supplied. The choice of `re.match` means that a `scope` without special
3253      tokens filters by prefix.
3254
3255  Returns:
3256    A list of Variable objects.
3257  """
3258  return ops.get_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, scope)
3259
3260
3261@tf_export(v1=["initializers.variables", "variables_initializer"])
3262def variables_initializer(var_list, name="init"):
3263  """Returns an Op that initializes a list of variables.
3264
3265  After you launch the graph in a session, you can run the returned Op to
3266  initialize all the variables in `var_list`. This Op runs all the
3267  initializers of the variables in `var_list` in parallel.
3268
3269  Calling `initialize_variables()` is equivalent to passing the list of
3270  initializers to `Group()`.
3271
3272  If `var_list` is empty, however, the function still returns an Op that can
3273  be run. That Op just has no effect.
3274
3275  @compatibility(TF2)
3276  In TF2, variables are initialized immediately when they are created. There is
3277  no longer a need to run variable initializers before using them.
3278  @end_compatibility
3279
3280  Args:
3281    var_list: List of `Variable` objects to initialize.
3282    name: Optional name for the returned operation.
3283
3284  Returns:
3285    An Op that run the initializers of all the specified variables.
3286  """
3287  if var_list and not context.executing_eagerly():
3288    return control_flow_ops.group(*[v.initializer for v in var_list], name=name)
3289  return control_flow_ops.no_op(name=name)
3290
3291
3292@tf_export(v1=["initialize_variables"])
3293@tf_should_use.should_use_result
3294@deprecated("2017-03-02", "Use `tf.variables_initializer` instead.")
3295def initialize_variables(var_list, name="init"):
3296  """See `tf.compat.v1.variables_initializer`."""
3297  return variables_initializer(var_list, name=name)
3298
3299
3300@tf_export(v1=["initializers.global_variables", "global_variables_initializer"])
3301def global_variables_initializer():
3302  """Returns an Op that initializes global variables.
3303
3304  This is just a shortcut for `variables_initializer(global_variables())`
3305
3306  @compatibility(TF2)
3307  In TF2, variables are initialized immediately when they are created. There is
3308  no longer a need to run variable initializers before using them.
3309  @end_compatibility
3310
3311  Returns:
3312    An Op that initializes global variables in the graph.
3313  """
3314  if context.executing_eagerly():
3315    return control_flow_ops.no_op(name="global_variables_initializer")
3316  return variables_initializer(global_variables())
3317
3318
3319@tf_export(v1=["initialize_all_variables"])
3320@tf_should_use.should_use_result
3321@deprecated("2017-03-02", "Use `tf.global_variables_initializer` instead.")
3322def initialize_all_variables():
3323  """See `tf.compat.v1.global_variables_initializer`."""
3324  return global_variables_initializer()
3325
3326
3327@tf_export(v1=["initializers.local_variables", "local_variables_initializer"])
3328def local_variables_initializer():
3329  """Returns an Op that initializes all local variables.
3330
3331  This is just a shortcut for `variables_initializer(local_variables())`
3332
3333  @compatibility(TF2)
3334  In TF2, variables are initialized immediately when they are created. There is
3335  no longer a need to run variable initializers before using them.
3336  @end_compatibility
3337
3338  Returns:
3339    An Op that initializes all local variables in the graph.
3340  """
3341  if context.executing_eagerly():
3342    return control_flow_ops.no_op(name="local_variables_initializer")
3343  return variables_initializer(local_variables())
3344
3345
3346@tf_export(v1=["initialize_local_variables"])
3347@tf_should_use.should_use_result
3348@deprecated("2017-03-02", "Use `tf.local_variables_initializer` instead.")
3349def initialize_local_variables():
3350  """See `tf.compat.v1.local_variables_initializer`."""
3351  return local_variables_initializer()
3352
3353
3354@tf_export(v1=["is_variable_initialized"])
3355@tf_should_use.should_use_result
3356def is_variable_initialized(variable):
3357  """Tests if a variable has been initialized.
3358
3359  Args:
3360    variable: A `Variable`.
3361
3362  Returns:
3363    Returns a scalar boolean Tensor, `True` if the variable has been
3364    initialized, `False` otherwise.
3365  """
3366  return state_ops.is_variable_initialized(variable)
3367
3368
3369@tf_export(v1=["assert_variables_initialized"])
3370@tf_should_use.should_use_result
3371def assert_variables_initialized(var_list=None):
3372  """Returns an Op to check if variables are initialized.
3373
3374  NOTE: This function is obsolete and will be removed in 6 months.  Please
3375  change your implementation to use `report_uninitialized_variables()`.
3376
3377  When run, the returned Op will raise the exception `FailedPreconditionError`
3378  if any of the variables has not yet been initialized.
3379
3380  Note: This function is implemented by trying to fetch the values of the
3381  variables. If one of the variables is not initialized a message may be
3382  logged by the C++ runtime. This is expected.
3383
3384  Args:
3385    var_list: List of `Variable` objects to check. Defaults to the value of
3386      `global_variables().`
3387
3388  Returns:
3389    An Op, or None if there are no variables.
3390  """
3391  if var_list is None:
3392    var_list = global_variables() + local_variables()
3393  # Backwards compatibility for old-style variables. TODO(touts): remove.
3394  if not var_list:
3395    var_list = []
3396    for op in ops.get_default_graph().get_operations():
3397      if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]:
3398        var_list.append(op.outputs[0])
3399  if not var_list:
3400    return None
3401  else:
3402    ranks = []
3403    for var in var_list:
3404      with ops.colocate_with(var.op):
3405        ranks.append(array_ops.rank_internal(var, optimize=False))
3406    if len(ranks) == 1:
3407      return ranks[0]
3408    else:
3409      return array_ops.stack(ranks)
3410
3411
3412@tf_export(v1=["report_uninitialized_variables"])
3413@tf_should_use.should_use_result
3414def report_uninitialized_variables(var_list=None,
3415                                   name="report_uninitialized_variables"):
3416  """Adds ops to list the names of uninitialized variables.
3417
3418  When run, it returns a 1-D tensor containing the names of uninitialized
3419  variables if there are any, or an empty array if there are none.
3420
3421  Args:
3422    var_list: List of `Variable` objects to check. Defaults to the value of
3423      `global_variables() + local_variables()`
3424    name: Optional name of the `Operation`.
3425
3426  Returns:
3427    A 1-D tensor containing names of the uninitialized variables, or an empty
3428    1-D tensor if there are no variables or no uninitialized variables.
3429  """
3430  if var_list is None:
3431    var_list = global_variables() + local_variables()
3432    # Backwards compatibility for old-style variables. TODO(touts): remove.
3433    if not var_list:
3434      var_list = []
3435      for op in ops.get_default_graph().get_operations():
3436        if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]:
3437          var_list.append(op.outputs[0])
3438  with ops.name_scope(name):
3439    # Run all operations on CPU
3440    if var_list:
3441      init_vars = [state_ops.is_variable_initialized(v) for v in var_list]
3442    local_device = os.environ.get(
3443        "TF_DEVICE_FOR_UNINITIALIZED_VARIABLE_REPORTING", "/cpu:0")
3444    with ops.device(local_device):
3445      if not var_list:
3446        # Return an empty tensor so we only need to check for returned tensor
3447        # size being 0 as an indication of model ready.
3448        return array_ops.constant([], dtype=dtypes.string)
3449      else:
3450        # Get a 1-D boolean tensor listing whether each variable is initialized.
3451        variables_mask = math_ops.logical_not(array_ops.stack(init_vars))
3452        # Get a 1-D string tensor containing all the variable names.
3453        variable_names_tensor = array_ops.constant(
3454            [s.op.name for s in var_list])
3455        # Return a 1-D tensor containing all the names of
3456        # uninitialized variables.
3457        return array_ops.boolean_mask(variable_names_tensor, variables_mask)
3458
3459
3460ops.register_tensor_conversion_function(
3461    PartitionedVariable, PartitionedVariable._TensorConversionFunction)  # pylint: disable=protected-access
3462