• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Variable class."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import enum  # pylint: disable=g-bad-import-order
21import itertools
22import functools
23import os
24
25import six
26
27from tensorflow.core.framework import attr_value_pb2
28from tensorflow.core.framework import variable_pb2
29from tensorflow.python import pywrap_tensorflow  # pylint: disable=unused-import
30from tensorflow.python.eager import context
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import control_flow_ops
36from tensorflow.python.ops import gen_array_ops
37from tensorflow.python.ops import gen_state_ops
38from tensorflow.python.ops import gen_math_ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.ops import state_ops
41from tensorflow.python.platform import tf_logging as logging
42from tensorflow.python.training.tracking import base as trackable
43from tensorflow.python.util import _pywrap_utils
44from tensorflow.python.util import compat
45from tensorflow.python.util import object_identity
46from tensorflow.python.util import tf_should_use
47from tensorflow.python.util.deprecation import deprecated
48from tensorflow.python.util.deprecation import deprecated_args
49from tensorflow.python.util.tf_export import tf_export
50from tensorflow.python.types import core
51
52
53def default_variable_creator(_, **kwds):
54  del kwds
55  raise NotImplementedError("variable_scope needs to be imported")
56
57
58def default_variable_creator_v2(_, **kwds):
59  del kwds
60  raise NotImplementedError("variable_scope needs to be imported")
61
62
63def _make_getter(captured_getter, captured_previous):
64  """To avoid capturing loop variables."""
65
66  def getter(**kwargs):
67    return captured_getter(captured_previous, **kwargs)
68
69  return getter
70
71
72@tf_export("VariableSynchronization")
73class VariableSynchronization(enum.Enum):
74  """Indicates when a distributed variable will be synced.
75
76  * `AUTO`: Indicates that the synchronization will be determined by the current
77    `DistributionStrategy` (eg. With `MirroredStrategy` this would be
78    `ON_WRITE`).
79  * `NONE`: Indicates that there will only be one copy of the variable, so
80    there is no need to sync.
81  * `ON_WRITE`: Indicates that the variable will be updated across devices
82    every time it is written.
83  * `ON_READ`: Indicates that the variable will be aggregated across devices
84    when it is read (eg. when checkpointing or when evaluating an op that uses
85    the variable).
86  """
87  AUTO = 0
88  NONE = 1
89  ON_WRITE = 2
90  ON_READ = 3
91
92
93# LINT.IfChange
94@tf_export("VariableAggregation", v1=[])
95class VariableAggregationV2(enum.Enum):
96  """Indicates how a distributed variable will be aggregated.
97
98  `tf.distribute.Strategy` distributes a model by making multiple copies
99  (called "replicas") acting data-parallel on different elements of the input
100  batch. When performing some variable-update operation, say
101  `var.assign_add(x)`, in a model, we need to resolve how to combine the
102  different values for `x` computed in the different replicas.
103
104  * `NONE`: This is the default, giving an error if you use a
105    variable-update operation with multiple replicas.
106  * `SUM`: Add the updates across replicas.
107  * `MEAN`: Take the arithmetic mean ("average") of the updates across replicas.
108  * `ONLY_FIRST_REPLICA`: This is for when every replica is performing the same
109    update, but we only want to perform the update once. Used, e.g., for the
110    global step counter.
111  """
112  NONE = 0
113  SUM = 1
114  MEAN = 2
115  ONLY_FIRST_REPLICA = 3
116
117  def __hash__(self):
118    return hash(self.value)
119
120  def __eq__(self, other):
121    if self is other:
122      return True
123    elif isinstance(other, VariableAggregation):
124      return int(self.value) == int(other.value)
125    else:
126      return False
127
128
129@tf_export(v1=["VariableAggregation"])
130class VariableAggregation(enum.Enum):
131  NONE = 0
132  SUM = 1
133  MEAN = 2
134  ONLY_FIRST_REPLICA = 3
135  ONLY_FIRST_TOWER = 3  # DEPRECATED
136
137  def __hash__(self):
138    return hash(self.value)
139
140
141# LINT.ThenChange(//tensorflow/core/framework/variable.proto)
142#
143# Note that we are currently relying on the integer values of the Python enums
144# matching the integer values of the proto enums.
145
146VariableAggregation.__doc__ = (
147    VariableAggregationV2.__doc__ +
148    "* `ONLY_FIRST_TOWER`: Deprecated alias for `ONLY_FIRST_REPLICA`.\n  ")
149
150
151def validate_synchronization_aggregation_trainable(synchronization, aggregation,
152                                                   trainable, name):
153  """Given user-provided variable properties, sets defaults and validates."""
154  if aggregation is None:
155    aggregation = VariableAggregation.NONE
156  else:
157    if not isinstance(aggregation,
158                      (VariableAggregation, VariableAggregationV2)):
159      try:
160        aggregation = VariableAggregationV2(aggregation)
161      except ValueError:
162        raise ValueError(
163            "Invalid variable aggregation mode: {} for variable: {}".format(
164                aggregation, name))
165  if synchronization is None:
166    synchronization = VariableSynchronization.AUTO
167  else:
168    try:
169      synchronization = VariableSynchronization(synchronization)
170    except ValueError:
171      raise ValueError(
172          "Invalid variable synchronization mode: {} for variable: {}".format(
173              synchronization, name))
174  if trainable is None:
175    trainable = synchronization != VariableSynchronization.ON_READ
176  return synchronization, aggregation, trainable
177
178
179class VariableMetaclass(type):
180  """Metaclass to allow construction of tf.Variable to be overridden."""
181
182  def _variable_v1_call(cls,
183                        initial_value=None,
184                        trainable=None,
185                        collections=None,
186                        validate_shape=True,
187                        caching_device=None,
188                        name=None,
189                        variable_def=None,
190                        dtype=None,
191                        expected_shape=None,
192                        import_scope=None,
193                        constraint=None,
194                        use_resource=None,
195                        synchronization=VariableSynchronization.AUTO,
196                        aggregation=VariableAggregation.NONE,
197                        shape=None):
198    """Call on Variable class. Useful to force the signature."""
199    previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
200    for _, getter in ops.get_default_graph()._variable_creator_stack:  # pylint: disable=protected-access
201      previous_getter = _make_getter(getter, previous_getter)
202
203    # Reset `aggregation` that is explicitly set as `None` to the enum NONE.
204    if aggregation is None:
205      aggregation = VariableAggregation.NONE
206    return previous_getter(
207        initial_value=initial_value,
208        trainable=trainable,
209        collections=collections,
210        validate_shape=validate_shape,
211        caching_device=caching_device,
212        name=name,
213        variable_def=variable_def,
214        dtype=dtype,
215        expected_shape=expected_shape,
216        import_scope=import_scope,
217        constraint=constraint,
218        use_resource=use_resource,
219        synchronization=synchronization,
220        aggregation=aggregation,
221        shape=shape)
222
223  def _variable_v2_call(cls,
224                        initial_value=None,
225                        trainable=None,
226                        validate_shape=True,
227                        caching_device=None,
228                        name=None,
229                        variable_def=None,
230                        dtype=None,
231                        import_scope=None,
232                        constraint=None,
233                        synchronization=VariableSynchronization.AUTO,
234                        aggregation=VariableAggregation.NONE,
235                        shape=None):
236    """Call on Variable class. Useful to force the signature."""
237    previous_getter = lambda **kws: default_variable_creator_v2(None, **kws)
238    for _, getter in ops.get_default_graph()._variable_creator_stack:  # pylint: disable=protected-access
239      previous_getter = _make_getter(getter, previous_getter)
240
241    # Reset `aggregation` that is explicitly set as `None` to the enum NONE.
242    if aggregation is None:
243      aggregation = VariableAggregation.NONE
244    return previous_getter(
245        initial_value=initial_value,
246        trainable=trainable,
247        validate_shape=validate_shape,
248        caching_device=caching_device,
249        name=name,
250        variable_def=variable_def,
251        dtype=dtype,
252        import_scope=import_scope,
253        constraint=constraint,
254        synchronization=synchronization,
255        aggregation=aggregation,
256        shape=shape)
257
258  def __call__(cls, *args, **kwargs):
259    if cls is VariableV1:
260      return cls._variable_v1_call(*args, **kwargs)
261    elif cls is Variable:
262      return cls._variable_v2_call(*args, **kwargs)
263    else:
264      return super(VariableMetaclass, cls).__call__(*args, **kwargs)
265
266
267@tf_export("Variable", v1=[])
268# TODO(mdan): This should subclass core.Tensor, and not all its subclasses?
269class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
270  """See the [variable guide](https://tensorflow.org/guide/variable).
271
272  A variable maintains shared, persistent state manipulated by a program.
273
274  The `Variable()` constructor requires an initial value for the variable, which
275  can be a `Tensor` of any type and shape. This initial value defines the type
276  and shape of the variable. After construction, the type and shape of the
277  variable are fixed. The value can be changed using one of the assign methods.
278
279  >>> v = tf.Variable(1.)
280  >>> v.assign(2.)
281  <tf.Variable ... shape=() dtype=float32, numpy=2.0>
282  >>> v.assign_add(0.5)
283  <tf.Variable ... shape=() dtype=float32, numpy=2.5>
284
285  The `shape` argument to `Variable`'s constructor allows you to construct a
286  variable with a less defined shape than its `initial_value`:
287
288  >>> v = tf.Variable(1., shape=tf.TensorShape(None))
289  >>> v.assign([[1.]])
290  <tf.Variable ... shape=<unknown> dtype=float32, numpy=array([[1.]], ...)>
291
292  Just like any `Tensor`, variables created with `Variable()` can be used as
293  inputs to operations. Additionally, all the operators overloaded for the
294  `Tensor` class are carried over to variables.
295
296  >>> w = tf.Variable([[1.], [2.]])
297  >>> x = tf.constant([[3., 4.]])
298  >>> tf.matmul(w, x)
299  <tf.Tensor:... shape=(2, 2), ... numpy=
300    array([[3., 4.],
301           [6., 8.]], dtype=float32)>
302  >>> tf.sigmoid(w + x)
303  <tf.Tensor:... shape=(2, 2), ...>
304
305  When building a machine learning model it is often convenient to distinguish
306  between variables holding trainable model parameters and other variables such
307  as a `step` variable used to count training steps. To make this easier, the
308  variable constructor supports a `trainable=<bool>`
309  parameter. `tf.GradientTape` watches trainable variables by default:
310
311  >>> with tf.GradientTape(persistent=True) as tape:
312  ...   trainable = tf.Variable(1.)
313  ...   non_trainable = tf.Variable(2., trainable=False)
314  ...   x1 = trainable * 2.
315  ...   x2 = non_trainable * 3.
316  >>> tape.gradient(x1, trainable)
317  <tf.Tensor:... shape=(), dtype=float32, numpy=2.0>
318  >>> assert tape.gradient(x2, non_trainable) is None  # Unwatched
319
320  Variables are automatically tracked when assigned to attributes of types
321  inheriting from `tf.Module`.
322
323  >>> m = tf.Module()
324  >>> m.v = tf.Variable([1.])
325  >>> m.trainable_variables
326  (<tf.Variable ... shape=(1,) ... numpy=array([1.], dtype=float32)>,)
327
328  This tracking then allows saving variable values to
329  [training checkpoints](https://www.tensorflow.org/guide/checkpoint), or to
330  [SavedModels](https://www.tensorflow.org/guide/saved_model) which include
331  serialized TensorFlow graphs.
332
333  Variables are often captured and manipulated by `tf.function`s. This works the
334  same way the un-decorated function would have:
335
336  >>> v = tf.Variable(0.)
337  >>> read_and_decrement = tf.function(lambda: v.assign_sub(0.1))
338  >>> read_and_decrement()
339  <tf.Tensor: shape=(), dtype=float32, numpy=-0.1>
340  >>> read_and_decrement()
341  <tf.Tensor: shape=(), dtype=float32, numpy=-0.2>
342
343  Variables created inside a `tf.function` must be owned outside the function
344  and be created only once:
345
346  >>> class M(tf.Module):
347  ...   @tf.function
348  ...   def __call__(self, x):
349  ...     if not hasattr(self, "v"):  # Or set self.v to None in __init__
350  ...       self.v = tf.Variable(x)
351  ...     return self.v * x
352  >>> m = M()
353  >>> m(2.)
354  <tf.Tensor: shape=(), dtype=float32, numpy=4.0>
355  >>> m(3.)
356  <tf.Tensor: shape=(), dtype=float32, numpy=6.0>
357  >>> m.v
358  <tf.Variable ... shape=() dtype=float32, numpy=2.0>
359
360  See the `tf.function` documentation for details.
361  """
362
363  @deprecated_args(
364      None,
365      "A variable's value can be manually cached by calling "
366      "tf.Variable.read_value() under a tf.device scope. The caching_device "
367      "argument does not work properly.",
368      "caching_device")
369  def __init__(self,
370               initial_value=None,
371               trainable=None,
372               validate_shape=True,
373               caching_device=None,
374               name=None,
375               variable_def=None,
376               dtype=None,
377               import_scope=None,
378               constraint=None,
379               synchronization=VariableSynchronization.AUTO,
380               aggregation=VariableAggregation.NONE,
381               shape=None):
382    """Creates a new variable with value `initial_value`.
383
384    Args:
385      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
386        which is the initial value for the Variable. The initial value must have
387        a shape specified unless `validate_shape` is set to False. Can also be a
388        callable with no argument that returns the initial value when called. In
389        that case, `dtype` must be specified. (Note that initializer functions
390        from init_ops.py must first be bound to a shape before being used here.)
391      trainable: If `True`, GradientTapes automatically watch uses of this
392        variable. Defaults to `True`, unless `synchronization` is set to
393        `ON_READ`, in which case it defaults to `False`.
394      validate_shape: If `False`, allows the variable to be initialized with a
395        value of unknown shape. If `True`, the default, the shape of
396        `initial_value` must be known.
397      caching_device: Optional device string describing where the Variable
398        should be cached for reading.  Defaults to the Variable's device. If not
399        `None`, caches on another device.  Typical use is to cache on the device
400        where the Ops using the Variable reside, to deduplicate copying through
401        `Switch` and other conditional statements.
402      name: Optional name for the variable. Defaults to `'Variable'` and gets
403        uniquified automatically.
404      variable_def: `VariableDef` protocol buffer. If not `None`, recreates the
405        Variable object with its contents, referencing the variable's nodes in
406        the graph, which must already exist. The graph is not changed.
407        `variable_def` and the other arguments are mutually exclusive.
408      dtype: If set, initial_value will be converted to the given type. If
409        `None`, either the datatype will be kept (if `initial_value` is a
410        Tensor), or `convert_to_tensor` will decide.
411      import_scope: Optional `string`. Name scope to add to the `Variable.` Only
412        used when initializing from protocol buffer.
413      constraint: An optional projection function to be applied to the variable
414        after being updated by an `Optimizer` (e.g. used to implement norm
415        constraints or value constraints for layer weights). The function must
416        take as input the unprojected Tensor representing the value of the
417        variable and return the Tensor for the projected value (which must have
418        the same shape). Constraints are not safe to use when doing asynchronous
419        distributed training.
420      synchronization: Indicates when a distributed a variable will be
421        aggregated. Accepted values are constants defined in the class
422        `tf.VariableSynchronization`. By default the synchronization is set to
423        `AUTO` and the current `DistributionStrategy` chooses when to
424        synchronize.
425      aggregation: Indicates how a distributed variable will be aggregated.
426        Accepted values are constants defined in the class
427        `tf.VariableAggregation`.
428      shape: (optional) The shape of this variable. If None, the shape of
429        `initial_value` will be used. When setting this argument to
430        `tf.TensorShape(None)` (representing an unspecified shape), the variable
431        can be assigned with values of different shapes.
432
433    Raises:
434      ValueError: If both `variable_def` and initial_value are specified.
435      ValueError: If the initial value is not specified, or does not have a
436        shape and `validate_shape` is `True`.
437    """
438    raise NotImplementedError
439
440  def __repr__(self):
441    raise NotImplementedError
442
443  def value(self):
444    """Returns the last snapshot of this variable.
445
446    You usually do not need to call this method as all ops that need the value
447    of the variable call it automatically through a `convert_to_tensor()` call.
448
449    Returns a `Tensor` which holds the value of the variable.  You can not
450    assign a new value to this tensor as it is not a reference to the variable.
451
452    To avoid copies, if the consumer of the returned value is on the same device
453    as the variable, this actually returns the live value of the variable, not
454    a copy.  Updates to the variable are seen by the consumer.  If the consumer
455    is on a different device it will get a copy of the variable.
456
457    Returns:
458      A `Tensor` containing the value of the variable.
459    """
460    raise NotImplementedError
461
462  def read_value(self):
463    """Returns the value of this variable, read in the current context.
464
465    Can be different from value() if it's on another device, with control
466    dependencies, etc.
467
468    Returns:
469      A `Tensor` containing the value of the variable.
470    """
471    raise NotImplementedError
472
473  def set_shape(self, shape):
474    """Overrides the shape for this variable.
475
476    Args:
477      shape: the `TensorShape` representing the overridden shape.
478    """
479    raise NotImplementedError
480
481  @property
482  def trainable(self):
483    raise NotImplementedError
484
485  @property
486  def synchronization(self):
487    raise NotImplementedError
488
489  @property
490  def aggregation(self):
491    raise NotImplementedError
492
493  def eval(self, session=None):
494    """In a session, computes and returns the value of this variable.
495
496    This is not a graph construction method, it does not add ops to the graph.
497
498    This convenience method requires a session where the graph
499    containing this variable has been launched. If no session is
500    passed, the default session is used.  See `tf.compat.v1.Session` for more
501    information on launching a graph and on sessions.
502
503    ```python
504    v = tf.Variable([1, 2])
505    init = tf.compat.v1.global_variables_initializer()
506
507    with tf.compat.v1.Session() as sess:
508        sess.run(init)
509        # Usage passing the session explicitly.
510        print(v.eval(sess))
511        # Usage with the default session.  The 'with' block
512        # above makes 'sess' the default session.
513        print(v.eval())
514    ```
515
516    Args:
517      session: The session to use to evaluate this variable. If none, the
518        default session is used.
519
520    Returns:
521      A numpy `ndarray` with a copy of the value of this variable.
522    """
523    raise NotImplementedError
524
525  @deprecated(
526      None, "Use Variable.read_value. Variables in 2.X are initialized "
527      "automatically both in eager and graph (inside tf.defun) contexts.")
528  def initialized_value(self):
529    """Returns the value of the initialized variable.
530
531    You should use this instead of the variable itself to initialize another
532    variable with a value that depends on the value of this variable.
533
534    ```python
535    # Initialize 'v' with a random tensor.
536    v = tf.Variable(tf.random.truncated_normal([10, 40]))
537    # Use `initialized_value` to guarantee that `v` has been
538    # initialized before its value is used to initialize `w`.
539    # The random values are picked only once.
540    w = tf.Variable(v.initialized_value() * 2.0)
541    ```
542
543    Returns:
544      A `Tensor` holding the value of this variable after its initializer
545      has run.
546    """
547    with ops.init_scope():
548      return control_flow_ops.cond(
549          is_variable_initialized(self), self.read_value,
550          lambda: self.initial_value)
551
552  @property
553  def initial_value(self):
554    """Returns the Tensor used as the initial value for the variable.
555
556    Note that this is different from `initialized_value()` which runs
557    the op that initializes the variable before returning its value.
558    This method returns the tensor that is used by the op that initializes
559    the variable.
560
561    Returns:
562      A `Tensor`.
563    """
564    raise NotImplementedError
565
566  @property
567  def constraint(self):
568    """Returns the constraint function associated with this variable.
569
570    Returns:
571      The constraint function that was passed to the variable constructor.
572      Can be `None` if no constraint was passed.
573    """
574    raise NotImplementedError
575
576  def assign(self, value, use_locking=False, name=None, read_value=True):
577    """Assigns a new value to the variable.
578
579    This is essentially a shortcut for `assign(self, value)`.
580
581    Args:
582      value: A `Tensor`. The new value for this variable.
583      use_locking: If `True`, use locking during the assignment.
584      name: The name of the operation to be created
585      read_value: if True, will return something which evaluates to the new
586        value of the variable; if False will return the assign op.
587
588    Returns:
589      The updated variable. If `read_value` is false, instead returns None in
590      Eager mode and the assign op in graph mode.
591    """
592    raise NotImplementedError
593
594  def assign_add(self, delta, use_locking=False, name=None, read_value=True):
595    """Adds a value to this variable.
596
597     This is essentially a shortcut for `assign_add(self, delta)`.
598
599    Args:
600      delta: A `Tensor`. The value to add to this variable.
601      use_locking: If `True`, use locking during the operation.
602      name: The name of the operation to be created
603      read_value: if True, will return something which evaluates to the new
604        value of the variable; if False will return the assign op.
605
606    Returns:
607      The updated variable. If `read_value` is false, instead returns None in
608      Eager mode and the assign op in graph mode.
609    """
610    raise NotImplementedError
611
612  def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
613    """Subtracts a value from this variable.
614
615    This is essentially a shortcut for `assign_sub(self, delta)`.
616
617    Args:
618      delta: A `Tensor`. The value to subtract from this variable.
619      use_locking: If `True`, use locking during the operation.
620      name: The name of the operation to be created
621      read_value: if True, will return something which evaluates to the new
622        value of the variable; if False will return the assign op.
623
624    Returns:
625      The updated variable. If `read_value` is false, instead returns None in
626      Eager mode and the assign op in graph mode.
627    """
628    raise NotImplementedError
629
630  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
631    """Subtracts `tf.IndexedSlices` from this variable.
632
633    Args:
634      sparse_delta: `tf.IndexedSlices` to be subtracted from this variable.
635      use_locking: If `True`, use locking during the operation.
636      name: the name of the operation.
637
638    Returns:
639      The updated variable.
640
641    Raises:
642      TypeError: if `sparse_delta` is not an `IndexedSlices`.
643    """
644    raise NotImplementedError
645
646  def scatter_add(self, sparse_delta, use_locking=False, name=None):
647    """Adds `tf.IndexedSlices` to this variable.
648
649    Args:
650      sparse_delta: `tf.IndexedSlices` to be added to this variable.
651      use_locking: If `True`, use locking during the operation.
652      name: the name of the operation.
653
654    Returns:
655      The updated variable.
656
657    Raises:
658      TypeError: if `sparse_delta` is not an `IndexedSlices`.
659    """
660    raise NotImplementedError
661
662  def scatter_max(self, sparse_delta, use_locking=False, name=None):
663    """Updates this variable with the max of `tf.IndexedSlices` and itself.
664
665    Args:
666      sparse_delta: `tf.IndexedSlices` to use as an argument of max with this
667        variable.
668      use_locking: If `True`, use locking during the operation.
669      name: the name of the operation.
670
671    Returns:
672      The updated variable.
673
674    Raises:
675      TypeError: if `sparse_delta` is not an `IndexedSlices`.
676    """
677    raise NotImplementedError
678
679  def scatter_min(self, sparse_delta, use_locking=False, name=None):
680    """Updates this variable with the min of `tf.IndexedSlices` and itself.
681
682    Args:
683      sparse_delta: `tf.IndexedSlices` to use as an argument of min with this
684        variable.
685      use_locking: If `True`, use locking during the operation.
686      name: the name of the operation.
687
688    Returns:
689      The updated variable.
690
691    Raises:
692      TypeError: if `sparse_delta` is not an `IndexedSlices`.
693    """
694    raise NotImplementedError
695
696  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
697    """Multiply this variable by `tf.IndexedSlices`.
698
699    Args:
700      sparse_delta: `tf.IndexedSlices` to multiply this variable by.
701      use_locking: If `True`, use locking during the operation.
702      name: the name of the operation.
703
704    Returns:
705      The updated variable.
706
707    Raises:
708      TypeError: if `sparse_delta` is not an `IndexedSlices`.
709    """
710    raise NotImplementedError
711
712  def scatter_div(self, sparse_delta, use_locking=False, name=None):
713    """Divide this variable by `tf.IndexedSlices`.
714
715    Args:
716      sparse_delta: `tf.IndexedSlices` to divide this variable by.
717      use_locking: If `True`, use locking during the operation.
718      name: the name of the operation.
719
720    Returns:
721      The updated variable.
722
723    Raises:
724      TypeError: if `sparse_delta` is not an `IndexedSlices`.
725    """
726    raise NotImplementedError
727
728  def scatter_update(self, sparse_delta, use_locking=False, name=None):
729    """Assigns `tf.IndexedSlices` to this variable.
730
731    Args:
732      sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
733      use_locking: If `True`, use locking during the operation.
734      name: the name of the operation.
735
736    Returns:
737      The updated variable.
738
739    Raises:
740      TypeError: if `sparse_delta` is not an `IndexedSlices`.
741    """
742    raise NotImplementedError
743
744  def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
745    """Assigns `tf.IndexedSlices` to this variable batch-wise.
746
747    Analogous to `batch_gather`. This assumes that this variable and the
748    sparse_delta IndexedSlices have a series of leading dimensions that are the
749    same for all of them, and the updates are performed on the last dimension of
750    indices. In other words, the dimensions should be the following:
751
752    `num_prefix_dims = sparse_delta.indices.ndims - 1`
753    `batch_dim = num_prefix_dims + 1`
754    `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[
755         batch_dim:]`
756
757    where
758
759    `sparse_delta.updates.shape[:num_prefix_dims]`
760    `== sparse_delta.indices.shape[:num_prefix_dims]`
761    `== var.shape[:num_prefix_dims]`
762
763    And the operation performed can be expressed as:
764
765    `var[i_1, ..., i_n,
766         sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[
767            i_1, ..., i_n, j]`
768
769    When sparse_delta.indices is a 1D tensor, this operation is equivalent to
770    `scatter_update`.
771
772    To avoid this operation one can looping over the first `ndims` of the
773    variable and using `scatter_update` on the subtensors that result of slicing
774    the first dimension. This is a valid option for `ndims = 1`, but less
775    efficient than this implementation.
776
777    Args:
778      sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
779      use_locking: If `True`, use locking during the operation.
780      name: the name of the operation.
781
782    Returns:
783      The updated variable.
784
785    Raises:
786      TypeError: if `sparse_delta` is not an `IndexedSlices`.
787    """
788    raise NotImplementedError
789
790  def scatter_nd_sub(self, indices, updates, name=None):
791    """Applies sparse subtraction to individual values or slices in a Variable.
792
793    Assuming the variable has rank `P` and `indices` is a `Tensor` of rank `Q`.
794
795    `indices` must be integer tensor, containing indices into self.
796    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
797
798    The innermost dimension of `indices` (with length `K`) corresponds to
799    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
800    dimension of self.
801
802    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
803
804    ```
805    [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]].
806    ```
807
808    For example, say we want to add 4 scattered elements to a rank-1 tensor to
809    8 elements. In Python, that update would look like this:
810
811    ```python
812        v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
813        indices = tf.constant([[4], [3], [1] ,[7]])
814        updates = tf.constant([9, 10, 11, 12])
815        op = v.scatter_nd_sub(indices, updates)
816        with tf.compat.v1.Session() as sess:
817          print sess.run(op)
818    ```
819
820    The resulting update to v would look like this:
821
822        [1, -9, 3, -6, -6, 6, 7, -4]
823
824    See `tf.scatter_nd` for more details about how to make updates to
825    slices.
826
827    Args:
828      indices: The indices to be used in the operation.
829      updates: The values to be used in the operation.
830      name: the name of the operation.
831
832    Returns:
833      The updated variable.
834    """
835    raise NotImplementedError
836
837  def scatter_nd_add(self, indices, updates, name=None):
838    """Applies sparse addition to individual values or slices in a Variable.
839
840    The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`.
841
842    `indices` must be integer tensor, containing indices into self.
843    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
844
845    The innermost dimension of `indices` (with length `K`) corresponds to
846    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
847    dimension of self.
848
849    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
850
851    ```
852    [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]].
853    ```
854
855    For example, say we want to add 4 scattered elements to a rank-1 tensor to
856    8 elements. In Python, that update would look like this:
857
858    ```python
859        v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
860        indices = tf.constant([[4], [3], [1] ,[7]])
861        updates = tf.constant([9, 10, 11, 12])
862        add = v.scatter_nd_add(indices, updates)
863        with tf.compat.v1.Session() as sess:
864          print sess.run(add)
865    ```
866
867    The resulting update to v would look like this:
868
869        [1, 13, 3, 14, 14, 6, 7, 20]
870
871    See `tf.scatter_nd` for more details about how to make updates to
872    slices.
873
874    Args:
875      indices: The indices to be used in the operation.
876      updates: The values to be used in the operation.
877      name: the name of the operation.
878
879    Returns:
880      The updated variable.
881    """
882    raise NotImplementedError
883
884  def scatter_nd_update(self, indices, updates, name=None):
885    """Applies sparse assignment to individual values or slices in a Variable.
886
887    The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`.
888
889    `indices` must be integer tensor, containing indices into self.
890    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
891
892    The innermost dimension of `indices` (with length `K`) corresponds to
893    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
894    dimension of self.
895
896    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
897
898    ```
899    [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]].
900    ```
901
902    For example, say we want to add 4 scattered elements to a rank-1 tensor to
903    8 elements. In Python, that update would look like this:
904
905    ```python
906        v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
907        indices = tf.constant([[4], [3], [1] ,[7]])
908        updates = tf.constant([9, 10, 11, 12])
909        op = v.scatter_nd_assign(indices, updates)
910        with tf.compat.v1.Session() as sess:
911          print sess.run(op)
912    ```
913
914    The resulting update to v would look like this:
915
916        [1, 11, 3, 10, 9, 6, 7, 12]
917
918    See `tf.scatter_nd` for more details about how to make updates to
919    slices.
920
921    Args:
922      indices: The indices to be used in the operation.
923      updates: The values to be used in the operation.
924      name: the name of the operation.
925
926    Returns:
927      The updated variable.
928    """
929    raise NotImplementedError
930
931  def sparse_read(self, indices, name=None):
932    r"""Gather slices from params axis axis according to indices.
933
934    This function supports a subset of tf.gather, see tf.gather for details on
935    usage.
936
937    Args:
938      indices: The index `Tensor`.  Must be one of the following types: `int32`,
939        `int64`. Must be in range `[0, params.shape[axis])`.
940      name: A name for the operation (optional).
941
942    Returns:
943      A `Tensor`. Has the same type as `params`.
944    """
945    raise AttributeError
946
947  def gather_nd(self, indices, name=None):
948    r"""Gather slices from `params` into a Tensor with shape specified by `indices`.
949
950    See tf.gather_nd for details.
951
952    Args:
953      indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
954        Index tensor.
955      name: A name for the operation (optional).
956
957    Returns:
958      A `Tensor`. Has the same type as `params`.
959    """
960    raise AttributeError
961
962  @deprecated(None, "Prefer Dataset.range instead.")
963  def count_up_to(self, limit):
964    """Increments this variable until it reaches `limit`.
965
966    When that Op is run it tries to increment the variable by `1`. If
967    incrementing the variable would bring it above `limit` then the Op raises
968    the exception `OutOfRangeError`.
969
970    If no error is raised, the Op outputs the value of the variable before
971    the increment.
972
973    This is essentially a shortcut for `count_up_to(self, limit)`.
974
975    Args:
976      limit: value at which incrementing the variable raises an error.
977
978    Returns:
979      A `Tensor` that will hold the variable value before the increment. If no
980      other Op modifies this variable, the values produced will all be
981      distinct.
982    """
983    raise NotImplementedError
984
985  @deprecated(None,
986              "Prefer Variable.assign which has equivalent behavior in 2.X.")
987  def load(self, value, session=None):
988    """Load new value into this variable.
989
990    Writes new value to variable's memory. Doesn't add ops to the graph.
991
992    This convenience method requires a session where the graph
993    containing this variable has been launched. If no session is
994    passed, the default session is used.  See `tf.compat.v1.Session` for more
995    information on launching a graph and on sessions.
996
997    ```python
998    v = tf.Variable([1, 2])
999    init = tf.compat.v1.global_variables_initializer()
1000
1001    with tf.compat.v1.Session() as sess:
1002        sess.run(init)
1003        # Usage passing the session explicitly.
1004        v.load([2, 3], sess)
1005        print(v.eval(sess)) # prints [2 3]
1006        # Usage with the default session.  The 'with' block
1007        # above makes 'sess' the default session.
1008        v.load([3, 4], sess)
1009        print(v.eval()) # prints [3 4]
1010    ```
1011
1012    Args:
1013        value: New variable value
1014        session: The session to use to evaluate this variable. If none, the
1015          default session is used.
1016
1017    Raises:
1018        ValueError: Session is not passed and no default session
1019    """
1020    if context.executing_eagerly():
1021      self.assign(value)
1022    else:
1023      session = session or ops.get_default_session()
1024      if session is None:
1025        raise ValueError(
1026            "Either session argument should be provided or default session "
1027            "should be established")
1028      session.run(self.initializer, {self.initializer.inputs[1]: value})
1029
1030  # Conversion to tensor.
1031  @staticmethod
1032  def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False):  # pylint: disable=invalid-name
1033    """Utility function for converting a Variable to a Tensor."""
1034    _ = name
1035    if dtype and not dtype.is_compatible_with(v.dtype):
1036      raise ValueError(
1037          "Incompatible type conversion requested to type '%s' for variable "
1038          "of type '%s'" % (dtype.name, v.dtype.name))
1039    if as_ref:
1040      return v._ref()  # pylint: disable=protected-access
1041    else:
1042      return v.value()
1043
1044  @classmethod
1045  def _OverloadAllOperators(cls):  # pylint: disable=invalid-name
1046    """Register overloads for all operators."""
1047    for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
1048      cls._OverloadOperator(operator)
1049    # For slicing, bind getitem differently than a tensor (use SliceHelperVar
1050    # instead)
1051    # pylint: disable=protected-access
1052    setattr(cls, "__getitem__", array_ops._SliceHelperVar)
1053
1054  @classmethod
1055  def _OverloadOperator(cls, operator):  # pylint: disable=invalid-name
1056    """Defer an operator overload to `ops.Tensor`.
1057
1058    We pull the operator out of ops.Tensor dynamically to avoid ordering issues.
1059
1060    Args:
1061      operator: string. The operator name.
1062    """
1063    # We can't use the overload mechanism on __eq__ & __ne__ since __eq__ is
1064    # called when adding a variable to sets. As a result we call a.value() which
1065    # causes infinite recursion when operating within a GradientTape
1066    # TODO(gjn): Consider removing this
1067    if operator == "__eq__" or operator == "__ne__":
1068      return
1069
1070    tensor_oper = getattr(ops.Tensor, operator)
1071
1072    def _run_op(a, *args, **kwargs):
1073      # pylint: disable=protected-access
1074      return tensor_oper(a.value(), *args, **kwargs)
1075
1076    functools.update_wrapper(_run_op, tensor_oper)
1077    setattr(cls, operator, _run_op)
1078
1079  def __hash__(self):
1080    if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions():  # pylint: disable=protected-access
1081      raise TypeError("Variable is unhashable. "
1082                      "Instead, use tensor.ref() as the key.")
1083    else:
1084      return id(self)
1085
1086  # TODO(gjn): duplicate of math_ops.tensor_equals, consider removing
1087  def __eq__(self, other):
1088    """Compares two variables element-wise for equality."""
1089    if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions():  # pylint: disable=protected-access
1090      return gen_math_ops.equal(self, other, incompatible_shape_error=False)
1091    else:
1092      # In legacy graph mode, tensor equality is object equality
1093      return self is other
1094
1095  # TODO(gjn): duplicate of math_ops.tensor_not_equals, consider removing
1096  def __ne__(self, other):
1097    """Compares two variables element-wise for equality."""
1098    if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions():  # pylint: disable=protected-access
1099      return gen_math_ops.not_equal(self, other, incompatible_shape_error=False)
1100    else:
1101      # In legacy graph mode, tensor equality is object equality
1102      return self is not other
1103
1104  def __iter__(self):
1105    """Dummy method to prevent iteration.
1106
1107    Do not call.
1108
1109    NOTE(mrry): If we register __getitem__ as an overloaded operator,
1110    Python will valiantly attempt to iterate over the variable's Tensor from 0
1111    to infinity.  Declaring this method prevents this unintended behavior.
1112
1113    Raises:
1114      TypeError: when invoked.
1115    """
1116    raise TypeError("'Variable' object is not iterable.")
1117
1118  # NOTE(mrry): This enables the Variable's overloaded "right" binary
1119  # operators to run when the left operand is an ndarray, because it
1120  # accords the Variable class higher priority than an ndarray, or a
1121  # numpy matrix.
1122  # TODO(mrry): Convert this to using numpy's __numpy_ufunc__
1123  # mechanism, which allows more control over how Variables interact
1124  # with ndarrays.
1125  __array_priority__ = 100
1126
1127  @property
1128  def name(self):
1129    """The name of this variable."""
1130    raise NotImplementedError
1131
1132  @property
1133  def _shared_name(self):
1134    """The shared name of the variable.
1135
1136      Unlike name(), shared_name doesn't have ":0" suffix. It is user-specified
1137      name with name scope prefix.
1138
1139    Returns:
1140      variable name.
1141    """
1142    return self.name[:self.name.index(":")]
1143
1144  @property
1145  def initializer(self):
1146    """The initializer operation for this variable."""
1147    raise NotImplementedError
1148
1149  @property
1150  def device(self):
1151    """The device of this variable."""
1152    raise NotImplementedError
1153
1154  @property
1155  def dtype(self):
1156    """The `DType` of this variable."""
1157    raise NotImplementedError
1158
1159  @property
1160  def op(self):
1161    """The `Operation` of this variable."""
1162    raise NotImplementedError
1163
1164  @property
1165  def graph(self):
1166    """The `Graph` of this variable."""
1167    raise NotImplementedError
1168
1169  @property
1170  def shape(self):
1171    """The `TensorShape` of this variable.
1172
1173    Returns:
1174      A `TensorShape`.
1175    """
1176    raise NotImplementedError
1177
1178  def get_shape(self):
1179    """Alias of `Variable.shape`."""
1180    return self.shape
1181
1182  def _gather_saveables_for_checkpoint(self):
1183    """For implementing `Trackable`. This object is saveable on its own."""
1184    return {trackable.VARIABLE_VALUE_KEY: self}
1185
1186  def to_proto(self, export_scope=None):
1187    """Converts a `Variable` to a `VariableDef` protocol buffer.
1188
1189    Args:
1190      export_scope: Optional `string`. Name scope to remove.
1191
1192    Returns:
1193      A `VariableDef` protocol buffer, or `None` if the `Variable` is not
1194      in the specified name scope.
1195    """
1196    raise NotImplementedError
1197
1198  @staticmethod
1199  def from_proto(variable_def, import_scope=None):
1200    """Returns a `Variable` object created from `variable_def`."""
1201    return RefVariable(variable_def=variable_def, import_scope=import_scope)
1202
1203  def _set_save_slice_info(self, save_slice_info):
1204    """Sets the slice info for this `Variable`.
1205
1206    Args:
1207      save_slice_info: A `Variable.SaveSliceInfo` object.
1208    """
1209    self._save_slice_info = save_slice_info
1210
1211  def _get_save_slice_info(self):
1212    return self._save_slice_info
1213
1214  @deprecated(None, "Use ref() instead.")
1215  def experimental_ref(self):
1216    return self.ref()
1217
1218  def ref(self):
1219    # tf.Tensor also has the same ref() API.  If you update the
1220    # documentation here, please update tf.Tensor.ref() as well.
1221    """Returns a hashable reference object to this Variable.
1222
1223    The primary use case for this API is to put variables in a set/dictionary.
1224    We can't put variables in a set/dictionary as `variable.__hash__()` is no
1225    longer available starting Tensorflow 2.0.
1226
1227    The following will raise an exception starting 2.0
1228
1229    >>> x = tf.Variable(5)
1230    >>> y = tf.Variable(10)
1231    >>> z = tf.Variable(10)
1232    >>> variable_set = {x, y, z}
1233    Traceback (most recent call last):
1234      ...
1235    TypeError: Variable is unhashable. Instead, use tensor.ref() as the key.
1236    >>> variable_dict = {x: 'five', y: 'ten'}
1237    Traceback (most recent call last):
1238      ...
1239    TypeError: Variable is unhashable. Instead, use tensor.ref() as the key.
1240
1241    Instead, we can use `variable.ref()`.
1242
1243    >>> variable_set = {x.ref(), y.ref(), z.ref()}
1244    >>> x.ref() in variable_set
1245    True
1246    >>> variable_dict = {x.ref(): 'five', y.ref(): 'ten', z.ref(): 'ten'}
1247    >>> variable_dict[y.ref()]
1248    'ten'
1249
1250    Also, the reference object provides `.deref()` function that returns the
1251    original Variable.
1252
1253    >>> x = tf.Variable(5)
1254    >>> x.ref().deref()
1255    <tf.Variable 'Variable:0' shape=() dtype=int32, numpy=5>
1256    """
1257    return object_identity.Reference(self)
1258
1259  class SaveSliceInfo(object):
1260    """Information on how to save this Variable as a slice.
1261
1262    Provides internal support for saving variables as slices of a larger
1263    variable.  This API is not public and is subject to change.
1264
1265    Available properties:
1266
1267    * full_name
1268    * full_shape
1269    * var_offset
1270    * var_shape
1271    """
1272
1273    def __init__(self,
1274                 full_name=None,
1275                 full_shape=None,
1276                 var_offset=None,
1277                 var_shape=None,
1278                 save_slice_info_def=None,
1279                 import_scope=None):
1280      """Create a `SaveSliceInfo`.
1281
1282      Args:
1283        full_name: Name of the full variable of which this `Variable` is a
1284          slice.
1285        full_shape: Shape of the full variable, as a list of int.
1286        var_offset: Offset of this `Variable` into the full variable, as a list
1287          of int.
1288        var_shape: Shape of this `Variable`, as a list of int.
1289        save_slice_info_def: `SaveSliceInfoDef` protocol buffer. If not `None`,
1290          recreates the SaveSliceInfo object its contents. `save_slice_info_def`
1291          and other arguments are mutually exclusive.
1292        import_scope: Optional `string`. Name scope to add. Only used when
1293          initializing from protocol buffer.
1294      """
1295      if save_slice_info_def:
1296        assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef)
1297        self.full_name = ops.prepend_name_scope(
1298            save_slice_info_def.full_name, import_scope=import_scope)
1299        self.full_shape = [i for i in save_slice_info_def.full_shape]
1300        self.var_offset = [i for i in save_slice_info_def.var_offset]
1301        self.var_shape = [i for i in save_slice_info_def.var_shape]
1302      else:
1303        self.full_name = full_name
1304        self.full_shape = full_shape
1305        self.var_offset = var_offset
1306        self.var_shape = var_shape
1307
1308    @property
1309    def spec(self):
1310      """Computes the spec string used for saving."""
1311      full_shape_str = " ".join("%d" % d for d in self.full_shape) + " "
1312      sl_spec = ":".join(
1313          "%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape))
1314      return full_shape_str + sl_spec
1315
1316    def to_proto(self, export_scope=None):
1317      """Returns a SaveSliceInfoDef() proto.
1318
1319      Args:
1320        export_scope: Optional `string`. Name scope to remove.
1321
1322      Returns:
1323        A `SaveSliceInfoDef` protocol buffer, or None if the `Variable` is not
1324        in the specified name scope.
1325      """
1326      if (export_scope is None or self.full_name.startswith(export_scope)):
1327        save_slice_info_def = variable_pb2.SaveSliceInfoDef()
1328        save_slice_info_def.full_name = ops.strip_name_scope(
1329            self.full_name, export_scope)
1330        for i in self.full_shape:
1331          save_slice_info_def.full_shape.append(i)
1332        for i in self.var_offset:
1333          save_slice_info_def.var_offset.append(i)
1334        for i in self.var_shape:
1335          save_slice_info_def.var_shape.append(i)
1336        return save_slice_info_def
1337      else:
1338        return None
1339
1340
1341Variable._OverloadAllOperators()  # pylint: disable=protected-access
1342_pywrap_utils.RegisterType("Variable", Variable)
1343
1344
1345@tf_export(v1=["Variable"])
1346class VariableV1(Variable):
1347  """See the [Variables Guide](https://tensorflow.org/guide/variables).
1348
1349  A variable maintains state in the graph across calls to `run()`. You add a
1350  variable to the graph by constructing an instance of the class `Variable`.
1351
1352  The `Variable()` constructor requires an initial value for the variable,
1353  which can be a `Tensor` of any type and shape. The initial value defines the
1354  type and shape of the variable. After construction, the type and shape of
1355  the variable are fixed. The value can be changed using one of the assign
1356  methods.
1357
1358  If you want to change the shape of a variable later you have to use an
1359  `assign` Op with `validate_shape=False`.
1360
1361  Just like any `Tensor`, variables created with `Variable()` can be used as
1362  inputs for other Ops in the graph. Additionally, all the operators
1363  overloaded for the `Tensor` class are carried over to variables, so you can
1364  also add nodes to the graph by just doing arithmetic on variables.
1365
1366  ```python
1367  import tensorflow as tf
1368
1369  # Create a variable.
1370  w = tf.Variable(<initial-value>, name=<optional-name>)
1371
1372  # Use the variable in the graph like any Tensor.
1373  y = tf.matmul(w, ...another variable or tensor...)
1374
1375  # The overloaded operators are available too.
1376  z = tf.sigmoid(w + y)
1377
1378  # Assign a new value to the variable with `assign()` or a related method.
1379  w.assign(w + 1.0)
1380  w.assign_add(1.0)
1381  ```
1382
1383  When you launch the graph, variables have to be explicitly initialized before
1384  you can run Ops that use their value. You can initialize a variable by
1385  running its *initializer op*, restoring the variable from a save file, or
1386  simply running an `assign` Op that assigns a value to the variable. In fact,
1387  the variable *initializer op* is just an `assign` Op that assigns the
1388  variable's initial value to the variable itself.
1389
1390  ```python
1391  # Launch the graph in a session.
1392  with tf.compat.v1.Session() as sess:
1393      # Run the variable initializer.
1394      sess.run(w.initializer)
1395      # ...you now can run ops that use the value of 'w'...
1396  ```
1397
1398  The most common initialization pattern is to use the convenience function
1399  `global_variables_initializer()` to add an Op to the graph that initializes
1400  all the variables. You then run that Op after launching the graph.
1401
1402  ```python
1403  # Add an Op to initialize global variables.
1404  init_op = tf.compat.v1.global_variables_initializer()
1405
1406  # Launch the graph in a session.
1407  with tf.compat.v1.Session() as sess:
1408      # Run the Op that initializes global variables.
1409      sess.run(init_op)
1410      # ...you can now run any Op that uses variable values...
1411  ```
1412
1413  If you need to create a variable with an initial value dependent on another
1414  variable, use the other variable's `initialized_value()`. This ensures that
1415  variables are initialized in the right order.
1416
1417  All variables are automatically collected in the graph where they are
1418  created. By default, the constructor adds the new variable to the graph
1419  collection `GraphKeys.GLOBAL_VARIABLES`. The convenience function
1420  `global_variables()` returns the contents of that collection.
1421
1422  When building a machine learning model it is often convenient to distinguish
1423  between variables holding the trainable model parameters and other variables
1424  such as a `global step` variable used to count training steps. To make this
1425  easier, the variable constructor supports a `trainable=<bool>` parameter. If
1426  `True`, the new variable is also added to the graph collection
1427  `GraphKeys.TRAINABLE_VARIABLES`. The convenience function
1428  `trainable_variables()` returns the contents of this collection. The
1429  various `Optimizer` classes use this collection as the default list of
1430  variables to optimize.
1431
1432  WARNING: tf.Variable objects by default have a non-intuitive memory model. A
1433  Variable is represented internally as a mutable Tensor which can
1434  non-deterministically alias other Tensors in a graph. The set of operations
1435  which consume a Variable and can lead to aliasing is undetermined and can
1436  change across TensorFlow versions. Avoid writing code which relies on the
1437  value of a Variable either changing or not changing as other operations
1438  happen. For example, using Variable objects or simple functions thereof as
1439  predicates in a `tf.cond` is dangerous and error-prone:
1440
1441  ```
1442  v = tf.Variable(True)
1443  tf.cond(v, lambda: v.assign(False), my_false_fn)  # Note: this is broken.
1444  ```
1445
1446  Here, adding `use_resource=True` when constructing the variable will
1447  fix any nondeterminism issues:
1448  ```
1449  v = tf.Variable(True, use_resource=True)
1450  tf.cond(v, lambda: v.assign(False), my_false_fn)
1451  ```
1452
1453  To use the replacement for variables which does
1454  not have these issues:
1455
1456  * Add `use_resource=True` when constructing `tf.Variable`;
1457  * Call `tf.compat.v1.get_variable_scope().set_use_resource(True)` inside a
1458    `tf.compat.v1.variable_scope` before the `tf.compat.v1.get_variable()` call.
1459  """
1460
1461  def __init__(
1462      self,  # pylint: disable=super-init-not-called
1463      initial_value=None,
1464      trainable=None,
1465      collections=None,
1466      validate_shape=True,
1467      caching_device=None,
1468      name=None,
1469      variable_def=None,
1470      dtype=None,
1471      expected_shape=None,
1472      import_scope=None,
1473      constraint=None,
1474      use_resource=None,
1475      synchronization=VariableSynchronization.AUTO,
1476      aggregation=VariableAggregation.NONE,
1477      shape=None):
1478    """Creates a new variable with value `initial_value`.
1479
1480    The new variable is added to the graph collections listed in `collections`,
1481    which defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1482
1483    If `trainable` is `True` the variable is also added to the graph collection
1484    `GraphKeys.TRAINABLE_VARIABLES`.
1485
1486    This constructor creates both a `variable` Op and an `assign` Op to set the
1487    variable to its initial value.
1488
1489    Args:
1490      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
1491        which is the initial value for the Variable. The initial value must have
1492        a shape specified unless `validate_shape` is set to False. Can also be a
1493        callable with no argument that returns the initial value when called. In
1494        that case, `dtype` must be specified. (Note that initializer functions
1495        from init_ops.py must first be bound to a shape before being used here.)
1496      trainable: If `True`, also adds the variable to the graph collection
1497        `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default
1498        list of variables to use by the `Optimizer` classes. Defaults to `True`,
1499        unless `synchronization` is set to `ON_READ`, in which case it defaults
1500        to `False`.
1501      collections: List of graph collections keys. The new variable is added to
1502        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1503      validate_shape: If `False`, allows the variable to be initialized with a
1504        value of unknown shape. If `True`, the default, the shape of
1505        `initial_value` must be known.
1506      caching_device: Optional device string describing where the Variable
1507        should be cached for reading.  Defaults to the Variable's device. If not
1508        `None`, caches on another device.  Typical use is to cache on the device
1509        where the Ops using the Variable reside, to deduplicate copying through
1510        `Switch` and other conditional statements.
1511      name: Optional name for the variable. Defaults to `'Variable'` and gets
1512        uniquified automatically.
1513      variable_def: `VariableDef` protocol buffer. If not `None`, recreates the
1514        Variable object with its contents, referencing the variable's nodes in
1515        the graph, which must already exist. The graph is not changed.
1516        `variable_def` and the other arguments are mutually exclusive.
1517      dtype: If set, initial_value will be converted to the given type. If
1518        `None`, either the datatype will be kept (if `initial_value` is a
1519        Tensor), or `convert_to_tensor` will decide.
1520      expected_shape: A TensorShape. If set, initial_value is expected to have
1521        this shape.
1522      import_scope: Optional `string`. Name scope to add to the `Variable.` Only
1523        used when initializing from protocol buffer.
1524      constraint: An optional projection function to be applied to the variable
1525        after being updated by an `Optimizer` (e.g. used to implement norm
1526        constraints or value constraints for layer weights). The function must
1527        take as input the unprojected Tensor representing the value of the
1528        variable and return the Tensor for the projected value (which must have
1529        the same shape). Constraints are not safe to use when doing asynchronous
1530        distributed training.
1531      use_resource: whether to use resource variables.
1532      synchronization: Indicates when a distributed a variable will be
1533        aggregated. Accepted values are constants defined in the class
1534        `tf.VariableSynchronization`. By default the synchronization is set to
1535        `AUTO` and the current `DistributionStrategy` chooses when to
1536        synchronize.
1537      aggregation: Indicates how a distributed variable will be aggregated.
1538        Accepted values are constants defined in the class
1539        `tf.VariableAggregation`.
1540      shape: (optional) The shape of this variable. If None, the shape of
1541        `initial_value` will be used. When setting this argument to
1542        `tf.TensorShape(None)` (representing an unspecified shape), the variable
1543        can be assigned with values of different shapes.
1544
1545    Raises:
1546      ValueError: If both `variable_def` and initial_value are specified.
1547      ValueError: If the initial value is not specified, or does not have a
1548        shape and `validate_shape` is `True`.
1549      RuntimeError: If eager execution is enabled.
1550    """
1551
1552  SaveSliceInfo = Variable.SaveSliceInfo
1553
1554
1555# TODO(apassos): do not repeat all comments here
1556class RefVariable(VariableV1, core.Tensor):
1557  """Ref-based implementation of variables."""
1558
1559  def __init__(
1560      self,  # pylint: disable=super-init-not-called
1561      initial_value=None,
1562      trainable=None,
1563      collections=None,
1564      validate_shape=True,
1565      caching_device=None,
1566      name=None,
1567      variable_def=None,
1568      dtype=None,
1569      expected_shape=None,
1570      import_scope=None,
1571      constraint=None,
1572      synchronization=None,
1573      aggregation=None,
1574      shape=None):
1575    """Creates a new variable with value `initial_value`.
1576
1577    The new variable is added to the graph collections listed in `collections`,
1578    which defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1579
1580    If `trainable` is `True` the variable is also added to the graph collection
1581    `GraphKeys.TRAINABLE_VARIABLES`.
1582
1583    This constructor creates both a `variable` Op and an `assign` Op to set the
1584    variable to its initial value.
1585
1586    Args:
1587      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
1588        which is the initial value for the Variable. The initial value must have
1589        a shape specified unless `validate_shape` is set to False. Can also be a
1590        callable with no argument that returns the initial value when called. In
1591        that case, `dtype` must be specified. (Note that initializer functions
1592        from init_ops.py must first be bound to a shape before being used here.)
1593      trainable: If `True`, also adds the variable to the graph collection
1594        `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default
1595        list of variables to use by the `Optimizer` classes. Defaults to `True`,
1596        unless `synchronization` is set to `ON_READ`, in which case it defaults
1597        to `False`.
1598      collections: List of graph collections keys. The new variable is added to
1599        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1600      validate_shape: If `False`, allows the variable to be initialized with a
1601        value of unknown shape. If `True`, the default, the shape of
1602        `initial_value` must be known.
1603      caching_device: Optional device string describing where the Variable
1604        should be cached for reading.  Defaults to the Variable's device. If not
1605        `None`, caches on another device.  Typical use is to cache on the device
1606        where the Ops using the Variable reside, to deduplicate copying through
1607        `Switch` and other conditional statements.
1608      name: Optional name for the variable. Defaults to `'Variable'` and gets
1609        uniquified automatically.
1610      variable_def: `VariableDef` protocol buffer. If not `None`, recreates the
1611        Variable object with its contents, referencing the variable's nodes in
1612        the graph, which must already exist. The graph is not changed.
1613        `variable_def` and the other arguments are mutually exclusive.
1614      dtype: If set, initial_value will be converted to the given type. If
1615        `None`, either the datatype will be kept (if `initial_value` is a
1616        Tensor), or `convert_to_tensor` will decide.
1617      expected_shape: A TensorShape. If set, initial_value is expected to have
1618        this shape.
1619      import_scope: Optional `string`. Name scope to add to the `Variable.` Only
1620        used when initializing from protocol buffer.
1621      constraint: An optional projection function to be applied to the variable
1622        after being updated by an `Optimizer` (e.g. used to implement norm
1623        constraints or value constraints for layer weights). The function must
1624        take as input the unprojected Tensor representing the value of the
1625        variable and return the Tensor for the projected value (which must have
1626        the same shape). Constraints are not safe to use when doing asynchronous
1627        distributed training.
1628      synchronization: Indicates when a distributed a variable will be
1629        aggregated. Accepted values are constants defined in the class
1630        `tf.VariableSynchronization`. By default the synchronization is set to
1631        `AUTO` and the current `DistributionStrategy` chooses when to
1632        synchronize.
1633      aggregation: Indicates how a distributed variable will be aggregated.
1634        Accepted values are constants defined in the class
1635        `tf.VariableAggregation`.
1636      shape: (optional) The shape of this variable. If None, the shape of
1637        `initial_value` will be used. When setting this argument to
1638        `tf.TensorShape(None)` (representing an unspecified shape), the variable
1639        can be assigned with values of different shapes.
1640
1641    Raises:
1642      ValueError: If both `variable_def` and initial_value are specified.
1643      ValueError: If the initial value is not specified, or does not have a
1644        shape and `validate_shape` is `True`.
1645      RuntimeError: If eager execution is enabled.
1646    """
1647    self._in_graph_mode = True
1648    if variable_def:
1649      # If variable_def is provided, recreates the variable from its fields.
1650      if initial_value:
1651        raise ValueError("variable_def and initial_value are mutually "
1652                         "exclusive.")
1653      self._init_from_proto(variable_def, import_scope=import_scope)
1654    else:
1655      # Create from initial_value.
1656      self._init_from_args(
1657          initial_value=initial_value,
1658          trainable=trainable,
1659          collections=collections,
1660          validate_shape=validate_shape,
1661          caching_device=caching_device,
1662          name=name,
1663          dtype=dtype,
1664          expected_shape=expected_shape,
1665          constraint=constraint,
1666          synchronization=synchronization,
1667          aggregation=aggregation,
1668          shape=shape)
1669
1670  def __repr__(self):
1671    if context.executing_eagerly() and not self._in_graph_mode:
1672      return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % (
1673          self.name, self.get_shape(), self.dtype.name,
1674          ops.numpy_text(self.read_value(), is_repr=True))
1675    else:
1676      return "<tf.Variable '%s' shape=%s dtype=%s>" % (
1677          self.name, self.get_shape(), self.dtype.name)
1678
1679  def _init_from_args(self,
1680                      initial_value=None,
1681                      trainable=None,
1682                      collections=None,
1683                      validate_shape=True,
1684                      caching_device=None,
1685                      name=None,
1686                      dtype=None,
1687                      expected_shape=None,
1688                      constraint=None,
1689                      synchronization=None,
1690                      aggregation=None,
1691                      shape=None):
1692    """Creates a new variable from arguments.
1693
1694    Args:
1695      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
1696        which is the initial value for the Variable. The initial value must have
1697        a shape specified unless `validate_shape` is set to False. Can also be a
1698        callable with no argument that returns the initial value when called.
1699        (Note that initializer functions from init_ops.py must first be bound to
1700        a shape before being used here.)
1701      trainable: If `True`, also adds the variable to the graph collection
1702        `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default
1703        list of variables to use by the `Optimizer` classes. Defaults to `True`,
1704        unless `synchronization` is set to `ON_READ`, in which case it defaults
1705        to `False`.
1706      collections: List of graph collections keys. The new variable is added to
1707        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1708      validate_shape: If `False`, allows the variable to be initialized with a
1709        value of unknown shape. If `True`, the default, the shape of
1710        `initial_value` must be known.
1711      caching_device: Optional device string or function describing where the
1712        Variable should be cached for reading.  Defaults to the Variable's
1713        device.  If not `None`, caches on another device.  Typical use is to
1714        cache on the device where the Ops using the Variable reside, to
1715        deduplicate copying through `Switch` and other conditional statements.
1716      name: Optional name for the variable. Defaults to `'Variable'` and gets
1717        uniquified automatically.
1718      dtype: If set, initial_value will be converted to the given type. If None,
1719        either the datatype will be kept (if initial_value is a Tensor) or
1720        float32 will be used (if it is a Python object convertible to a Tensor).
1721      expected_shape: Deprecated. Ignored.
1722      constraint: An optional projection function to be applied to the variable
1723        after being updated by an `Optimizer` (e.g. used to implement norm
1724        constraints or value constraints for layer weights). The function must
1725        take as input the unprojected Tensor representing the value of the
1726        variable and return the Tensor for the projected value (which must have
1727        the same shape). Constraints are not safe to use when doing asynchronous
1728        distributed training.
1729      synchronization: Indicates when a distributed a variable will be
1730        aggregated. Accepted values are constants defined in the class
1731        `tf.VariableSynchronization`. By default the synchronization is set to
1732        `AUTO` and the current `DistributionStrategy` chooses when to
1733        synchronize.
1734      aggregation: Indicates how a distributed variable will be aggregated.
1735        Accepted values are constants defined in the class
1736        `tf.VariableAggregation`.
1737      shape: (optional) The shape of this variable. If None, the shape of
1738        `initial_value` will be used. When setting this argument to
1739        `tf.TensorShape(None)` (representing an unspecified shape), the variable
1740        can be assigned with values of different shapes.
1741
1742    Raises:
1743      ValueError: If the initial value is not specified, or does not have a
1744        shape and `validate_shape` is `True`.
1745      RuntimeError: If lifted into the eager context.
1746    """
1747    _ = expected_shape
1748    if initial_value is None:
1749      raise ValueError("initial_value must be specified.")
1750    init_from_fn = callable(initial_value)
1751
1752    if collections is None:
1753      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
1754    if not isinstance(collections, (list, tuple, set)):
1755      raise ValueError(
1756          "collections argument to Variable constructor must be a list, tuple, "
1757          "or set. Got %s of type %s" % (collections, type(collections)))
1758    if constraint is not None and not callable(constraint):
1759      raise ValueError("The `constraint` argument must be a callable.")
1760
1761    # Store the graph key so optimizers know how to only retrieve variables from
1762    # this graph.
1763    self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
1764    if isinstance(initial_value, trackable.CheckpointInitialValue):
1765      self._maybe_initialize_trackable()
1766      self._update_uid = initial_value.checkpoint_position.restore_uid
1767      initial_value = initial_value.wrapped_value
1768
1769    synchronization, aggregation, trainable = (
1770        validate_synchronization_aggregation_trainable(synchronization,
1771                                                       aggregation, trainable,
1772                                                       name))
1773    self._synchronization = synchronization
1774    self._aggregation = aggregation
1775    self._trainable = trainable
1776    if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
1777      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
1778    with ops.init_scope():
1779      # Ensure that we weren't lifted into the eager context.
1780      if context.executing_eagerly():
1781        raise RuntimeError(
1782            "RefVariable not supported when eager execution is enabled. ")
1783      with ops.name_scope(name, "Variable",
1784                          [] if init_from_fn else [initial_value]) as name:
1785
1786        if init_from_fn:
1787          # Use attr_scope and device(None) to simulate the behavior of
1788          # colocate_with when the variable we want to colocate with doesn't
1789          # yet exist.
1790          true_name = ops.name_from_scope_name(name)  # pylint: disable=protected-access
1791          attr = attr_value_pb2.AttrValue(
1792              list=attr_value_pb2.AttrValue.ListValue(
1793                  s=[compat.as_bytes("loc:@%s" % true_name)]))
1794          # pylint: disable=protected-access
1795          with ops.get_default_graph()._attr_scope({"_class": attr}):
1796            with ops.name_scope("Initializer"), ops.device(None):
1797              initial_value = initial_value()
1798              if isinstance(initial_value, trackable.CheckpointInitialValue):
1799                self._maybe_initialize_trackable()
1800                self._update_uid = initial_value.checkpoint_position.restore_uid
1801                initial_value = initial_value.wrapped_value
1802              self._initial_value = ops.convert_to_tensor(
1803                  initial_value, name="initial_value", dtype=dtype)
1804              if shape is None:
1805                shape = (
1806                    self._initial_value.get_shape()
1807                    if validate_shape else tensor_shape.unknown_shape())
1808            self._variable = state_ops.variable_op_v2(
1809                shape, self._initial_value.dtype.base_dtype, name=name)
1810          # pylint: enable=protected-access
1811
1812        # Or get the initial value from a Tensor or Python object.
1813        else:
1814          self._initial_value = ops.convert_to_tensor(
1815              initial_value, name="initial_value", dtype=dtype)
1816          # pylint: disable=protected-access
1817          if self._initial_value.op._get_control_flow_context() is not None:
1818            raise ValueError(
1819                "Initializer for variable %s is from inside a control-flow "
1820                "construct, such as a loop or conditional. When creating a "
1821                "variable inside a loop or conditional, use a lambda as the "
1822                "initializer." % name)
1823          if shape is None:
1824            # pylint: enable=protected-access
1825            shape = (
1826                self._initial_value.get_shape()
1827                if validate_shape else tensor_shape.unknown_shape())
1828          # In this case, the variable op can't be created until after the
1829          # initial_value has been converted to a Tensor with a known type.
1830          self._variable = state_ops.variable_op_v2(
1831              shape, self._initial_value.dtype.base_dtype, name=name)
1832
1833        # Cache the name in `self`, because some APIs call `Variable.name` in a
1834        # tight loop, and this halves the cost.
1835        self._name = self._variable.name
1836
1837        # Manually overrides the variable's shape with the initial value's.
1838        if validate_shape:
1839          initial_value_shape = self._initial_value.get_shape()
1840          if not initial_value_shape.is_fully_defined():
1841            raise ValueError("initial_value must have a shape specified: %s" %
1842                             self._initial_value)
1843
1844        # If 'initial_value' makes use of other variables, make sure we don't
1845        # have an issue if these other variables aren't initialized first by
1846        # using their initialized_value() method.
1847        self._initializer_op = state_ops.assign(
1848            self._variable,
1849            _try_guard_against_uninitialized_dependencies(
1850                name, self._initial_value),
1851            validate_shape=validate_shape).op
1852
1853        # TODO(vrv): Change this class to not take caching_device, but
1854        # to take the op to colocate the snapshot with, so we can use
1855        # colocation rather than devices.
1856        if caching_device is not None:
1857          with ops.device(caching_device):
1858            self._snapshot = array_ops.identity(self._variable, name="read")
1859        else:
1860          with ops.colocate_with(self._variable.op):
1861            self._snapshot = array_ops.identity(self._variable, name="read")
1862      ops.add_to_collections(collections, self)
1863
1864    self._caching_device = caching_device
1865    self._save_slice_info = None
1866    self._constraint = constraint
1867
1868  def _init_from_proto(self, variable_def, import_scope=None):
1869    """Recreates the Variable object from a `VariableDef` protocol buffer.
1870
1871    Args:
1872      variable_def: `VariableDef` protocol buffer, describing a variable whose
1873        nodes already exists in the graph.
1874      import_scope: Optional `string`. Name scope to add.
1875    """
1876    assert isinstance(variable_def, variable_pb2.VariableDef)
1877    # Create from variable_def.
1878    g = ops.get_default_graph()
1879    self._variable = g.as_graph_element(
1880        ops.prepend_name_scope(
1881            variable_def.variable_name, import_scope=import_scope))
1882    self._name = self._variable.name
1883    self._initializer_op = g.as_graph_element(
1884        ops.prepend_name_scope(
1885            variable_def.initializer_name, import_scope=import_scope))
1886    # Tests whether initial_value_name exists first for backwards compatibility.
1887    if (hasattr(variable_def, "initial_value_name") and
1888        variable_def.initial_value_name):
1889      self._initial_value = g.as_graph_element(
1890          ops.prepend_name_scope(
1891              variable_def.initial_value_name, import_scope=import_scope))
1892    else:
1893      self._initial_value = None
1894    synchronization, aggregation, trainable = (
1895        validate_synchronization_aggregation_trainable(
1896            variable_def.synchronization, variable_def.aggregation,
1897            variable_def.trainable, variable_def.variable_name))
1898    self._synchronization = synchronization
1899    self._aggregation = aggregation
1900    self._trainable = trainable
1901    self._snapshot = g.as_graph_element(
1902        ops.prepend_name_scope(
1903            variable_def.snapshot_name, import_scope=import_scope))
1904    if variable_def.HasField("save_slice_info_def"):
1905      self._save_slice_info = Variable.SaveSliceInfo(
1906          save_slice_info_def=variable_def.save_slice_info_def,
1907          import_scope=import_scope)
1908    else:
1909      self._save_slice_info = None
1910    self._caching_device = None
1911    self._constraint = None
1912
1913  def _as_graph_element(self):
1914    """Conversion function for Graph.as_graph_element()."""
1915    return self._variable
1916
1917  def value(self):
1918    """Returns the last snapshot of this variable.
1919
1920    You usually do not need to call this method as all ops that need the value
1921    of the variable call it automatically through a `convert_to_tensor()` call.
1922
1923    Returns a `Tensor` which holds the value of the variable.  You can not
1924    assign a new value to this tensor as it is not a reference to the variable.
1925
1926    To avoid copies, if the consumer of the returned value is on the same device
1927    as the variable, this actually returns the live value of the variable, not
1928    a copy.  Updates to the variable are seen by the consumer.  If the consumer
1929    is on a different device it will get a copy of the variable.
1930
1931    Returns:
1932      A `Tensor` containing the value of the variable.
1933    """
1934    return self._snapshot
1935
1936  def read_value(self):
1937    """Returns the value of this variable, read in the current context.
1938
1939    Can be different from value() if it's on another device, with control
1940    dependencies, etc.
1941
1942    Returns:
1943      A `Tensor` containing the value of the variable.
1944    """
1945    return array_ops.identity(self._variable, name="read")
1946
1947  def _ref(self):
1948    """Returns a reference to this variable.
1949
1950    You usually do not need to call this method as all ops that need a reference
1951    to the variable call it automatically.
1952
1953    Returns is a `Tensor` which holds a reference to the variable.  You can
1954    assign a new value to the variable by passing the tensor to an assign op.
1955    See `tf.Variable.value` if you want to get the value of the
1956    variable.
1957
1958    Returns:
1959      A `Tensor` that is a reference to the variable.
1960    """
1961    return self._variable
1962
1963  def set_shape(self, shape):
1964    """Overrides the shape for this variable.
1965
1966    Args:
1967      shape: the `TensorShape` representing the overridden shape.
1968    """
1969    self._ref().set_shape(shape)
1970    self.value().set_shape(shape)
1971
1972  @property
1973  def trainable(self):
1974    return self._trainable
1975
1976  @property
1977  def synchronization(self):
1978    return self._synchronization
1979
1980  @property
1981  def aggregation(self):
1982    return self._aggregation
1983
1984  def eval(self, session=None):
1985    """In a session, computes and returns the value of this variable.
1986
1987    This is not a graph construction method, it does not add ops to the graph.
1988
1989    This convenience method requires a session where the graph
1990    containing this variable has been launched. If no session is
1991    passed, the default session is used.  See `tf.compat.v1.Session` for more
1992    information on launching a graph and on sessions.
1993
1994    ```python
1995    v = tf.Variable([1, 2])
1996    init = tf.compat.v1.global_variables_initializer()
1997
1998    with tf.compat.v1.Session() as sess:
1999        sess.run(init)
2000        # Usage passing the session explicitly.
2001        print(v.eval(sess))
2002        # Usage with the default session.  The 'with' block
2003        # above makes 'sess' the default session.
2004        print(v.eval())
2005    ```
2006
2007    Args:
2008      session: The session to use to evaluate this variable. If none, the
2009        default session is used.
2010
2011    Returns:
2012      A numpy `ndarray` with a copy of the value of this variable.
2013    """
2014    return self._variable.eval(session=session)
2015
2016  @property
2017  def initial_value(self):
2018    """Returns the Tensor used as the initial value for the variable.
2019
2020    Note that this is different from `initialized_value()` which runs
2021    the op that initializes the variable before returning its value.
2022    This method returns the tensor that is used by the op that initializes
2023    the variable.
2024
2025    Returns:
2026      A `Tensor`.
2027    """
2028    return self._initial_value
2029
2030  @property
2031  def constraint(self):
2032    """Returns the constraint function associated with this variable.
2033
2034    Returns:
2035      The constraint function that was passed to the variable constructor.
2036      Can be `None` if no constraint was passed.
2037    """
2038    return self._constraint
2039
2040  def assign(self, value, use_locking=False, name=None, read_value=True):
2041    """Assigns a new value to the variable.
2042
2043    This is essentially a shortcut for `assign(self, value)`.
2044
2045    Args:
2046      value: A `Tensor`. The new value for this variable.
2047      use_locking: If `True`, use locking during the assignment.
2048      name: The name of the operation to be created
2049      read_value: if True, will return something which evaluates to the new
2050        value of the variable; if False will return the assign op.
2051
2052    Returns:
2053      A `Tensor` that will hold the new value of this variable after
2054      the assignment has completed.
2055    """
2056    assign = state_ops.assign(
2057        self._variable, value, use_locking=use_locking, name=name)
2058    if read_value:
2059      return assign
2060    return assign.op
2061
2062  def assign_add(self, delta, use_locking=False, name=None, read_value=True):
2063    """Adds a value to this variable.
2064
2065     This is essentially a shortcut for `assign_add(self, delta)`.
2066
2067    Args:
2068      delta: A `Tensor`. The value to add to this variable.
2069      use_locking: If `True`, use locking during the operation.
2070      name: The name of the operation to be created
2071      read_value: if True, will return something which evaluates to the new
2072        value of the variable; if False will return the assign op.
2073
2074    Returns:
2075      A `Tensor` that will hold the new value of this variable after
2076      the addition has completed.
2077    """
2078    assign = state_ops.assign_add(
2079        self._variable, delta, use_locking=use_locking, name=name)
2080    if read_value:
2081      return assign
2082    return assign.op
2083
2084  def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
2085    """Subtracts a value from this variable.
2086
2087    This is essentially a shortcut for `assign_sub(self, delta)`.
2088
2089    Args:
2090      delta: A `Tensor`. The value to subtract from this variable.
2091      use_locking: If `True`, use locking during the operation.
2092      name: The name of the operation to be created
2093      read_value: if True, will return something which evaluates to the new
2094        value of the variable; if False will return the assign op.
2095
2096    Returns:
2097      A `Tensor` that will hold the new value of this variable after
2098      the subtraction has completed.
2099    """
2100    assign = state_ops.assign_sub(
2101        self._variable, delta, use_locking=use_locking, name=name)
2102    if read_value:
2103      return assign
2104    return assign.op
2105
2106  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
2107    """Subtracts `tf.IndexedSlices` from this variable.
2108
2109    Args:
2110      sparse_delta: `tf.IndexedSlices` to be subtracted from this variable.
2111      use_locking: If `True`, use locking during the operation.
2112      name: the name of the operation.
2113
2114    Returns:
2115      A `Tensor` that will hold the new value of this variable after
2116      the scattered subtraction has completed.
2117
2118    Raises:
2119      TypeError: if `sparse_delta` is not an `IndexedSlices`.
2120    """
2121    if not isinstance(sparse_delta, ops.IndexedSlices):
2122      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
2123    return gen_state_ops.scatter_sub(
2124        self._variable,
2125        sparse_delta.indices,
2126        sparse_delta.values,
2127        use_locking=use_locking,
2128        name=name)
2129
2130  def scatter_add(self, sparse_delta, use_locking=False, name=None):
2131    """Adds `tf.IndexedSlices` to this variable.
2132
2133    Args:
2134      sparse_delta: `tf.IndexedSlices` to be added to this variable.
2135      use_locking: If `True`, use locking during the operation.
2136      name: the name of the operation.
2137
2138    Returns:
2139      A `Tensor` that will hold the new value of this variable after
2140      the scattered addition has completed.
2141
2142    Raises:
2143      TypeError: if `sparse_delta` is not an `IndexedSlices`.
2144    """
2145    if not isinstance(sparse_delta, ops.IndexedSlices):
2146      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
2147    return gen_state_ops.scatter_add(
2148        self._variable,
2149        sparse_delta.indices,
2150        sparse_delta.values,
2151        use_locking=use_locking,
2152        name=name)
2153
2154  def scatter_max(self, sparse_delta, use_locking=False, name=None):
2155    """Updates this variable with the max of `tf.IndexedSlices` and itself.
2156
2157    Args:
2158      sparse_delta: `tf.IndexedSlices` to use as an argument of max with this
2159        variable.
2160      use_locking: If `True`, use locking during the operation.
2161      name: the name of the operation.
2162
2163    Returns:
2164      A `Tensor` that will hold the new value of this variable after
2165      the scattered maximization has completed.
2166
2167    Raises:
2168      TypeError: if `sparse_delta` is not an `IndexedSlices`.
2169    """
2170    if not isinstance(sparse_delta, ops.IndexedSlices):
2171      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
2172    return gen_state_ops.scatter_max(
2173        self._variable,
2174        sparse_delta.indices,
2175        sparse_delta.values,
2176        use_locking=use_locking,
2177        name=name)
2178
2179  def scatter_min(self, sparse_delta, use_locking=False, name=None):
2180    """Updates this variable with the min of `tf.IndexedSlices` and itself.
2181
2182    Args:
2183      sparse_delta: `tf.IndexedSlices` to use as an argument of min with this
2184        variable.
2185      use_locking: If `True`, use locking during the operation.
2186      name: the name of the operation.
2187
2188    Returns:
2189      A `Tensor` that will hold the new value of this variable after
2190      the scattered minimization has completed.
2191
2192    Raises:
2193      TypeError: if `sparse_delta` is not an `IndexedSlices`.
2194    """
2195    if not isinstance(sparse_delta, ops.IndexedSlices):
2196      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
2197    return gen_state_ops.scatter_min(
2198        self._variable,
2199        sparse_delta.indices,
2200        sparse_delta.values,
2201        use_locking=use_locking,
2202        name=name)
2203
2204  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
2205    """Multiply this variable by `tf.IndexedSlices`.
2206
2207    Args:
2208      sparse_delta: `tf.IndexedSlices` to multiply this variable by.
2209      use_locking: If `True`, use locking during the operation.
2210      name: the name of the operation.
2211
2212    Returns:
2213      A `Tensor` that will hold the new value of this variable after
2214      the scattered multiplication has completed.
2215
2216    Raises:
2217      TypeError: if `sparse_delta` is not an `IndexedSlices`.
2218    """
2219    if not isinstance(sparse_delta, ops.IndexedSlices):
2220      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
2221    return gen_state_ops.scatter_mul(
2222        self._variable,
2223        sparse_delta.indices,
2224        sparse_delta.values,
2225        use_locking=use_locking,
2226        name=name)
2227
2228  def scatter_div(self, sparse_delta, use_locking=False, name=None):
2229    """Divide this variable by `tf.IndexedSlices`.
2230
2231    Args:
2232      sparse_delta: `tf.IndexedSlices` to divide this variable by.
2233      use_locking: If `True`, use locking during the operation.
2234      name: the name of the operation.
2235
2236    Returns:
2237      A `Tensor` that will hold the new value of this variable after
2238      the scattered division has completed.
2239
2240    Raises:
2241      TypeError: if `sparse_delta` is not an `IndexedSlices`.
2242    """
2243    if not isinstance(sparse_delta, ops.IndexedSlices):
2244      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
2245    return gen_state_ops.scatter_div(
2246        self._variable,
2247        sparse_delta.indices,
2248        sparse_delta.values,
2249        use_locking=use_locking,
2250        name=name)
2251
2252  def scatter_update(self, sparse_delta, use_locking=False, name=None):
2253    """Assigns `tf.IndexedSlices` to this variable.
2254
2255    Args:
2256      sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
2257      use_locking: If `True`, use locking during the operation.
2258      name: the name of the operation.
2259
2260    Returns:
2261      A `Tensor` that will hold the new value of this variable after
2262      the scattered assignment has completed.
2263
2264    Raises:
2265      TypeError: if `sparse_delta` is not an `IndexedSlices`.
2266    """
2267    if not isinstance(sparse_delta, ops.IndexedSlices):
2268      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
2269    return gen_state_ops.scatter_update(
2270        self._variable,
2271        sparse_delta.indices,
2272        sparse_delta.values,
2273        use_locking=use_locking,
2274        name=name)
2275
2276  def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
2277    """Assigns `tf.IndexedSlices` to this variable batch-wise.
2278
2279    Analogous to `batch_gather`. This assumes that this variable and the
2280    sparse_delta IndexedSlices have a series of leading dimensions that are the
2281    same for all of them, and the updates are performed on the last dimension of
2282    indices. In other words, the dimensions should be the following:
2283
2284    `num_prefix_dims = sparse_delta.indices.ndims - 1`
2285    `batch_dim = num_prefix_dims + 1`
2286    `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[
2287         batch_dim:]`
2288
2289    where
2290
2291    `sparse_delta.updates.shape[:num_prefix_dims]`
2292    `== sparse_delta.indices.shape[:num_prefix_dims]`
2293    `== var.shape[:num_prefix_dims]`
2294
2295    And the operation performed can be expressed as:
2296
2297    `var[i_1, ..., i_n,
2298         sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[
2299            i_1, ..., i_n, j]`
2300
2301    When sparse_delta.indices is a 1D tensor, this operation is equivalent to
2302    `scatter_update`.
2303
2304    To avoid this operation one can looping over the first `ndims` of the
2305    variable and using `scatter_update` on the subtensors that result of slicing
2306    the first dimension. This is a valid option for `ndims = 1`, but less
2307    efficient than this implementation.
2308
2309    Args:
2310      sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
2311      use_locking: If `True`, use locking during the operation.
2312      name: the name of the operation.
2313
2314    Returns:
2315      A `Tensor` that will hold the new value of this variable after
2316      the scattered assignment has completed.
2317
2318    Raises:
2319      TypeError: if `sparse_delta` is not an `IndexedSlices`.
2320    """
2321    return state_ops.batch_scatter_update(
2322        self,
2323        sparse_delta.indices,
2324        sparse_delta.values,
2325        use_locking=use_locking,
2326        name=name)
2327
2328  def scatter_nd_sub(self, indices, updates, name=None):
2329    """Applies sparse subtraction to individual values or slices in a Variable.
2330
2331    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
2332
2333    `indices` must be integer tensor, containing indices into `ref`.
2334    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
2335
2336    The innermost dimension of `indices` (with length `K`) corresponds to
2337    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
2338    dimension of `ref`.
2339
2340    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
2341
2342    ```
2343    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
2344    ```
2345
2346    For example, say we want to add 4 scattered elements to a rank-1 tensor to
2347    8 elements. In Python, that update would look like this:
2348
2349    ```python
2350        ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
2351        indices = tf.constant([[4], [3], [1] ,[7]])
2352        updates = tf.constant([9, 10, 11, 12])
2353        op = ref.scatter_nd_sub(indices, updates)
2354        with tf.compat.v1.Session() as sess:
2355          print sess.run(op)
2356    ```
2357
2358    The resulting update to ref would look like this:
2359
2360        [1, -9, 3, -6, -6, 6, 7, -4]
2361
2362    See `tf.scatter_nd` for more details about how to make updates to
2363    slices.
2364
2365    Args:
2366      indices: The indices to be used in the operation.
2367      updates: The values to be used in the operation.
2368      name: the name of the operation.
2369
2370    Returns:
2371      A `Tensor` that will hold the new value of this variable after
2372      the scattered subtraction has completed.
2373    """
2374    return gen_state_ops.scatter_nd_sub(
2375        self._variable, indices, updates, use_locking=True, name=name)
2376
2377  def scatter_nd_add(self, indices, updates, name=None):
2378    """Applies sparse addition to individual values or slices in a Variable.
2379
2380    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
2381
2382    `indices` must be integer tensor, containing indices into `ref`.
2383    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
2384
2385    The innermost dimension of `indices` (with length `K`) corresponds to
2386    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
2387    dimension of `ref`.
2388
2389    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
2390
2391    ```
2392    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
2393    ```
2394
2395    For example, say we want to add 4 scattered elements to a rank-1 tensor to
2396    8 elements. In Python, that update would look like this:
2397
2398    ```python
2399        ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
2400        indices = tf.constant([[4], [3], [1] ,[7]])
2401        updates = tf.constant([9, 10, 11, 12])
2402        add = ref.scatter_nd_add(indices, updates)
2403        with tf.compat.v1.Session() as sess:
2404          print sess.run(add)
2405    ```
2406
2407    The resulting update to ref would look like this:
2408
2409        [1, 13, 3, 14, 14, 6, 7, 20]
2410
2411    See `tf.scatter_nd` for more details about how to make updates to
2412    slices.
2413
2414    Args:
2415      indices: The indices to be used in the operation.
2416      updates: The values to be used in the operation.
2417      name: the name of the operation.
2418
2419    Returns:
2420      A `Tensor` that will hold the new value of this variable after
2421      the scattered addition has completed.
2422    """
2423    return gen_state_ops.scatter_nd_add(
2424        self._variable, indices, updates, use_locking=True, name=name)
2425
2426  def scatter_nd_update(self, indices, updates, name=None):
2427    """Applies sparse assignment to individual values or slices in a Variable.
2428
2429    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
2430
2431    `indices` must be integer tensor, containing indices into `ref`.
2432    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
2433
2434    The innermost dimension of `indices` (with length `K`) corresponds to
2435    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
2436    dimension of `ref`.
2437
2438    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
2439
2440    ```
2441    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
2442    ```
2443
2444    For example, say we want to add 4 scattered elements to a rank-1 tensor to
2445    8 elements. In Python, that update would look like this:
2446
2447    ```python
2448        ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
2449        indices = tf.constant([[4], [3], [1] ,[7]])
2450        updates = tf.constant([9, 10, 11, 12])
2451        op = ref.scatter_nd_update(indices, updates)
2452        with tf.compat.v1.Session() as sess:
2453          print sess.run(op)
2454    ```
2455
2456    The resulting update to ref would look like this:
2457
2458        [1, 11, 3, 10, 9, 6, 7, 12]
2459
2460    See `tf.scatter_nd` for more details about how to make updates to
2461    slices.
2462
2463    Args:
2464      indices: The indices to be used in the operation.
2465      updates: The values to be used in the operation.
2466      name: the name of the operation.
2467
2468    Returns:
2469      A `Tensor` that will hold the new value of this variable after
2470      the scattered assignment has completed.
2471    """
2472    return gen_state_ops.scatter_nd_update(
2473        self._variable, indices, updates, use_locking=True, name=name)
2474
2475  def scatter_nd_max(self, indices, updates, name=None):
2476    """Updates this variable with the max of `tf.IndexedSlices` and itself.
2477
2478    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
2479
2480    `indices` must be integer tensor, containing indices into `ref`.
2481    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
2482
2483    The innermost dimension of `indices` (with length `K`) corresponds to
2484    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
2485    dimension of `ref`.
2486
2487    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
2488
2489    ```
2490    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
2491    ```
2492
2493    See `tf.scatter_nd` for more details about how to make updates to
2494    slices.
2495
2496    Args:
2497      indices: The indices to be used in the operation.
2498      updates: The values to be used in the operation.
2499      name: the name of the operation.
2500
2501    Returns:
2502      A `Tensor` that will hold the new value of this variable after
2503      the scattered addition has completed.
2504    """
2505    return gen_state_ops.scatter_nd_max(
2506        self._variable, indices, updates, use_locking=True, name=name)
2507
2508  def scatter_nd_min(self, indices, updates, name=None):
2509    """Updates this variable with the min of `tf.IndexedSlices` and itself.
2510
2511    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
2512
2513    `indices` must be integer tensor, containing indices into `ref`.
2514    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
2515
2516    The innermost dimension of `indices` (with length `K`) corresponds to
2517    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
2518    dimension of `ref`.
2519
2520    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
2521
2522    ```
2523    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
2524    ```
2525
2526    See `tf.scatter_nd` for more details about how to make updates to
2527    slices.
2528
2529    Args:
2530      indices: The indices to be used in the operation.
2531      updates: The values to be used in the operation.
2532      name: the name of the operation.
2533
2534    Returns:
2535      A `Tensor` that will hold the new value of this variable after
2536      the scattered addition has completed.
2537    """
2538    return gen_state_ops.scatter_nd_min(
2539        self._variable, indices, updates, use_locking=True, name=name)
2540
2541  def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask,
2542                            end_mask, ellipsis_mask, new_axis_mask,
2543                            shrink_axis_mask):
2544    return gen_array_ops.strided_slice_assign(
2545        ref=self._ref(),
2546        begin=begin,
2547        end=end,
2548        strides=strides,
2549        value=value,
2550        name=name,
2551        begin_mask=begin_mask,
2552        end_mask=end_mask,
2553        ellipsis_mask=ellipsis_mask,
2554        new_axis_mask=new_axis_mask,
2555        shrink_axis_mask=shrink_axis_mask)
2556
2557  @deprecated(None, "Prefer Dataset.range instead.")
2558  def count_up_to(self, limit):
2559    """Increments this variable until it reaches `limit`.
2560
2561    When that Op is run it tries to increment the variable by `1`. If
2562    incrementing the variable would bring it above `limit` then the Op raises
2563    the exception `OutOfRangeError`.
2564
2565    If no error is raised, the Op outputs the value of the variable before
2566    the increment.
2567
2568    This is essentially a shortcut for `count_up_to(self, limit)`.
2569
2570    Args:
2571      limit: value at which incrementing the variable raises an error.
2572
2573    Returns:
2574      A `Tensor` that will hold the variable value before the increment. If no
2575      other Op modifies this variable, the values produced will all be
2576      distinct.
2577    """
2578    return state_ops.count_up_to(self._variable, limit=limit)
2579
2580  # Conversion to tensor.
2581  @staticmethod
2582  def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False):  # pylint: disable=invalid-name
2583    """Utility function for converting a Variable to a Tensor."""
2584    _ = name
2585    if dtype and not dtype.is_compatible_with(v.dtype):
2586      raise ValueError(
2587          "Incompatible type conversion requested to type '%s' for variable "
2588          "of type '%s'" % (dtype.name, v.dtype.name))
2589    if as_ref:
2590      return v._ref()  # pylint: disable=protected-access
2591    else:
2592      return v.value()
2593
2594  # NOTE(mrry): This enables the Variable's overloaded "right" binary
2595  # operators to run when the left operand is an ndarray, because it
2596  # accords the Variable class higher priority than an ndarray, or a
2597  # numpy matrix.
2598  # TODO(mrry): Convert this to using numpy's __numpy_ufunc__
2599  # mechanism, which allows more control over how Variables interact
2600  # with ndarrays.
2601  __array_priority__ = 100
2602
2603  @property
2604  def name(self):
2605    """The name of this variable."""
2606    return self._name
2607
2608  @property
2609  def initializer(self):
2610    """The initializer operation for this variable."""
2611    return self._initializer_op
2612
2613  @property
2614  def device(self):
2615    """The device of this variable."""
2616    return self._variable.device
2617
2618  @property
2619  def dtype(self):
2620    """The `DType` of this variable."""
2621    return self._variable.dtype
2622
2623  @property
2624  def op(self):
2625    """The `Operation` of this variable."""
2626    return self._variable.op
2627
2628  @property
2629  def graph(self):
2630    """The `Graph` of this variable."""
2631    return self._variable.graph
2632
2633  @property
2634  def _distribute_strategy(self):
2635    """The `tf.distribute.Strategy` that this variable was created under."""
2636    return None  # Ref variables are never created inside a strategy.
2637
2638  @property
2639  def shape(self):
2640    """The `TensorShape` of this variable.
2641
2642    Returns:
2643      A `TensorShape`.
2644    """
2645    return self._variable.get_shape()
2646
2647  def to_proto(self, export_scope=None):
2648    """Converts a `Variable` to a `VariableDef` protocol buffer.
2649
2650    Args:
2651      export_scope: Optional `string`. Name scope to remove.
2652
2653    Returns:
2654      A `VariableDef` protocol buffer, or `None` if the `Variable` is not
2655      in the specified name scope.
2656    """
2657    if (export_scope is None or self._variable.name.startswith(export_scope)):
2658      var_def = variable_pb2.VariableDef()
2659      var_def.variable_name = ops.strip_name_scope(self._variable.name,
2660                                                   export_scope)
2661      if self._initial_value is not None:
2662        # For backwards compatibility.
2663        var_def.initial_value_name = ops.strip_name_scope(
2664            self._initial_value.name, export_scope)
2665      var_def.trainable = self.trainable
2666      var_def.synchronization = self.synchronization.value
2667      var_def.aggregation = self.aggregation.value
2668      var_def.initializer_name = ops.strip_name_scope(self.initializer.name,
2669                                                      export_scope)
2670      var_def.snapshot_name = ops.strip_name_scope(self._snapshot.name,
2671                                                   export_scope)
2672      if self._save_slice_info:
2673        var_def.save_slice_info_def.MergeFrom(
2674            self._save_slice_info.to_proto(export_scope=export_scope))
2675      return var_def
2676    else:
2677      return None
2678
2679  def __iadd__(self, other):
2680    logging.log_first_n(
2681        logging.WARN, "Variable += will be deprecated. Use variable.assign_add"
2682        " if you want assignment to the variable value or 'x = x + y'"
2683        " if you want a new python Tensor object.", 1)
2684    return self + other
2685
2686  def __isub__(self, other):
2687    logging.log_first_n(
2688        logging.WARN, "Variable -= will be deprecated. Use variable.assign_sub"
2689        " if you want assignment to the variable value or 'x = x - y'"
2690        " if you want a new python Tensor object.", 1)
2691    return self - other
2692
2693  def __imul__(self, other):
2694    logging.log_first_n(
2695        logging.WARN,
2696        "Variable *= will be deprecated. Use `var.assign(var * other)`"
2697        " if you want assignment to the variable value or `x = x * y`"
2698        " if you want a new python Tensor object.", 1)
2699    return self * other
2700
2701  def __idiv__(self, other):
2702    logging.log_first_n(
2703        logging.WARN,
2704        "Variable /= will be deprecated. Use `var.assign(var / other)`"
2705        " if you want assignment to the variable value or `x = x / y`"
2706        " if you want a new python Tensor object.", 1)
2707    return self / other
2708
2709  def __itruediv__(self, other):
2710    logging.log_first_n(
2711        logging.WARN,
2712        "Variable /= will be deprecated. Use `var.assign(var / other)`"
2713        " if you want assignment to the variable value or `x = x / y`"
2714        " if you want a new python Tensor object.", 1)
2715    return self / other
2716
2717  def __irealdiv__(self, other):
2718    logging.log_first_n(
2719        logging.WARN,
2720        "Variable /= will be deprecated. Use `var.assign(var / other)`"
2721        " if you want assignment to the variable value or `x = x / y`"
2722        " if you want a new python Tensor object.", 1)
2723    return self / other
2724
2725  def __ipow__(self, other):
2726    logging.log_first_n(
2727        logging.WARN,
2728        "Variable **= will be deprecated. Use `var.assign(var ** other)`"
2729        " if you want assignment to the variable value or `x = x ** y`"
2730        " if you want a new python Tensor object.", 1)
2731    return self**other
2732
2733
2734def _try_guard_against_uninitialized_dependencies(name, initial_value):
2735  """Attempt to guard against dependencies on uninitialized variables.
2736
2737  Replace references to variables in `initial_value` with references to the
2738  variable's initialized values. The initialized values are essentially
2739  conditional TensorFlow graphs that return a variable's value if it is
2740  initialized or its `initial_value` if it hasn't been initialized. This
2741  replacement is done on a best effort basis:
2742
2743  - If the `initial_value` graph contains cycles, we don't do any
2744    replacements for that graph.
2745  - If the variables that `initial_value` depends on are not present in the
2746    `GLOBAL_VARIABLES` or `LOCAL_VARIABLES` we don't replace them.
2747
2748  In these cases, it is up to the caller to ensure that the `initial_value`
2749  graph uses initialized variables or that they guard access to variables
2750  using their `initialized_value` method.
2751
2752  Args:
2753    name: Variable name.
2754    initial_value: `Tensor`. The initial value.
2755
2756  Returns:
2757    A `Tensor` suitable to initialize a variable.
2758  Raises:
2759    TypeError: If `initial_value` is not a `Tensor`.
2760  """
2761  if not isinstance(initial_value, ops.Tensor):
2762    raise TypeError("initial_value needs to be a Tensor: %s" % initial_value)
2763
2764  # Don't modify initial_value if it contains any cyclic dependencies.
2765  if _has_cycle(initial_value.op, state={}):
2766    return initial_value
2767  return _safe_initial_value_from_tensor(name, initial_value, op_cache={})
2768
2769
2770_UNKNOWN, _STARTED, _FINISHED = range(3)
2771
2772
2773def _has_cycle(op, state):
2774  """Detect cycles in the dependencies of `initial_value`."""
2775  op_state = state.get(op.name, _UNKNOWN)
2776  if op_state == _STARTED:
2777    return True
2778  elif op_state == _FINISHED:
2779    return False
2780
2781  state[op.name] = _STARTED
2782  for i in itertools.chain((i.op for i in op.inputs), op.control_inputs):
2783    if _has_cycle(i, state):
2784      return True
2785  state[op.name] = _FINISHED
2786  return False
2787
2788
2789def _safe_initial_value_from_tensor(name, tensor, op_cache):
2790  """Replace dependencies on variables with their initialized values.
2791
2792  Args:
2793    name: Variable name.
2794    tensor: A `Tensor`. The tensor to replace.
2795    op_cache: A dict mapping operation names to `Operation`s. Used to memoize
2796      the results so as to avoid creating redundant operations.
2797
2798  Returns:
2799    A `Tensor` compatible with `tensor`. Any inputs that lead to variable
2800    values will be replaced with a corresponding graph that uses the
2801    variable's initialized values. This is done on a best-effort basis. If no
2802    modifications need to be made then `tensor` will be returned unchanged.
2803  """
2804  op = tensor.op
2805  new_op = op_cache.get(op.name)
2806  if new_op is None:
2807    new_op = _safe_initial_value_from_op(name, op, op_cache)
2808    op_cache[op.name] = new_op
2809  return new_op.outputs[tensor.value_index]
2810
2811
2812def _safe_initial_value_from_op(name, op, op_cache):
2813  """Replace dependencies on variables with their initialized values.
2814
2815  Args:
2816    name: Variable name.
2817    op: An `Operation`. The operation to replace.
2818    op_cache: A dict mapping operation names to `Operation`s. Used to memoize
2819      the results so as to avoid creating redundant operations.
2820
2821  Returns:
2822    An `Operation` compatible with `op`. Any inputs that lead to variable
2823    values will be replaced with a corresponding graph that uses the
2824    variable's initialized values. This is done on a best-effort basis. If no
2825    modifications need to be made then `op` will be returned unchanged.
2826  """
2827  op_type = op.node_def.op
2828  if op_type in ("IsVariableInitialized", "VarIsInitializedOp",
2829                 "ReadVariableOp", "If"):
2830    return op
2831
2832  # Attempt to find the initialized_value of any variable reference / handles.
2833  # TODO(b/70206927): Fix handling of ResourceVariables.
2834  if op_type in ("Variable", "VariableV2", "VarHandleOp"):
2835    initialized_value = _find_initialized_value_for_variable(op)
2836    return op if initialized_value is None else initialized_value.op
2837
2838  # Recursively build initializer expressions for inputs.
2839  modified = False
2840  new_op_inputs = []
2841  for op_input in op.inputs:
2842    new_op_input = _safe_initial_value_from_tensor(name, op_input, op_cache)
2843    new_op_inputs.append(new_op_input)
2844    modified = modified or (new_op_input != op_input)
2845
2846  # If at least one input was modified, replace the op.
2847  if modified:
2848    new_op_type = op_type
2849    if new_op_type == "RefSwitch":
2850      new_op_type = "Switch"
2851    new_op_name = op.node_def.name + "_" + name
2852    new_op_name = new_op_name.replace(":", "_")
2853    return op.graph.create_op(
2854        new_op_type,
2855        new_op_inputs,
2856        op._output_types,  # pylint: disable=protected-access
2857        name=new_op_name,
2858        attrs=op.node_def.attr)
2859
2860  return op
2861
2862
2863def _find_initialized_value_for_variable(variable_op):
2864  """Find the initialized value for a variable op.
2865
2866  To do so, lookup the variable op in the variables collection.
2867
2868  Args:
2869    variable_op: A variable `Operation`.
2870
2871  Returns:
2872    A `Tensor` representing the initialized value for the variable or `None`
2873    if the initialized value could not be found.
2874  """
2875  try:
2876    var_names = [variable_op.node_def.name, variable_op.node_def.name + ":0"]
2877    for collection_name in (ops.GraphKeys.GLOBAL_VARIABLES,
2878                            ops.GraphKeys.LOCAL_VARIABLES):
2879      for var in variable_op.graph.get_collection(collection_name):
2880        if var.name in var_names:
2881          return var.initialized_value()
2882  except AttributeError:
2883    # Return None when an incomplete user-defined variable type was put in
2884    # the collection.
2885    return None
2886  return None
2887
2888
2889class PartitionedVariable(object):
2890  """A container for partitioned `Variable` objects.
2891
2892  @compatibility(eager) `tf.PartitionedVariable` is not compatible with
2893  eager execution.  Use `tf.Variable` instead which is compatible
2894  with both eager execution and graph construction.  See [the
2895  TensorFlow Eager Execution
2896  guide](https://www.tensorflow.org/guide/eager#variables_and_optimizers)
2897  for details on how variables work in eager execution.
2898  @end_compatibility
2899  """
2900
2901  def __init__(self, name, shape, dtype, variable_list, partitions):
2902    """Creates a new partitioned variable wrapper.
2903
2904    Variables passed via the variable_list must contain a save_slice_info
2905    field.  Concatenation and iteration is in lexicographic order according
2906    to the var_offset property of the save_slice_info.
2907
2908    Args:
2909      name: String. Overall name of the variables.
2910      shape: List of integers.  Overall shape of the variables.
2911      dtype: Type of the variables.
2912      variable_list: List of `Variable` that comprise this partitioned variable.
2913      partitions: List of integers.  Number of partitions for each dimension.
2914
2915    Raises:
2916      TypeError: If `variable_list` is not a list of `Variable` objects, or
2917        `partitions` is not a list.
2918      ValueError: If `variable_list` is empty, or the `Variable` shape
2919        information does not match `shape`, or `partitions` has invalid values.
2920    """
2921    if not isinstance(variable_list, (list, tuple)):
2922      raise TypeError("variable_list is not a list or tuple: %s" %
2923                      variable_list)
2924    if not isinstance(partitions, (list, tuple)):
2925      raise TypeError("partitions is not a list or tuple: %s" % partitions)
2926    if not all(p >= 1 for p in partitions):
2927      raise ValueError("partition values must be positive: %s" % partitions)
2928    if not variable_list:
2929      raise ValueError("variable_list may not be empty")
2930    # pylint: disable=protected-access
2931    for v in variable_list:
2932      # Sort the variable_list lexicographically according to var offset value.
2933      if not all(v._get_save_slice_info() is not None for v in variable_list):
2934        raise ValueError(
2935            "All variables must have a save_slice_info available: %s" %
2936            [v.name for v in variable_list])
2937      if len(shape) != len(partitions):
2938        raise ValueError("len(shape) != len(partitions): %s vs. %s" %
2939                         (shape, partitions))
2940      if v._get_save_slice_info().full_shape != shape:
2941        raise ValueError("All variables' full shapes must match shape: %s; "
2942                         "but full shapes were: %s" %
2943                         (shape, str([v._get_save_slice_info().full_shape])))
2944    self._variable_list = sorted(
2945        variable_list, key=lambda v: v._get_save_slice_info().var_offset)
2946    # pylint: enable=protected-access
2947
2948    self._name = name
2949    self._shape = shape
2950    self._dtype = dtype
2951    self._partitions = partitions
2952    self._as_tensor = None
2953
2954  def __iter__(self):
2955    """Return an iterable for accessing the underlying partition Variables."""
2956    return iter(self._variable_list)
2957
2958  def __len__(self):
2959    num_partition_axes = len(self._partition_axes())
2960    if num_partition_axes > 1:
2961      raise ValueError("Cannot get a length for %d > 1 partition axes" %
2962                       num_partition_axes)
2963    return len(self._variable_list)
2964
2965  def _partition_axes(self):
2966    if all(p == 1 for p in self._partitions):
2967      return [0]
2968    else:
2969      return [i for i, p in enumerate(self._partitions) if p > 1]
2970
2971  def _concat(self):
2972    """Returns the overall concatenated value as a `Tensor`.
2973
2974    This is different from using the partitioned variable directly as a tensor
2975    (through tensor conversion and `as_tensor`) in that it creates a new set of
2976    operations that keeps the control dependencies from its scope.
2977
2978    Returns:
2979      `Tensor` containing the concatenated value.
2980    """
2981    if len(self._variable_list) == 1:
2982      with ops.name_scope(None):
2983        return array_ops.identity(self._variable_list[0], name=self._name)
2984
2985    partition_axes = self._partition_axes()
2986
2987    if len(partition_axes) > 1:
2988      raise NotImplementedError(
2989          "Cannot concatenate along more than one dimension: %s.  "
2990          "Multi-axis partition concat is not supported" % str(partition_axes))
2991    partition_ix = partition_axes[0]
2992
2993    with ops.name_scope(self._name + "/ConcatPartitions/"):
2994      concatenated = array_ops.concat(self._variable_list, partition_ix)
2995
2996    with ops.name_scope(None):
2997      return array_ops.identity(concatenated, name=self._name)
2998
2999  def as_tensor(self):
3000    """Returns the overall concatenated value as a `Tensor`.
3001
3002    The returned tensor will not inherit the control dependencies from the scope
3003    where the value is used, which is similar to getting the value of
3004    `Variable`.
3005
3006    Returns:
3007      `Tensor` containing the concatenated value.
3008    """
3009    with ops.control_dependencies(None):
3010      return self._concat()
3011
3012  @staticmethod
3013  def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False):
3014    # pylint: disable=invalid-name
3015    _ = name
3016    if dtype is not None and not dtype.is_compatible_with(v.dtype):
3017      raise ValueError(
3018          "Incompatible type conversion requested to type '%s' for variable "
3019          "of type '%s'" % (dtype.name, v.dtype.name))
3020    if as_ref:
3021      raise NotImplementedError(
3022          "PartitionedVariable doesn't support being used as a reference.")
3023    else:
3024      return v.as_tensor()
3025
3026  @property
3027  def name(self):
3028    return self._name
3029
3030  @property
3031  def dtype(self):
3032    return self._dtype
3033
3034  @property
3035  def shape(self):
3036    return self.get_shape()
3037
3038  @property
3039  def _distribute_strategy(self):
3040    """The `tf.distribute.Strategy` that this variable was created under."""
3041    # NOTE(yuefengz): Today, no partitioned variables in a distribute strategy.
3042    return None
3043
3044  def get_shape(self):
3045    return self._shape
3046
3047  def _get_variable_list(self):
3048    return self._variable_list
3049
3050  def _get_partitions(self):
3051    return self._partitions
3052
3053  def _apply_assign_fn(self, assign_fn, value):
3054    partition_axes = self._partition_axes()
3055    if len(partition_axes) > 1:
3056      raise NotImplementedError(
3057          "Cannot do assign action along more than one dimension: %s.  "
3058          "Multi-axis partition assign action is not supported " %
3059          str(partition_axes))
3060    if isinstance(value, list):
3061      assert len(value) == len(self._variable_list)
3062      value_list = value
3063    elif isinstance(value, PartitionedVariable):
3064      value_list = [var_part for var_part in value]
3065    else:
3066      partition_ix = partition_axes[0]
3067      size_splits_list = [
3068          tensor_shape.dimension_value(var.shape[partition_ix])
3069          for var in self._variable_list
3070      ]
3071      value_list = array_ops.split(value, size_splits_list, axis=partition_ix)
3072
3073    op_list = [
3074        assign_fn(var, value_list[idx])
3075        for idx, var in enumerate(self._variable_list)
3076    ]
3077    return op_list
3078
3079  def assign(self, value, use_locking=False, name=None, read_value=True):
3080    assign_fn = lambda var, r_value: var.assign(
3081        r_value, use_locking=use_locking, name=name, read_value=read_value)
3082    assign_list = self._apply_assign_fn(assign_fn, value)
3083    if read_value:
3084      return assign_list
3085    return [assign.op for assign in assign_list]
3086
3087  def assign_add(self, value, use_locking=False, name=None, read_value=True):
3088    assign_fn = lambda var, r_value: var.assign_add(
3089        r_value, use_locking=use_locking, name=name, read_value=read_value)
3090    assign_list = self._apply_assign_fn(assign_fn, value)
3091    if read_value:
3092      return assign_list
3093    return [assign.op for assign in assign_list]
3094
3095  def assign_sub(self, value, use_locking=False, name=None, read_value=True):
3096    assign_fn = lambda var, r_value: var.assign_sub(
3097        r_value, use_locking=use_locking, name=name, read_value=read_value)
3098    assign_list = self._apply_assign_fn(assign_fn, value)
3099    if read_value:
3100      return assign_list
3101    return [assign.op for assign in assign_list]
3102
3103
3104# Register a conversion function which reads the value of the variable,
3105# allowing instances of the class to be used as tensors.
3106ops.register_tensor_conversion_function(RefVariable,
3107                                        RefVariable._TensorConversionFunction)  # pylint: disable=protected-access
3108
3109
3110@tf_export(v1=["global_variables"])
3111def global_variables(scope=None):
3112  """Returns global variables.
3113
3114  Global variables are variables that are shared across machines in a
3115  distributed environment. The `Variable()` constructor or `get_variable()`
3116  automatically adds new variables to the graph collection
3117  `GraphKeys.GLOBAL_VARIABLES`.
3118  This convenience function returns the contents of that collection.
3119
3120  An alternative to global variables are local variables. See
3121  `tf.compat.v1.local_variables`
3122
3123  Args:
3124    scope: (Optional.) A string. If supplied, the resulting list is filtered to
3125      include only items whose `name` attribute matches `scope` using
3126      `re.match`. Items without a `name` attribute are never returned if a scope
3127      is supplied. The choice of `re.match` means that a `scope` without special
3128      tokens filters by prefix.
3129
3130  Returns:
3131    A list of `Variable` objects.
3132  """
3133  return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope)
3134
3135
3136@tf_export(v1=["all_variables"])
3137@deprecated("2017-03-02", "Please use tf.global_variables instead.")
3138def all_variables():
3139  """Use `tf.compat.v1.global_variables` instead."""
3140  return global_variables()
3141
3142
3143def _all_saveable_objects(scope=None):
3144  """Returns all variables and `SaveableObject`s that must be checkpointed.
3145
3146  Args:
3147    scope: (Optional.) A string. If supplied, the resulting list is filtered to
3148      include only items whose `name` attribute matches `scope` using
3149      `re.match`. Items without a `name` attribute are never returned if a scope
3150      is supplied. The choice of `re.match` means that a `scope` without special
3151      tokens filters by prefix.
3152
3153  Returns:
3154    A list of `Variable` and `SaveableObject` to be checkpointed
3155  """
3156  # TODO(andreasst): make this function public once things are settled.
3157  return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) +
3158          ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope))
3159
3160
3161@tf_export(v1=["local_variables"])
3162def local_variables(scope=None):
3163  """Returns local variables.
3164
3165  Local variables - per process variables, usually not saved/restored to
3166  checkpoint and used for temporary or intermediate values.
3167  For example, they can be used as counters for metrics computation or
3168  number of epochs this machine has read data.
3169  The `tf.contrib.framework.local_variable()` function automatically adds the
3170  new variable to `GraphKeys.LOCAL_VARIABLES`.
3171  This convenience function returns the contents of that collection.
3172
3173  An alternative to local variables are global variables. See
3174  `tf.compat.v1.global_variables`
3175
3176  Args:
3177    scope: (Optional.) A string. If supplied, the resulting list is filtered to
3178      include only items whose `name` attribute matches `scope` using
3179      `re.match`. Items without a `name` attribute are never returned if a scope
3180      is supplied. The choice of `re.match` means that a `scope` without special
3181      tokens filters by prefix.
3182
3183  Returns:
3184    A list of local `Variable` objects.
3185  """
3186  return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES, scope)
3187
3188
3189@tf_export(v1=["model_variables"])
3190def model_variables(scope=None):
3191  """Returns all variables in the MODEL_VARIABLES collection.
3192
3193  Args:
3194    scope: (Optional.) A string. If supplied, the resulting list is filtered to
3195      include only items whose `name` attribute matches `scope` using
3196      `re.match`. Items without a `name` attribute are never returned if a scope
3197      is supplied. The choice of `re.match` means that a `scope` without special
3198      tokens filters by prefix.
3199
3200  Returns:
3201    A list of local Variable objects.
3202  """
3203  return ops.get_collection(ops.GraphKeys.MODEL_VARIABLES, scope)
3204
3205
3206@tf_export(v1=["trainable_variables"])
3207def trainable_variables(scope=None):
3208  """Returns all variables created with `trainable=True`.
3209
3210  When passed `trainable=True`, the `Variable()` constructor automatically
3211  adds new variables to the graph collection
3212  `GraphKeys.TRAINABLE_VARIABLES`. This convenience function returns the
3213  contents of that collection.
3214
3215  Args:
3216    scope: (Optional.) A string. If supplied, the resulting list is filtered to
3217      include only items whose `name` attribute matches `scope` using
3218      `re.match`. Items without a `name` attribute are never returned if a scope
3219      is supplied. The choice of `re.match` means that a `scope` without special
3220      tokens filters by prefix.
3221
3222  Returns:
3223    A list of Variable objects.
3224  """
3225  return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, scope)
3226
3227
3228@tf_export(v1=["moving_average_variables"])
3229def moving_average_variables(scope=None):
3230  """Returns all variables that maintain their moving averages.
3231
3232  If an `ExponentialMovingAverage` object is created and the `apply()`
3233  method is called on a list of variables, these variables will
3234  be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection.
3235  This convenience function returns the contents of that collection.
3236
3237  Args:
3238    scope: (Optional.) A string. If supplied, the resulting list is filtered to
3239      include only items whose `name` attribute matches `scope` using
3240      `re.match`. Items without a `name` attribute are never returned if a scope
3241      is supplied. The choice of `re.match` means that a `scope` without special
3242      tokens filters by prefix.
3243
3244  Returns:
3245    A list of Variable objects.
3246  """
3247  return ops.get_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, scope)
3248
3249
3250@tf_export(v1=["initializers.variables", "variables_initializer"])
3251def variables_initializer(var_list, name="init"):
3252  """Returns an Op that initializes a list of variables.
3253
3254  After you launch the graph in a session, you can run the returned Op to
3255  initialize all the variables in `var_list`. This Op runs all the
3256  initializers of the variables in `var_list` in parallel.
3257
3258  Calling `initialize_variables()` is equivalent to passing the list of
3259  initializers to `Group()`.
3260
3261  If `var_list` is empty, however, the function still returns an Op that can
3262  be run. That Op just has no effect.
3263
3264  Args:
3265    var_list: List of `Variable` objects to initialize.
3266    name: Optional name for the returned operation.
3267
3268  Returns:
3269    An Op that run the initializers of all the specified variables.
3270  """
3271  if var_list and not context.executing_eagerly():
3272    return control_flow_ops.group(*[v.initializer for v in var_list], name=name)
3273  return control_flow_ops.no_op(name=name)
3274
3275
3276@tf_export(v1=["initialize_variables"])
3277@tf_should_use.should_use_result
3278@deprecated("2017-03-02", "Use `tf.variables_initializer` instead.")
3279def initialize_variables(var_list, name="init"):
3280  """See `tf.compat.v1.variables_initializer`."""
3281  return variables_initializer(var_list, name=name)
3282
3283
3284@tf_export(v1=["initializers.global_variables", "global_variables_initializer"])
3285def global_variables_initializer():
3286  """Returns an Op that initializes global variables.
3287
3288  This is just a shortcut for `variables_initializer(global_variables())`
3289
3290  Returns:
3291    An Op that initializes global variables in the graph.
3292  """
3293  if context.executing_eagerly():
3294    return control_flow_ops.no_op(name="global_variables_initializer")
3295  return variables_initializer(global_variables())
3296
3297
3298@tf_export(v1=["initialize_all_variables"])
3299@tf_should_use.should_use_result
3300@deprecated("2017-03-02", "Use `tf.global_variables_initializer` instead.")
3301def initialize_all_variables():
3302  """See `tf.compat.v1.global_variables_initializer`."""
3303  return global_variables_initializer()
3304
3305
3306@tf_export(v1=["initializers.local_variables", "local_variables_initializer"])
3307def local_variables_initializer():
3308  """Returns an Op that initializes all local variables.
3309
3310  This is just a shortcut for `variables_initializer(local_variables())`
3311
3312  Returns:
3313    An Op that initializes all local variables in the graph.
3314  """
3315  if context.executing_eagerly():
3316    return control_flow_ops.no_op(name="local_variables_initializer")
3317  return variables_initializer(local_variables())
3318
3319
3320@tf_export(v1=["initialize_local_variables"])
3321@tf_should_use.should_use_result
3322@deprecated("2017-03-02", "Use `tf.local_variables_initializer` instead.")
3323def initialize_local_variables():
3324  """See `tf.compat.v1.local_variables_initializer`."""
3325  return local_variables_initializer()
3326
3327
3328@tf_export(v1=["is_variable_initialized"])
3329@tf_should_use.should_use_result
3330def is_variable_initialized(variable):
3331  """Tests if a variable has been initialized.
3332
3333  Args:
3334    variable: A `Variable`.
3335
3336  Returns:
3337    Returns a scalar boolean Tensor, `True` if the variable has been
3338    initialized, `False` otherwise.
3339  """
3340  return state_ops.is_variable_initialized(variable)
3341
3342
3343@tf_export(v1=["assert_variables_initialized"])
3344@tf_should_use.should_use_result
3345def assert_variables_initialized(var_list=None):
3346  """Returns an Op to check if variables are initialized.
3347
3348  NOTE: This function is obsolete and will be removed in 6 months.  Please
3349  change your implementation to use `report_uninitialized_variables()`.
3350
3351  When run, the returned Op will raise the exception `FailedPreconditionError`
3352  if any of the variables has not yet been initialized.
3353
3354  Note: This function is implemented by trying to fetch the values of the
3355  variables. If one of the variables is not initialized a message may be
3356  logged by the C++ runtime. This is expected.
3357
3358  Args:
3359    var_list: List of `Variable` objects to check. Defaults to the value of
3360      `global_variables().`
3361
3362  Returns:
3363    An Op, or None if there are no variables.
3364  """
3365  if var_list is None:
3366    var_list = global_variables() + local_variables()
3367  # Backwards compatibility for old-style variables. TODO(touts): remove.
3368  if not var_list:
3369    var_list = []
3370    for op in ops.get_default_graph().get_operations():
3371      if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]:
3372        var_list.append(op.outputs[0])
3373  if not var_list:
3374    return None
3375  else:
3376    ranks = []
3377    for var in var_list:
3378      with ops.colocate_with(var.op):
3379        ranks.append(array_ops.rank_internal(var, optimize=False))
3380    if len(ranks) == 1:
3381      return ranks[0]
3382    else:
3383      return array_ops.stack(ranks)
3384
3385
3386@tf_export(v1=["report_uninitialized_variables"])
3387@tf_should_use.should_use_result
3388def report_uninitialized_variables(var_list=None,
3389                                   name="report_uninitialized_variables"):
3390  """Adds ops to list the names of uninitialized variables.
3391
3392  When run, it returns a 1-D tensor containing the names of uninitialized
3393  variables if there are any, or an empty array if there are none.
3394
3395  Args:
3396    var_list: List of `Variable` objects to check. Defaults to the value of
3397      `global_variables() + local_variables()`
3398    name: Optional name of the `Operation`.
3399
3400  Returns:
3401    A 1-D tensor containing names of the uninitialized variables, or an empty
3402    1-D tensor if there are no variables or no uninitialized variables.
3403  """
3404  if var_list is None:
3405    var_list = global_variables() + local_variables()
3406    # Backwards compatibility for old-style variables. TODO(touts): remove.
3407    if not var_list:
3408      var_list = []
3409      for op in ops.get_default_graph().get_operations():
3410        if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]:
3411          var_list.append(op.outputs[0])
3412  with ops.name_scope(name):
3413    # Run all operations on CPU
3414    if var_list:
3415      init_vars = [state_ops.is_variable_initialized(v) for v in var_list]
3416    local_device = os.environ.get(
3417        "TF_DEVICE_FOR_UNINITIALIZED_VARIABLE_REPORTING", "/cpu:0")
3418    with ops.device(local_device):
3419      if not var_list:
3420        # Return an empty tensor so we only need to check for returned tensor
3421        # size being 0 as an indication of model ready.
3422        return array_ops.constant([], dtype=dtypes.string)
3423      else:
3424        # Get a 1-D boolean tensor listing whether each variable is initialized.
3425        variables_mask = math_ops.logical_not(array_ops.stack(init_vars))
3426        # Get a 1-D string tensor containing all the variable names.
3427        variable_names_tensor = array_ops.constant(
3428            [s.op.name for s in var_list])
3429        # Return a 1-D tensor containing all the names of
3430        # uninitialized variables.
3431        return array_ops.boolean_mask(variable_names_tensor, variables_mask)
3432
3433
3434ops.register_tensor_conversion_function(
3435    PartitionedVariable, PartitionedVariable._TensorConversionFunction)  # pylint: disable=protected-access
3436