• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops to use variables as resources."""
16
17# pylint: disable=g-bad-name
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import contextlib
23import functools
24import weakref
25
26import numpy as np
27
28from tensorflow.core.framework import attr_value_pb2
29from tensorflow.core.framework import variable_pb2
30from tensorflow.python.client import pywrap_tf_session
31from tensorflow.python.eager import context
32from tensorflow.python.eager import tape
33from tensorflow.python.framework import auto_control_deps_utils as acd
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import cpp_shape_inference_pb2
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import errors
38from tensorflow.python.framework import meta_graph
39from tensorflow.python.framework import ops
40from tensorflow.python.framework import tensor_shape
41from tensorflow.python.framework import tensor_spec
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import gen_array_ops
44from tensorflow.python.ops import gen_resource_variable_ops
45from tensorflow.python.ops import gen_state_ops
46from tensorflow.python.ops import handle_data_util
47from tensorflow.python.ops import math_ops
48from tensorflow.python.ops import state_ops
49from tensorflow.python.ops import variables
50# go/tf-wildcard-import
51# pylint: disable=wildcard-import
52from tensorflow.python.ops.gen_resource_variable_ops import *
53# pylint: enable=wildcard-import
54from tensorflow.python.training.tracking import base as trackable
55from tensorflow.python.types import core
56from tensorflow.python.util import _pywrap_utils
57from tensorflow.python.util import compat
58from tensorflow.python.util.deprecation import deprecated
59from tensorflow.python.util.tf_export import tf_export
60
61acd.register_read_only_resource_op("ReadVariableOp")
62acd.register_read_only_resource_op("VariableShape")
63acd.register_read_only_resource_op("ResourceGather")
64acd.register_read_only_resource_op("ResourceGatherNd")
65acd.register_read_only_resource_op("_ReadVariablesOp")
66
67
68# TODO(allenl): Remove this alias and migrate callers.
69get_resource_handle_data = handle_data_util.get_resource_handle_data
70
71
72def get_eager_safe_handle_data(handle):
73  """Get the data handle from the Tensor `handle`."""
74  assert isinstance(handle, ops.Tensor)
75
76  if isinstance(handle, ops.EagerTensor):
77    return handle._handle_data  # pylint: disable=protected-access
78  else:
79    return get_resource_handle_data(handle)
80
81
82def _set_handle_shapes_and_types(tensor, handle_data, graph_mode):
83  """Sets the shape inference result HandleData on tensor.
84
85  Args:
86    tensor: A `Tensor` or `EagerTensor`.
87    handle_data: A `CppShapeInferenceResult.HandleData`.
88    graph_mode: A python bool.
89  """
90  tensor._handle_data = handle_data  # pylint: disable=protected-access
91  if not graph_mode:
92    return
93
94  # Not an EagerTensor, so a graph tensor.
95  shapes, types = zip(*[(pair.shape, pair.dtype)
96                        for pair in handle_data.shape_and_type])
97  ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
98  shapes = [
99      [d.size for d in s.dim]  # pylint: disable=g-complex-comprehension
100      if not s.unknown_rank else None for s in shapes
101  ]
102  pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
103      tensor._op._graph._c_graph,  # pylint: disable=protected-access
104      tensor._as_tf_output(),  # pylint: disable=protected-access
105      shapes,
106      ranks,
107      types)
108
109
110def _combine_handle_data(handle, initial_value):
111  """Concats HandleData from tensors `handle` and `initial_value`.
112
113  Args:
114    handle: A `Tensor` of dtype `resource`.
115    initial_value: A `Tensor`.
116
117  Returns:
118    A `CppShapeInferenceResult.HandleData`.  If `initial_value` has dtype
119    `variant`, the `HandleData` contains the concatenation of the shape_and_type
120    from both `handle` and `initial_value`.
121
122  Raises:
123    RuntimeError: If handle, which was returned by VarHandleOp, either has
124      no handle data, or its len(handle_data.shape_and_type) != 1.
125  """
126  assert handle.dtype == dtypes.resource
127
128  variable_handle_data = get_eager_safe_handle_data(handle)
129
130  if initial_value.dtype != dtypes.variant:
131    return variable_handle_data
132
133  extra_handle_data = get_eager_safe_handle_data(initial_value)
134  if extra_handle_data is not None and extra_handle_data.is_set:
135    if (variable_handle_data is None or not variable_handle_data.is_set or
136        len(variable_handle_data.shape_and_type) != 1):
137      raise RuntimeError(
138          "Expected VarHandleOp to return a length==1 shape_and_type, "
139          f"but saw: '{variable_handle_data}'")
140    variable_handle_data.shape_and_type.extend(extra_handle_data.shape_and_type)
141  return variable_handle_data
142
143
144def _variable_handle_from_shape_and_dtype(shape,
145                                          dtype,
146                                          shared_name,
147                                          name,
148                                          graph_mode,
149                                          initial_value=None):
150  """Create a variable handle, copying in handle data from `initial_value`."""
151  container = ops.get_default_graph()._container  # pylint: disable=protected-access
152  if container is None:
153    container = ""
154  shape = tensor_shape.as_shape(shape)
155  dtype = dtypes.as_dtype(dtype)
156  if not graph_mode:
157    if shared_name is not None:
158      raise errors.InternalError(  # pylint: disable=no-value-for-parameter
159          "Using an explicit shared_name is not supported executing eagerly.")
160    shared_name = context.shared_name()
161
162  handle = gen_resource_variable_ops.var_handle_op(
163      shape=shape,
164      dtype=dtype,
165      shared_name=shared_name,
166      name=name,
167      container=container)
168  if initial_value is None:
169    initial_value = handle
170  if graph_mode:
171    full_handle_data = _combine_handle_data(handle, initial_value)
172    _set_handle_shapes_and_types(handle, full_handle_data, graph_mode)
173    return handle
174  else:
175    handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData()
176    handle_data.is_set = True
177    handle_data.shape_and_type.append(
178        cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType(
179            shape=shape.as_proto(), dtype=dtype.as_datatype_enum))
180
181    if initial_value is not None and initial_value.dtype == dtypes.variant:
182      extra_handle_data = get_eager_safe_handle_data(initial_value)
183      if extra_handle_data is not None and extra_handle_data.is_set:
184        if (not handle_data.is_set or len(handle_data.shape_and_type) != 1):
185          raise RuntimeError(
186              "Expected VarHandleOp to return a length==1 shape_and_type, "
187              f"but saw: '{handle_data}'")
188        handle_data.shape_and_type.extend(extra_handle_data.shape_and_type)
189
190    _set_handle_shapes_and_types(handle, handle_data, graph_mode)
191    return handle
192
193
194def eager_safe_variable_handle(initial_value, shape, shared_name, name,
195                               graph_mode):
196  """Creates a variable handle with information to do shape inference.
197
198  The dtype is read from `initial_value` and stored in the returned
199  resource tensor's handle data.
200
201  If `initial_value.dtype == tf.variant`, we additionally extract the handle
202  data (if any) from `initial_value` and append it to the `handle_data`.
203  In this case, the returned tensor's handle data is in the form
204
205  ```
206  is_set: true
207  shape_and_type {
208    shape {
209      // initial_value.shape
210    }
211    dtype: DT_VARIANT
212  }
213  shape_and_type {
214    // handle_data(initial_value).shape_and_type[0]
215  }
216  shape_and_type {
217    // handle_data(initial_value).shape_and_type[1]
218  }
219  ...
220  ```
221
222  Ops that read from this tensor, such as `ReadVariableOp` and
223  `AssignVariableOp`, know that `handle_data(handle).shape_and_type[1:]`
224  correspond to the handle data of the variant(s) stored in the Variable.
225
226  Args:
227    initial_value: A `Tensor`.
228    shape: The shape of the handle data. Can be `TensorShape(None)` (i.e.
229      unknown shape).
230    shared_name: A string.
231    name: A string.
232    graph_mode: A python bool.
233
234  Returns:
235    The handle, a `Tensor` of type `resource`.
236  """
237  dtype = initial_value.dtype.base_dtype
238  return _variable_handle_from_shape_and_dtype(shape, dtype, shared_name, name,
239                                               graph_mode, initial_value)
240
241
242@contextlib.contextmanager
243def _handle_graph(handle):
244  # Note: might have an eager tensor but not be executing eagerly when building
245  # functions.
246  if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor) or
247      ops.has_default_graph()):
248    yield
249  else:
250    with handle.graph.as_default():
251      yield
252
253
254class EagerResourceDeleter(object):
255  """An object which cleans up a resource handle.
256
257  An alternative to defining a __del__ method on an object. The intended use is
258  that ResourceVariables or other objects with resource handles will maintain a
259  single reference to this object. When the parent object is collected, this
260  object will be too. Even if the parent object is part of a reference cycle,
261  the cycle will be collectable.
262  """
263
264  __slots__ = ["_handle", "_handle_device", "_context"]
265
266  def __init__(self, handle, handle_device):
267    if not isinstance(handle, ops.Tensor):
268      raise ValueError(
269          (f"Passed handle={handle} to EagerResourceDeleter. Was expecting "
270           f"the handle to be a `tf.Tensor`."))
271    self._handle = handle
272    self._handle_device = handle_device
273    # This is held since the __del__ function runs an op, and if the context()
274    # is collected before this object, there will be a segfault when running the
275    # op.
276    self._context = context.context()
277
278  def __del__(self):
279    # Resources follow object-identity when executing eagerly, so it is safe to
280    # delete the resource we have a handle to.
281    try:
282      # A packed EagerTensor doesn't own any resource.
283      if isinstance(self._handle, ops.EagerTensor) and self._handle.is_packed:
284        return
285      # This resource was created in eager mode. However, this destructor may be
286      # running in graph mode (especially during unit tests). To clean up
287      # successfully, we switch back into eager mode temporarily.
288      with context.eager_mode():
289        with ops.device(self._handle_device):
290          gen_resource_variable_ops.destroy_resource_op(
291              self._handle, ignore_lookup_error=True)
292    except TypeError:
293      # Suppress some exceptions, mainly for the case when we're running on
294      # module deletion. Things that can go wrong include the context module
295      # already being unloaded, self._handle._handle_data no longer being
296      # valid, and so on. Printing warnings in these cases is silly
297      # (exceptions raised from __del__ are printed as warnings to stderr).
298      pass  # 'NoneType' object is not callable when the handle has been
299      # partially unloaded.
300    except AttributeError:
301      pass  # 'NoneType' object has no attribute 'eager_mode' when context has
302      # been unloaded. Will catch other module unloads as well.
303
304
305def shape_safe_assign_variable_handle(handle, shape, value, name=None):
306  """Helper that checks shape compatibility and assigns variable."""
307  with _handle_graph(handle):
308    value_tensor = ops.convert_to_tensor(value)
309  shape.assert_is_compatible_with(value_tensor.shape)
310  return gen_resource_variable_ops.assign_variable_op(
311      handle, value_tensor, name=name)
312
313
314def _maybe_set_handle_data(dtype, handle, tensor):
315  if dtype == dtypes.variant:
316    # For DT_VARIANT types, the handle's shape_and_type[1:] stores the
317    # variant's handle data.  Extract it.
318    handle_data = get_eager_safe_handle_data(handle)
319    if handle_data.is_set and len(handle_data.shape_and_type) > 1:
320      tensor._handle_data = (  # pylint: disable=protected-access
321          cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData(
322              is_set=True, shape_and_type=handle_data.shape_and_type[1:]))
323
324
325def variable_accessed(variable):
326  """Records that `variable` was accessed for the tape and FuncGraph."""
327  if hasattr(ops.get_default_graph(), "watch_variable"):
328    ops.get_default_graph().watch_variable(variable)
329  if variable.trainable:
330    tape.variable_accessed(variable)
331
332
333class BaseResourceVariable(variables.VariableV1, core.Tensor):
334  """A python variable from an existing handle."""
335
336  # TODO(wangpeng): Deprecate `constraint` when callers no long pass it in.
337  def __init__(  # pylint: disable=super-init-not-called
338      self,
339      trainable=None,
340      shape=None,
341      dtype=None,
342      handle=None,
343      constraint=None,
344      synchronization=None,
345      aggregation=None,
346      distribute_strategy=None,
347      name=None,
348      unique_id=None,
349      handle_name=None,
350      graph_element=None,
351      initial_value=None,
352      initializer_op=None,
353      is_initialized_op=None,
354      cached_value=None,
355      save_slice_info=None,
356      handle_deleter=None,
357      caching_device=None,
358      in_graph_mode=None,
359      **unused_kwargs):
360    """Creates a variable from a handle.
361
362    Args:
363      trainable: If `True`, GradientTapes automatically watch uses of this
364        Variable.
365      shape: The variable's shape.
366      dtype: The variable's dtype.
367      handle: The variable's handle
368      constraint: An optional projection function to be applied to the variable
369        after being updated by an `Optimizer` (e.g. used to implement norm
370        constraints or value constraints for layer weights). The function must
371        take as input the unprojected Tensor representing the value of the
372        variable and return the Tensor for the projected value (which must have
373        the same shape). Constraints are not safe to use when doing asynchronous
374        distributed training.
375      synchronization: Indicates when a distributed a variable will be
376        aggregated. Accepted values are constants defined in the class
377        `tf.VariableSynchronization`. By default the synchronization is set to
378        `AUTO` and the current `DistributionStrategy` chooses when to
379        synchronize.
380      aggregation: Indicates how a distributed variable will be aggregated.
381        Accepted values are constants defined in the class
382        `tf.VariableAggregation`.
383      distribute_strategy: The distribution strategy this variable was created
384        under.
385      name: The name for this variable.
386      unique_id: Internal. Unique ID for this variable's handle.
387      handle_name: The name for the variable's handle.
388      graph_element: Optional, required only in session.run-mode. Pre-created
389        tensor which reads this variable's value.
390      initial_value: Optional. Variable's initial value.
391      initializer_op: Operation which assigns the variable's initial value.
392      is_initialized_op: Pre-created operation to check whether this variable is
393        initialized.
394      cached_value: Pre-created operation to read this variable in a specific
395        device.
396      save_slice_info: Metadata for variable partitioning.
397      handle_deleter: EagerResourceDeleter responsible for cleaning up the
398        handle.
399      caching_device: Optional device string or function describing where the
400        Variable should be cached for reading.  Defaults to the Variable's
401        device.  If not `None`, caches on another device.  Typical use is to
402        cache on the device where the Ops using the Variable reside, to
403        deduplicate copying through `Switch` and other conditional statements.
404      in_graph_mode: whether we are executing in TF1 graph mode. If None, will
405        detect within the function. This is to avoid repeated init_scope()
406        conetxt entrances which can add up.
407    """
408    if in_graph_mode is None:
409      with ops.init_scope():
410        self._in_graph_mode = not context.executing_eagerly()
411    else:
412      self._in_graph_mode = in_graph_mode
413    synchronization, aggregation, trainable = (
414        variables.validate_synchronization_aggregation_trainable(
415            synchronization, aggregation, trainable, name))
416    self._trainable = trainable
417    self._synchronization = synchronization
418    self._aggregation = aggregation
419    self._save_slice_info = save_slice_info
420    self._initial_value = initial_value
421    self._initializer_op = initializer_op
422    self._is_initialized_op = is_initialized_op
423    self._graph_element = graph_element
424    self._caching_device = caching_device
425    self._cached_value = cached_value
426    self._distribute_strategy = distribute_strategy
427    # Store the graph key so optimizers know how to only retrieve variables from
428    # this graph. Guaranteed to be the same as the eager graph_key.
429    self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
430    self._shape = tensor_shape.as_shape(shape)
431    self._dtype = dtypes.as_dtype(dtype)
432    self._handle = handle
433    self._unique_id = unique_id
434    self._handle_name = handle_name + ":0"
435    self._constraint = constraint
436    # After the handle has been created, set up a way to clean it up when
437    # executing eagerly. We'll hold the only reference to the deleter, so that
438    # when this object is garbage collected the deleter will be too. This
439    # means ResourceVariables can be part of reference cycles without those
440    # cycles being uncollectable.
441    if not self._in_graph_mode:
442      if handle_deleter is None:
443        handle_deleter = EagerResourceDeleter(
444            handle=self._handle, handle_device=self._handle.device)
445    self._handle_deleter = handle_deleter
446    self._cached_shape_as_list = None
447
448  def __repr__(self):
449    if context.executing_eagerly() and not self._in_graph_mode:
450      # If we cannot read the value for any reason, still produce a __repr__.
451      try:
452        value_text = ops.numpy_text(self.read_value(), is_repr=True)
453      except:  # pylint: disable=bare-except
454        value_text = "<unavailable>"
455
456      return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % (
457          self.name, self.get_shape(), self.dtype.name, value_text)
458    else:
459      return "<tf.Variable '%s' shape=%s dtype=%s>" % (
460          self.name, self.get_shape(), self.dtype.name)
461
462  @contextlib.contextmanager
463  def _assign_dependencies(self):
464    """Makes assignments depend on the cached value, if any.
465
466    This prevents undefined behavior with reads not ordered wrt writes.
467
468    Yields:
469      None.
470    """
471    if self._cached_value is not None:
472      with ops.control_dependencies([self._cached_value]):
473        yield
474    else:
475      yield
476
477  def __array__(self):
478    """Allows direct conversion to a numpy array.
479
480    >>> np.array(tf.Variable([1.0]))
481    array([1.], dtype=float32)
482
483    Returns:
484      The variable value as a numpy array.
485    """
486    # You can't return `self.numpy()` here because for scalars
487    # that raises:
488    #     ValueError: object __array__ method not producing an array
489    # Even `self.read_value().__array__()` and `self.read_value()._numpy()` give
490    # the same error. The `EagerTensor` class must be doing something behind the
491    # scenes to make `np.array(tf.constant(1))` work.
492    return np.asarray(self.numpy())
493
494  def __nonzero__(self):
495    return self.__bool__()
496
497  def __bool__(self):
498    return bool(self.read_value())
499
500  def __copy__(self):
501    return self
502
503  def __deepcopy__(self, memo):
504    if not context.executing_eagerly():
505      raise NotImplementedError(
506          "__deepcopy__() is only available when eager execution is enabled.")
507    copied_variable = ResourceVariable(
508        initial_value=self.read_value(),
509        trainable=self._trainable,
510        constraint=self._constraint,
511        dtype=self._dtype,
512        name=self._shared_name,
513        distribute_strategy=self._distribute_strategy,
514        synchronization=self.synchronization,
515        aggregation=self.aggregation)
516    memo[self._unique_id] = copied_variable
517    return copied_variable
518
519  @property
520  def dtype(self):
521    """The dtype of this variable."""
522    return self._dtype
523
524  @property
525  def device(self):
526    """The device this variable is on."""
527    return self.handle.device
528
529  @property
530  def graph(self):
531    """The `Graph` of this variable."""
532    return self.handle.graph
533
534  @property
535  def name(self):
536    """The name of the handle for this variable."""
537    return self._handle_name
538
539  @property
540  def shape(self):
541    """The shape of this variable."""
542    return self._shape
543
544  def set_shape(self, shape):
545    self._shape = self._shape.merge_with(shape)
546
547  def _shape_as_list(self):
548    if self.shape.ndims is None:
549      return None
550    return [dim.value for dim in self.shape.dims]
551
552  def _shape_tuple(self):
553    shape = self._shape_as_list()
554    if shape is None:
555      return None
556    return tuple(shape)
557
558  @property
559  def create(self):
560    """The op responsible for initializing this variable."""
561    if not self._in_graph_mode:
562      raise RuntimeError("This operation is not supported "
563                         "when eager execution is enabled.")
564    return self._initializer_op
565
566  @property
567  def handle(self):
568    """The handle by which this variable can be accessed."""
569    return self._handle
570
571  def value(self):
572    """A cached operation which reads the value of this variable."""
573    if self._cached_value is not None:
574      return self._cached_value
575    with ops.colocate_with(None, ignore_existing=True):
576      return self._read_variable_op()
577
578  def _as_graph_element(self):
579    """Conversion function for Graph.as_graph_element()."""
580    return self._graph_element
581
582  @property
583  def initializer(self):
584    """The op responsible for initializing this variable."""
585    return self._initializer_op
586
587  @property
588  def initial_value(self):
589    """Returns the Tensor used as the initial value for the variable."""
590    if context.executing_eagerly():
591      raise RuntimeError("This property is not supported "
592                         "when eager execution is enabled.")
593    return self._initial_value
594
595  @property
596  def constraint(self):
597    """Returns the constraint function associated with this variable.
598
599    Returns:
600      The constraint function that was passed to the variable constructor.
601      Can be `None` if no constraint was passed.
602    """
603    return self._constraint
604
605  @property
606  def op(self):
607    """The op for this variable."""
608    return self.handle.op
609
610  @property
611  def trainable(self):
612    return self._trainable
613
614  @property
615  def synchronization(self):
616    return self._synchronization
617
618  @property
619  def aggregation(self):
620    return self._aggregation
621
622  def eval(self, session=None):
623    """Evaluates and returns the value of this variable."""
624    if context.executing_eagerly():
625      raise RuntimeError("This operation is not supported "
626                         "when eager execution is enabled.")
627    return self._graph_element.eval(session=session)
628
629  def numpy(self):
630    if context.executing_eagerly():
631      return self.read_value().numpy()
632    raise NotImplementedError(
633        "numpy() is only available when eager execution is enabled.")
634
635  @deprecated(None, "Prefer Dataset.range instead.")
636  def count_up_to(self, limit):
637    """Increments this variable until it reaches `limit`.
638
639    When that Op is run it tries to increment the variable by `1`. If
640    incrementing the variable would bring it above `limit` then the Op raises
641    the exception `OutOfRangeError`.
642
643    If no error is raised, the Op outputs the value of the variable before
644    the increment.
645
646    This is essentially a shortcut for `count_up_to(self, limit)`.
647
648    Args:
649      limit: value at which incrementing the variable raises an error.
650
651    Returns:
652      A `Tensor` that will hold the variable value before the increment. If no
653      other Op modifies this variable, the values produced will all be
654      distinct.
655    """
656    return gen_state_ops.resource_count_up_to(
657        self.handle, limit=limit, T=self.dtype)
658
659  def _map_resources(self, save_options):
660    """For implementing `Trackable`."""
661    new_variable = None
662    if save_options.experimental_variable_policy._save_variable_devices():  # pylint:disable=protected-access
663      with ops.device(self.device):
664        new_variable = copy_to_graph_uninitialized(self)
665    else:
666      new_variable = copy_to_graph_uninitialized(self)
667    obj_map = {self: new_variable}
668    resource_map = {self.handle: new_variable.handle}
669    return obj_map, resource_map
670
671  def _read_variable_op(self):
672    variable_accessed(self)
673
674    def read_and_set_handle():
675      result = gen_resource_variable_ops.read_variable_op(
676          self.handle, self._dtype)
677      _maybe_set_handle_data(self._dtype, self.handle, result)
678      return result
679
680    if getattr(self, "_caching_device", None) is not None:
681      with ops.colocate_with(None, ignore_existing=True):
682        with ops.device(self._caching_device):
683          result = read_and_set_handle()
684    else:
685      result = read_and_set_handle()
686
687    if not context.executing_eagerly():
688      # Note that if a control flow context is active the input of the read op
689      # might not actually be the handle. This line bypasses it.
690      tape.record_operation(
691          "ReadVariableOp", [result], [self.handle],
692          backward_function=lambda x: [x],
693          forward_function=lambda x: [x])
694    return result
695
696  def read_value(self):
697    """Constructs an op which reads the value of this variable.
698
699    Should be used when there are multiple reads, or when it is desirable to
700    read the value only after some condition is true.
701
702    Returns:
703     the read operation.
704    """
705    with ops.name_scope("Read"):
706      value = self._read_variable_op()
707    # Return an identity so it can get placed on whatever device the context
708    # specifies instead of the device where the variable is.
709    return array_ops.identity(value)
710
711  def sparse_read(self, indices, name=None):
712    """Reads the value of this variable sparsely, using `gather`."""
713    with ops.name_scope("Gather" if name is None else name) as name:
714      variable_accessed(self)
715      value = gen_resource_variable_ops.resource_gather(
716          self.handle, indices, dtype=self._dtype, name=name)
717
718      if self._dtype == dtypes.variant:
719        # For DT_VARIANT types, the handle's shape_and_type[1:] stores the
720        # variant's handle data.  Extract it.
721        handle_data = get_eager_safe_handle_data(self.handle)
722        if handle_data.is_set and len(handle_data.shape_and_type) > 1:
723          value._handle_data = (  # pylint: disable=protected-access
724              cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData(
725                  is_set=True, shape_and_type=handle_data.shape_and_type[1:]))
726
727    return array_ops.identity(value)
728
729  def gather_nd(self, indices, name=None):
730    """Reads the value of this variable sparsely, using `gather_nd`."""
731    with ops.name_scope("GatherNd" if name is None else name) as name:
732      if self.trainable:
733        variable_accessed(self)
734      value = gen_resource_variable_ops.resource_gather_nd(
735          self.handle, indices, dtype=self._dtype, name=name)
736
737    return array_ops.identity(value)
738
739  def to_proto(self, export_scope=None):
740    """Converts a `ResourceVariable` to a `VariableDef` protocol buffer.
741
742    Args:
743      export_scope: Optional `string`. Name scope to remove.
744
745    Raises:
746      RuntimeError: If run in EAGER mode.
747
748    Returns:
749      A `VariableDef` protocol buffer, or `None` if the `Variable` is not
750      in the specified name scope.
751    """
752    if context.executing_eagerly():
753      raise RuntimeError("This operation is not supported "
754                         "when eager execution is enabled.")
755    if export_scope is None or self.handle.name.startswith(export_scope):
756      var_def = variable_pb2.VariableDef()
757      var_def.variable_name = ops.strip_name_scope(self.handle.name,
758                                                   export_scope)
759      if self._initial_value is not None:
760        # This is inside an if-statement for backwards compatibility, since
761        # self._initial_value might be None for variables constructed from old
762        # protos.
763        var_def.initial_value_name = ops.strip_name_scope(
764            self._initial_value.name, export_scope)
765      var_def.initializer_name = ops.strip_name_scope(self.initializer.name,
766                                                      export_scope)
767      if self._cached_value is not None:
768        var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name,
769                                                     export_scope)
770      else:
771        # Store the graph_element here
772        var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name,
773                                                     export_scope)
774      var_def.is_resource = True
775      var_def.trainable = self.trainable
776      var_def.synchronization = self.synchronization.value
777      var_def.aggregation = self.aggregation.value
778      if self._save_slice_info:
779        var_def.save_slice_info_def.MergeFrom(
780            self._save_slice_info.to_proto(export_scope=export_scope))
781      return var_def
782    else:
783      return None
784
785  @staticmethod
786  def from_proto(variable_def, import_scope=None):
787    if context.executing_eagerly():
788      raise RuntimeError("This operation is not supported "
789                         "when eager execution is enabled.")
790    return ResourceVariable(
791        variable_def=variable_def, import_scope=import_scope)
792
793  __array_priority__ = 100
794
795  def is_initialized(self, name=None):
796    """Checks whether a resource variable has been initialized.
797
798    Outputs boolean scalar indicating whether the tensor has been initialized.
799
800    Args:
801      name: A name for the operation (optional).
802
803    Returns:
804      A `Tensor` of type `bool`.
805    """
806    return gen_resource_variable_ops.var_is_initialized_op(self.handle, name)
807
808  def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
809    """Subtracts a value from this variable.
810
811    Args:
812      delta: A `Tensor`. The value to subtract from this variable.
813      use_locking: If `True`, use locking during the operation.
814      name: The name to use for the operation.
815      read_value: A `bool`. Whether to read and return the new value of the
816        variable or not.
817
818    Returns:
819      If `read_value` is `True`, this method will return the new value of the
820      variable after the assignment has completed. Otherwise, when in graph mode
821      it will return the `Operation` that does the assignment, and when in eager
822      mode it will return `None`.
823    """
824    # TODO(apassos): this here and below is not atomic. Consider making it
825    # atomic if there's a way to do so without a performance cost for those who
826    # don't need it.
827    with _handle_graph(self.handle), self._assign_dependencies():
828      assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(
829          self.handle,
830          ops.convert_to_tensor(delta, dtype=self.dtype),
831          name=name)
832    if read_value:
833      return self._lazy_read(assign_sub_op)
834    return assign_sub_op
835
836  def assign_add(self, delta, use_locking=None, name=None, read_value=True):
837    """Adds a value to this variable.
838
839    Args:
840      delta: A `Tensor`. The value to add to this variable.
841      use_locking: If `True`, use locking during the operation.
842      name: The name to use for the operation.
843      read_value: A `bool`. Whether to read and return the new value of the
844        variable or not.
845
846    Returns:
847      If `read_value` is `True`, this method will return the new value of the
848      variable after the assignment has completed. Otherwise, when in graph mode
849      it will return the `Operation` that does the assignment, and when in eager
850      mode it will return `None`.
851    """
852    with _handle_graph(self.handle), self._assign_dependencies():
853      assign_add_op = gen_resource_variable_ops.assign_add_variable_op(
854          self.handle,
855          ops.convert_to_tensor(delta, dtype=self.dtype),
856          name=name)
857    if read_value:
858      return self._lazy_read(assign_add_op)
859    return assign_add_op
860
861  def _lazy_read(self, op):
862    variable_accessed(self)
863    return _UnreadVariable(
864        handle=self.handle,
865        dtype=self.dtype,
866        shape=self._shape,
867        in_graph_mode=self._in_graph_mode,
868        deleter=self._handle_deleter if not self._in_graph_mode else None,
869        parent_op=op,
870        unique_id=self._unique_id)
871
872  def assign(self, value, use_locking=None, name=None, read_value=True):
873    """Assigns a new value to this variable.
874
875    Args:
876      value: A `Tensor`. The new value for this variable.
877      use_locking: If `True`, use locking during the assignment.
878      name: The name to use for the assignment.
879      read_value: A `bool`. Whether to read and return the new value of the
880        variable or not.
881
882    Returns:
883      If `read_value` is `True`, this method will return the new value of the
884      variable after the assignment has completed. Otherwise, when in graph mode
885      it will return the `Operation` that does the assignment, and when in eager
886      mode it will return `None`.
887    """
888    # Note: not depending on the cached value here since this can be used to
889    # initialize the variable.
890    with _handle_graph(self.handle):
891      value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
892      if not self._shape.is_compatible_with(value_tensor.shape):
893        if self.name is None:
894          tensor_name = ""
895        else:
896          tensor_name = " " + str(self.name)
897        raise ValueError(
898            (f"Cannot assign value to variable '{tensor_name}': Shape mismatch."
899             f"The variable shape {self._shape}, and the "
900             f"assigned value shape {value_tensor.shape} are incompatible."))
901      assign_op = gen_resource_variable_ops.assign_variable_op(
902          self.handle, value_tensor, name=name)
903      if read_value:
904        return self._lazy_read(assign_op)
905    return assign_op
906
907  def __reduce__(self):
908    # The implementation mirrors that of __deepcopy__.
909    return functools.partial(
910        ResourceVariable,
911        initial_value=self.numpy(),
912        trainable=self.trainable,
913        name=self._shared_name,
914        dtype=self.dtype,
915        constraint=self.constraint,
916        distribute_strategy=self._distribute_strategy), ()
917
918  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
919    """Subtracts `tf.IndexedSlices` from this variable.
920
921    Args:
922      sparse_delta: `tf.IndexedSlices` to be subtracted from this variable.
923      use_locking: If `True`, use locking during the operation.
924      name: the name of the operation.
925
926    Returns:
927      The updated variable.
928
929    Raises:
930      TypeError: if `sparse_delta` is not an `IndexedSlices`.
931    """
932    if not isinstance(sparse_delta, ops.IndexedSlices):
933      raise TypeError(f"Argument `sparse_delta` must be a "
934                      f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
935    return self._lazy_read(
936        gen_resource_variable_ops.resource_scatter_sub(
937            self.handle,
938            sparse_delta.indices,
939            ops.convert_to_tensor(sparse_delta.values, self.dtype),
940            name=name))
941
942  def scatter_add(self, sparse_delta, use_locking=False, name=None):
943    """Adds `tf.IndexedSlices` to this variable.
944
945    Args:
946      sparse_delta: `tf.IndexedSlices` to be added to this variable.
947      use_locking: If `True`, use locking during the operation.
948      name: the name of the operation.
949
950    Returns:
951      The updated variable.
952
953    Raises:
954      TypeError: if `sparse_delta` is not an `IndexedSlices`.
955    """
956    if not isinstance(sparse_delta, ops.IndexedSlices):
957      raise TypeError(f"Argument `sparse_delta` must be a "
958                      f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
959    return self._lazy_read(
960        gen_resource_variable_ops.resource_scatter_add(
961            self.handle,
962            sparse_delta.indices,
963            ops.convert_to_tensor(sparse_delta.values, self.dtype),
964            name=name))
965
966  def scatter_max(self, sparse_delta, use_locking=False, name=None):
967    """Updates this variable with the max of `tf.IndexedSlices` and itself.
968
969    Args:
970      sparse_delta: `tf.IndexedSlices` to use as an argument of max with this
971        variable.
972      use_locking: If `True`, use locking during the operation.
973      name: the name of the operation.
974
975    Returns:
976      The updated variable.
977
978    Raises:
979      TypeError: if `sparse_delta` is not an `IndexedSlices`.
980    """
981    if not isinstance(sparse_delta, ops.IndexedSlices):
982      raise TypeError(f"Argument `sparse_delta` must be a "
983                      f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
984    return self._lazy_read(
985        gen_resource_variable_ops.resource_scatter_max(
986            self.handle,
987            sparse_delta.indices,
988            ops.convert_to_tensor(sparse_delta.values, self.dtype),
989            name=name))
990
991  def scatter_min(self, sparse_delta, use_locking=False, name=None):
992    """Updates this variable with the min of `tf.IndexedSlices` and itself.
993
994    Args:
995      sparse_delta: `tf.IndexedSlices` to use as an argument of min with this
996        variable.
997      use_locking: If `True`, use locking during the operation.
998      name: the name of the operation.
999
1000    Returns:
1001      The updated variable.
1002
1003    Raises:
1004      TypeError: if `sparse_delta` is not an `IndexedSlices`.
1005    """
1006    if not isinstance(sparse_delta, ops.IndexedSlices):
1007      raise TypeError(f"Argument `sparse_delta` must be a "
1008                      f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
1009    return self._lazy_read(
1010        gen_resource_variable_ops.resource_scatter_min(
1011            self.handle,
1012            sparse_delta.indices,
1013            ops.convert_to_tensor(sparse_delta.values, self.dtype),
1014            name=name))
1015
1016  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
1017    """Multiply this variable by `tf.IndexedSlices`.
1018
1019    Args:
1020      sparse_delta: `tf.IndexedSlices` to multiply this variable by.
1021      use_locking: If `True`, use locking during the operation.
1022      name: the name of the operation.
1023
1024    Returns:
1025      The updated variable.
1026
1027    Raises:
1028      TypeError: if `sparse_delta` is not an `IndexedSlices`.
1029    """
1030    if not isinstance(sparse_delta, ops.IndexedSlices):
1031      raise TypeError(f"Argument `sparse_delta` must be a "
1032                      f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
1033    return self._lazy_read(
1034        gen_resource_variable_ops.resource_scatter_mul(
1035            self.handle,
1036            sparse_delta.indices,
1037            ops.convert_to_tensor(sparse_delta.values, self.dtype),
1038            name=name))
1039
1040  def scatter_div(self, sparse_delta, use_locking=False, name=None):
1041    """Divide this variable by `tf.IndexedSlices`.
1042
1043    Args:
1044      sparse_delta: `tf.IndexedSlices` to divide this variable by.
1045      use_locking: If `True`, use locking during the operation.
1046      name: the name of the operation.
1047
1048    Returns:
1049      The updated variable.
1050
1051    Raises:
1052      TypeError: if `sparse_delta` is not an `IndexedSlices`.
1053    """
1054    if not isinstance(sparse_delta, ops.IndexedSlices):
1055      raise TypeError(f"Argument `sparse_delta` must be a "
1056                      f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
1057    return self._lazy_read(
1058        gen_resource_variable_ops.resource_scatter_div(
1059            self.handle,
1060            sparse_delta.indices,
1061            ops.convert_to_tensor(sparse_delta.values, self.dtype),
1062            name=name))
1063
1064  def scatter_update(self, sparse_delta, use_locking=False, name=None):
1065    """Assigns `tf.IndexedSlices` to this variable.
1066
1067    Args:
1068      sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
1069      use_locking: If `True`, use locking during the operation.
1070      name: the name of the operation.
1071
1072    Returns:
1073      The updated variable.
1074
1075    Raises:
1076      TypeError: if `sparse_delta` is not an `IndexedSlices`.
1077    """
1078    if not isinstance(sparse_delta, ops.IndexedSlices):
1079      raise TypeError(f"Argument `sparse_delta` must be a "
1080                      f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
1081    return self._lazy_read(
1082        gen_resource_variable_ops.resource_scatter_update(
1083            self.handle,
1084            sparse_delta.indices,
1085            ops.convert_to_tensor(sparse_delta.values, self.dtype),
1086            name=name))
1087
1088  def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
1089    """Assigns `tf.IndexedSlices` to this variable batch-wise.
1090
1091    Analogous to `batch_gather`. This assumes that this variable and the
1092    sparse_delta IndexedSlices have a series of leading dimensions that are the
1093    same for all of them, and the updates are performed on the last dimension of
1094    indices. In other words, the dimensions should be the following:
1095
1096    `num_prefix_dims = sparse_delta.indices.ndims - 1`
1097    `batch_dim = num_prefix_dims + 1`
1098    `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[
1099         batch_dim:]`
1100
1101    where
1102
1103    `sparse_delta.updates.shape[:num_prefix_dims]`
1104    `== sparse_delta.indices.shape[:num_prefix_dims]`
1105    `== var.shape[:num_prefix_dims]`
1106
1107    And the operation performed can be expressed as:
1108
1109    `var[i_1, ..., i_n,
1110         sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[
1111            i_1, ..., i_n, j]`
1112
1113    When sparse_delta.indices is a 1D tensor, this operation is equivalent to
1114    `scatter_update`.
1115
1116    To avoid this operation one can looping over the first `ndims` of the
1117    variable and using `scatter_update` on the subtensors that result of slicing
1118    the first dimension. This is a valid option for `ndims = 1`, but less
1119    efficient than this implementation.
1120
1121    Args:
1122      sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
1123      use_locking: If `True`, use locking during the operation.
1124      name: the name of the operation.
1125
1126    Returns:
1127      The updated variable.
1128
1129    Raises:
1130      TypeError: if `sparse_delta` is not an `IndexedSlices`.
1131    """
1132    if not isinstance(sparse_delta, ops.IndexedSlices):
1133      raise TypeError(f"Argument `sparse_delta` must be a "
1134                      f"`tf.IndexedSlices`. Received arg: {sparse_delta}")
1135    return self._lazy_read(
1136        state_ops.batch_scatter_update(
1137            self,
1138            sparse_delta.indices,
1139            sparse_delta.values,
1140            use_locking=use_locking,
1141            name=name))
1142
1143  def scatter_nd_sub(self, indices, updates, name=None):
1144    """Applies sparse subtraction to individual values or slices in a Variable.
1145
1146    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1147
1148    `indices` must be integer tensor, containing indices into `ref`.
1149    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1150
1151    The innermost dimension of `indices` (with length `K`) corresponds to
1152    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1153    dimension of `ref`.
1154
1155    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1156
1157    ```
1158    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1159    ```
1160
1161    For example, say we want to add 4 scattered elements to a rank-1 tensor to
1162    8 elements. In Python, that update would look like this:
1163
1164    ```python
1165        ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
1166        indices = tf.constant([[4], [3], [1] ,[7]])
1167        updates = tf.constant([9, 10, 11, 12])
1168        op = ref.scatter_nd_sub(indices, updates)
1169        with tf.compat.v1.Session() as sess:
1170          print sess.run(op)
1171    ```
1172
1173    The resulting update to ref would look like this:
1174
1175        [1, -9, 3, -6, -6, 6, 7, -4]
1176
1177    See `tf.scatter_nd` for more details about how to make updates to
1178    slices.
1179
1180    Args:
1181      indices: The indices to be used in the operation.
1182      updates: The values to be used in the operation.
1183      name: the name of the operation.
1184
1185    Returns:
1186      The updated variable.
1187    """
1188    return self._lazy_read(
1189        gen_state_ops.resource_scatter_nd_sub(
1190            self.handle,
1191            indices,
1192            ops.convert_to_tensor(updates, self.dtype),
1193            name=name))
1194
1195  def scatter_nd_add(self, indices, updates, name=None):
1196    """Applies sparse addition to individual values or slices in a Variable.
1197
1198    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1199
1200    `indices` must be integer tensor, containing indices into `ref`.
1201    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1202
1203    The innermost dimension of `indices` (with length `K`) corresponds to
1204    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1205    dimension of `ref`.
1206
1207    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1208
1209    ```
1210    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1211    ```
1212
1213    For example, say we want to add 4 scattered elements to a rank-1 tensor to
1214    8 elements. In Python, that update would look like this:
1215
1216    ```python
1217        ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
1218        indices = tf.constant([[4], [3], [1] ,[7]])
1219        updates = tf.constant([9, 10, 11, 12])
1220        add = ref.scatter_nd_add(indices, updates)
1221        with tf.compat.v1.Session() as sess:
1222          print sess.run(add)
1223    ```
1224
1225    The resulting update to ref would look like this:
1226
1227        [1, 13, 3, 14, 14, 6, 7, 20]
1228
1229    See `tf.scatter_nd` for more details about how to make updates to
1230    slices.
1231
1232    Args:
1233      indices: The indices to be used in the operation.
1234      updates: The values to be used in the operation.
1235      name: the name of the operation.
1236
1237    Returns:
1238      The updated variable.
1239    """
1240    return self._lazy_read(
1241        gen_state_ops.resource_scatter_nd_add(
1242            self.handle,
1243            indices,
1244            ops.convert_to_tensor(updates, self.dtype),
1245            name=name))
1246
1247  def scatter_nd_update(self, indices, updates, name=None):
1248    """Applies sparse assignment to individual values or slices in a Variable.
1249
1250    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1251
1252    `indices` must be integer tensor, containing indices into `ref`.
1253    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1254
1255    The innermost dimension of `indices` (with length `K`) corresponds to
1256    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1257    dimension of `ref`.
1258
1259    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1260
1261    ```
1262    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1263    ```
1264
1265    For example, say we want to add 4 scattered elements to a rank-1 tensor to
1266    8 elements. In Python, that update would look like this:
1267
1268    ```python
1269        ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
1270        indices = tf.constant([[4], [3], [1] ,[7]])
1271        updates = tf.constant([9, 10, 11, 12])
1272        op = ref.scatter_nd_update(indices, updates)
1273        with tf.compat.v1.Session() as sess:
1274          print sess.run(op)
1275    ```
1276
1277    The resulting update to ref would look like this:
1278
1279        [1, 11, 3, 10, 9, 6, 7, 12]
1280
1281    See `tf.scatter_nd` for more details about how to make updates to
1282    slices.
1283
1284    Args:
1285      indices: The indices to be used in the operation.
1286      updates: The values to be used in the operation.
1287      name: the name of the operation.
1288
1289    Returns:
1290      The updated variable.
1291    """
1292    return self._lazy_read(
1293        gen_state_ops.resource_scatter_nd_update(
1294            self.handle,
1295            indices,
1296            ops.convert_to_tensor(updates, self.dtype),
1297            name=name))
1298
1299  def scatter_nd_max(self, indices, updates, name=None):
1300    """Updates this variable with the max of `tf.IndexedSlices` and itself.
1301
1302    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1303
1304    `indices` must be integer tensor, containing indices into `ref`.
1305    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1306
1307    The innermost dimension of `indices` (with length `K`) corresponds to
1308    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1309    dimension of `ref`.
1310
1311    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1312
1313    ```
1314    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1315    ```
1316
1317    See `tf.scatter_nd` for more details about how to make updates to
1318    slices.
1319
1320    Args:
1321      indices: The indices to be used in the operation.
1322      updates: The values to be used in the operation.
1323      name: the name of the operation.
1324
1325    Returns:
1326      The updated variable.
1327    """
1328    return self._lazy_read(
1329        gen_state_ops.resource_scatter_nd_max(
1330            self.handle,
1331            indices,
1332            ops.convert_to_tensor(updates, self.dtype),
1333            name=name))
1334
1335  def scatter_nd_min(self, indices, updates, name=None):
1336    """Updates this variable with the min of `tf.IndexedSlices` and itself.
1337
1338    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1339
1340    `indices` must be integer tensor, containing indices into `ref`.
1341    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1342
1343    The innermost dimension of `indices` (with length `K`) corresponds to
1344    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1345    dimension of `ref`.
1346
1347    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1348
1349    ```
1350    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1351    ```
1352
1353    See `tf.scatter_nd` for more details about how to make updates to
1354    slices.
1355
1356    Args:
1357      indices: The indices to be used in the operation.
1358      updates: The values to be used in the operation.
1359      name: the name of the operation.
1360
1361    Returns:
1362      The updated variable.
1363    """
1364    return self._lazy_read(
1365        gen_state_ops.resource_scatter_nd_min(
1366            self.handle,
1367            indices,
1368            ops.convert_to_tensor(updates, self.dtype),
1369            name=name))
1370
1371  def _write_object_proto(self, proto, options):
1372    """Writes additional information of the variable into the SavedObject proto.
1373
1374    Subclasses of ResourceVariables could choose to override this method to
1375    customize extra information to provide when saving a SavedModel.
1376
1377    Ideally, this should contain the logic in
1378    write_object_proto_for_resource_variable but `DistributedValue` is an
1379    outlier at the momemnt. Once `DistributedValue` becomes a proper
1380    ResourceVariable, we should remove the helper method below.
1381
1382    Args:
1383      proto: `SavedObject` proto to update.
1384      options: A `SaveOption` instance that configures save behavior.
1385    """
1386    write_object_proto_for_resource_variable(self, proto, options)
1387
1388  def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask,
1389                            end_mask, ellipsis_mask, new_axis_mask,
1390                            shrink_axis_mask):
1391    with _handle_graph(self.handle), self._assign_dependencies():
1392      return self._lazy_read(
1393          gen_array_ops.resource_strided_slice_assign(
1394              ref=self.handle,
1395              begin=begin,
1396              end=end,
1397              strides=strides,
1398              value=ops.convert_to_tensor(value, dtype=self.dtype),
1399              name=name,
1400              begin_mask=begin_mask,
1401              end_mask=end_mask,
1402              ellipsis_mask=ellipsis_mask,
1403              new_axis_mask=new_axis_mask,
1404              shrink_axis_mask=shrink_axis_mask))
1405
1406  def __complex__(self):
1407    return complex(self.value().numpy())
1408
1409  def __int__(self):
1410    return int(self.value().numpy())
1411
1412  def __long__(self):
1413    return long(self.value().numpy())
1414
1415  def __float__(self):
1416    return float(self.value().numpy())
1417
1418  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
1419    del name
1420    if dtype is not None and not dtype.is_compatible_with(self.dtype):
1421      raise ValueError(
1422          f"Incompatible type conversion requested to type {dtype.name} for "
1423          f"`tf.Variable of type {self.dtype.name}. (Variable: {self})")
1424    if as_ref:
1425      return self.read_value().op.inputs[0]
1426    else:
1427      return self.value()
1428
1429  def __iadd__(self, unused_other):
1430    raise RuntimeError("`variable += value` with `tf.Variable`s is not "
1431                       "supported. Use `variable.assign_add(value)` to modify "
1432                       "the variable, or `out = variable + value` if you "
1433                       "need to get a new output Tensor.")
1434
1435  def __isub__(self, unused_other):
1436    raise RuntimeError("`variable -= value` with `tf.Variable`s is not "
1437                       "supported. Use `variable.assign_sub(value)` to modify "
1438                       "the variable, or `out = variable * value` if you "
1439                       "need to get a new output Tensor.")
1440
1441  def __imul__(self, unused_other):
1442    raise RuntimeError("`var *= value` with `tf.Variable`s is not "
1443                       "supported. Use `var.assign(var * value)` to modify "
1444                       "the variable, or `out = var * value` if you "
1445                       "need to get a new output Tensor.")
1446
1447  def __idiv__(self, unused_other):
1448    raise RuntimeError("`var /= value` with `tf.Variable`s is not "
1449                       "supported. Use `var.assign(var / value)` to modify "
1450                       "the variable, or `out = var / value` if you "
1451                       "need to get a new output Tensor.")
1452
1453  def __itruediv__(self, unused_other):
1454    raise RuntimeError("`var /= value` with `tf.Variable`s is not "
1455                       "supported. Use `var.assign(var / value)` to modify "
1456                       "the variable, or `out = var / value` if you "
1457                       "need to get a new output Tensor.")
1458
1459  def __irealdiv__(self, unused_other):
1460    raise RuntimeError("`var /= value` with `tf.Variable`s is not "
1461                       "supported. Use `var.assign(var / value)` to modify "
1462                       "the variable, or `out = var / value` if you "
1463                       "need to get a new output Tensor.")
1464
1465  def __ipow__(self, unused_other):
1466    raise RuntimeError("`var **= value` with `tf.Variable`s is not "
1467                       "supported. Use `var.assign(var ** value)` to modify "
1468                       "the variable, or `out = var ** value` if you "
1469                       "need to get a new output Tensor.")
1470
1471
1472class ResourceVariable(BaseResourceVariable):
1473  """Variable based on resource handles.
1474
1475  See the [Variables How To](https://tensorflow.org/guide/variables)
1476  for a high level overview.
1477
1478  A `ResourceVariable` allows you to maintain state across subsequent calls to
1479  session.run.
1480
1481  The `ResourceVariable` constructor requires an initial value for the variable,
1482  which can be a `Tensor` of any type and shape. The initial value defines the
1483  type and shape of the variable. After construction, the type and shape of
1484  the variable are fixed. The value can be changed using one of the assign
1485  methods.
1486
1487  Just like any `Tensor`, variables created with
1488  `tf.Variable(use_resource=True)` can be used as inputs for other Ops in the
1489  graph. Additionally, all the operators overloaded for the `Tensor` class are
1490  carried over to variables, so you can also add nodes to the graph by just
1491  doing arithmetic on variables.
1492
1493  Unlike ref-based variable, a ResourceVariable has well-defined semantics. Each
1494  usage of a ResourceVariable in a TensorFlow graph adds a read_value operation
1495  to the graph. The Tensors returned by a read_value operation are guaranteed to
1496  see all modifications to the value of the variable which happen in any
1497  operation on which the read_value depends on (either directly, indirectly, or
1498  via a control dependency) and guaranteed to not see any modification to the
1499  value of the variable from operations that depend on the read_value operation.
1500  Updates from operations that have no dependency relationship to the read_value
1501  operation might or might not be visible to read_value.
1502
1503  For example, if there is more than one assignment to a ResourceVariable in
1504  a single session.run call there is a well-defined value for each operation
1505  which uses the variable's value if the assignments and the read are connected
1506  by edges in the graph. Consider the following example, in which two writes
1507  can cause tf.Variable and tf.ResourceVariable to behave differently:
1508
1509  ```python
1510  a = tf.Variable(1.0, use_resource=True)
1511  a.initializer.run()
1512
1513  assign = a.assign(2.0)
1514  with tf.control_dependencies([assign]):
1515    b = a.read_value()
1516  with tf.control_dependencies([b]):
1517    other_assign = a.assign(3.0)
1518  with tf.control_dependencies([other_assign]):
1519    # Will print 2.0 because the value was read before other_assign ran. If
1520    # `a` was a tf.Variable instead, 2.0 or 3.0 could be printed.
1521    tf.compat.v1.Print(b, [b]).eval()
1522  ```
1523  """
1524
1525  def __init__(
1526      self,  # pylint: disable=super-init-not-called
1527      initial_value=None,
1528      trainable=None,
1529      collections=None,
1530      validate_shape=True,  # pylint: disable=unused-argument
1531      caching_device=None,
1532      name=None,
1533      dtype=None,
1534      variable_def=None,
1535      import_scope=None,
1536      constraint=None,
1537      distribute_strategy=None,
1538      synchronization=None,
1539      aggregation=None,
1540      shape=None):
1541    """Creates a variable.
1542
1543    Args:
1544      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
1545        which is the initial value for the Variable. Can also be a callable with
1546        no argument that returns the initial value when called. (Note that
1547        initializer functions from init_ops.py must first be bound to a shape
1548        before being used here.)
1549      trainable: If `True`, the default, also adds the variable to the graph
1550        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
1551        the default list of variables to use by the `Optimizer` classes.
1552        Defaults to `True`, unless `synchronization` is set to `ON_READ`, in
1553        which case it defaults to `False`.
1554      collections: List of graph collections keys. The new variable is added to
1555        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1556      validate_shape: Ignored. Provided for compatibility with tf.Variable.
1557      caching_device: Optional device string or function describing where the
1558        Variable should be cached for reading.  Defaults to the Variable's
1559        device.  If not `None`, caches on another device.  Typical use is to
1560        cache on the device where the Ops using the Variable reside, to
1561        deduplicate copying through `Switch` and other conditional statements.
1562      name: Optional name for the variable. Defaults to `'Variable'` and gets
1563        uniquified automatically.
1564      dtype: If set, initial_value will be converted to the given type. If None,
1565        either the datatype will be kept (if initial_value is a Tensor) or
1566        float32 will be used (if it is a Python object convertible to a Tensor).
1567      variable_def: `VariableDef` protocol buffer. If not None, recreates the
1568        `ResourceVariable` object with its contents. `variable_def` and other
1569        arguments (except for import_scope) are mutually exclusive.
1570      import_scope: Optional `string`. Name scope to add to the
1571        ResourceVariable. Only used when `variable_def` is provided.
1572      constraint: An optional projection function to be applied to the variable
1573        after being updated by an `Optimizer` (e.g. used to implement norm
1574        constraints or value constraints for layer weights). The function must
1575        take as input the unprojected Tensor representing the value of the
1576        variable and return the Tensor for the projected value (which must have
1577        the same shape). Constraints are not safe to use when doing asynchronous
1578        distributed training.
1579      distribute_strategy: The tf.distribute.Strategy this variable is being
1580        created inside of.
1581      synchronization: Indicates when a distributed a variable will be
1582        aggregated. Accepted values are constants defined in the class
1583        `tf.VariableSynchronization`. By default the synchronization is set to
1584        `AUTO` and the current `DistributionStrategy` chooses when to
1585        synchronize.
1586      aggregation: Indicates how a distributed variable will be aggregated.
1587        Accepted values are constants defined in the class
1588        `tf.VariableAggregation`.
1589      shape: (optional) The shape of this variable. If None, the shape of
1590        `initial_value` will be used. When setting this argument to
1591        `tf.TensorShape(None)` (representing an unspecified shape), the variable
1592        can be assigned with values of different shapes.
1593
1594    Raises:
1595      ValueError: If the initial value is not specified, or does not have a
1596        shape and `validate_shape` is `True`.
1597
1598    @compatibility(eager)
1599    When Eager Execution is enabled, the default for the `collections` argument
1600    is `None`, which signifies that this `Variable` will not be added to any
1601    collections.
1602    @end_compatibility
1603    """
1604    if variable_def:
1605      if initial_value is not None:
1606        raise ValueError(f"The variable_def and initial_value args to "
1607                         f"`tf.Variable` are mutually exclusive, but got both: "
1608                         f"variable_def={variable_def},\n"
1609                         f"initial_value={initial_value}")
1610      if context.executing_eagerly():
1611        raise ValueError(f"Creating a `tf.Variable` with a `variable_def` arg "
1612                         f"is not supported when eager execution is enabled. "
1613                         f"Got: variable_def={variable_def}")
1614      self._init_from_proto(variable_def, import_scope=import_scope)
1615    else:
1616      self._init_from_args(
1617          initial_value=initial_value,
1618          trainable=trainable,
1619          collections=collections,
1620          caching_device=caching_device,
1621          name=name,
1622          dtype=dtype,
1623          constraint=constraint,
1624          synchronization=synchronization,
1625          aggregation=aggregation,
1626          shape=shape,
1627          distribute_strategy=distribute_strategy)
1628
1629  def _init_from_args(self,
1630                      initial_value=None,
1631                      trainable=None,
1632                      collections=None,
1633                      caching_device=None,
1634                      name=None,
1635                      dtype=None,
1636                      constraint=None,
1637                      synchronization=None,
1638                      aggregation=None,
1639                      distribute_strategy=None,
1640                      shape=None):
1641    """Creates a variable.
1642
1643    Args:
1644      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
1645        which is the initial value for the Variable. The initial value must have
1646        a shape specified unless `validate_shape` is set to False. Can also be a
1647        callable with no argument that returns the initial value when called.
1648        (Note that initializer functions from init_ops.py must first be bound to
1649        a shape before being used here.)
1650      trainable: If `True`, the default, also adds the variable to the graph
1651        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
1652        the default list of variables to use by the `Optimizer` classes.
1653        Defaults to `True`, unless `synchronization` is set to `ON_READ`, in
1654        which case it defaults to `False`.
1655      collections: List of graph collections keys. The new variable is added to
1656        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1657      caching_device: Optional device string or function describing where the
1658        Variable should be cached for reading.  Defaults to the Variable's
1659        device.  If not `None`, caches on another device.  Typical use is to
1660        cache on the device where the Ops using the Variable reside, to
1661        deduplicate copying through `Switch` and other conditional statements.
1662      name: Optional name for the variable. Defaults to `'Variable'` and gets
1663        uniquified automatically.
1664      dtype: If set, initial_value will be converted to the given type. If None,
1665        either the datatype will be kept (if initial_value is a Tensor) or
1666        float32 will be used (if it is a Python object convertible to a Tensor).
1667      constraint: An optional projection function to be applied to the variable
1668        after being updated by an `Optimizer` (e.g. used to implement norm
1669        constraints or value constraints for layer weights). The function must
1670        take as input the unprojected Tensor representing the value of the
1671        variable and return the Tensor for the projected value (which must have
1672        the same shape). Constraints are not safe to use when doing asynchronous
1673        distributed training.
1674      synchronization: Indicates when a distributed a variable will be
1675        aggregated. Accepted values are constants defined in the class
1676        `tf.VariableSynchronization`. By default the synchronization is set to
1677        `AUTO` and the current `DistributionStrategy` chooses when to
1678        synchronize.
1679      aggregation: Indicates how a distributed variable will be aggregated.
1680        Accepted values are constants defined in the class
1681        `tf.VariableAggregation`.
1682      distribute_strategy: DistributionStrategy under which this variable was
1683        created.
1684      shape: (optional) The shape of this variable. If None, the shape of
1685        `initial_value` will be used. When setting this argument to
1686        `tf.TensorShape(None)` (representing an unspecified shape), the variable
1687        can be assigned with values of different shapes.
1688
1689    Raises:
1690      ValueError: If the initial value is not specified, or does not have a
1691        shape and `validate_shape` is `True`.
1692
1693    @compatibility(eager)
1694    When Eager Execution is enabled, variables are never added to collections.
1695    It is not implicitly added to the `GLOBAL_VARIABLES` or
1696    `TRAINABLE_VARIABLES` collections, and the `collections` argument is
1697    ignored.
1698    @end_compatibility
1699    """
1700    synchronization, aggregation, trainable = (
1701        variables.validate_synchronization_aggregation_trainable(
1702            synchronization, aggregation, trainable, name))
1703    if initial_value is None:
1704      raise ValueError("The `initial_value` arg to `tf.Variable` must "
1705                       "be specified except when you are not providing a "
1706                       "`variable_def`. You provided neither.")
1707    init_from_fn = callable(initial_value)
1708
1709    if isinstance(initial_value, ops.Tensor) and hasattr(
1710        initial_value, "graph") and initial_value.graph.building_function:
1711      raise ValueError(f"Argument `initial_value` ({initial_value}) could not "
1712                       "be lifted out of a `tf.function`. "
1713                       "(Tried to create variable with name='{name}'). "
1714                       "To avoid this error, when constructing `tf.Variable`s "
1715                       "inside of `tf.function` you can create the "
1716                       "`initial_value` tensor in a "
1717                       "`tf.init_scope` or pass a callable `initial_value` "
1718                       "(e.g., `tf.Variable(lambda : "
1719                       "tf.truncated_normal([10, 40]))`). "
1720                       "Please file a feature request if this "
1721                       "restriction inconveniences you.")
1722
1723    if collections is None:
1724      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
1725    if not isinstance(collections, (list, tuple, set)):
1726      raise ValueError(
1727          f"collections argument to Variable constructor must be a list, "
1728          f"tuple, or set. Got {collections} of type {type(collections)}")
1729    if constraint is not None and not callable(constraint):
1730      raise ValueError(f"Argument `constraint` must be None or a callable. "
1731                       f"a callable. Got a {type(constraint)}:  {constraint}")
1732
1733    if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
1734      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
1735    with ops.init_scope():
1736      self._in_graph_mode = not context.executing_eagerly()
1737      with ops.name_scope(
1738          name,
1739          "Variable", [] if init_from_fn else [initial_value],
1740          skip_on_eager=False) as name:
1741        # pylint: disable=protected-access
1742        handle_name = ops.name_from_scope_name(name)
1743        if self._in_graph_mode:
1744          shared_name = handle_name
1745          unique_id = shared_name
1746        else:
1747          # When in eager mode use a uid for the shared_name, to prevent
1748          # accidental sharing.
1749          unique_id = "%s_%d" % (handle_name, ops.uid())
1750          shared_name = None  # Never shared
1751        # Use attr_scope and device(None) to simulate the behavior of
1752        # colocate_with when the variable we want to colocate with doesn't
1753        # yet exist.
1754        device_context_manager = (
1755            ops.device if self._in_graph_mode else ops.NullContextmanager)
1756        attr = attr_value_pb2.AttrValue(
1757            list=attr_value_pb2.AttrValue.ListValue(
1758                s=[compat.as_bytes("loc:@%s" % handle_name)]))
1759        with ops.get_default_graph()._attr_scope({"_class": attr}):
1760          with ops.name_scope("Initializer"), device_context_manager(None):
1761            if init_from_fn:
1762              initial_value = initial_value()
1763            if isinstance(initial_value, trackable.CheckpointInitialValue):
1764              self._maybe_initialize_trackable()
1765              self._update_uid = initial_value.checkpoint_position.restore_uid
1766              initial_value = initial_value.wrapped_value
1767            initial_value = ops.convert_to_tensor(initial_value,
1768                                                  name="initial_value",
1769                                                  dtype=dtype)
1770          if shape is not None:
1771            if not initial_value.shape.is_compatible_with(shape):
1772              raise ValueError(
1773                  f"In this `tf.Variable` creation, the initial value's shape "
1774                  f"({initial_value.shape}) is not compatible with "
1775                  f"the explicitly supplied `shape` argument ({shape}).")
1776          else:
1777            shape = initial_value.shape
1778          handle = eager_safe_variable_handle(
1779              initial_value=initial_value,
1780              shape=shape,
1781              shared_name=shared_name,
1782              name=name,
1783              graph_mode=self._in_graph_mode)
1784        # pylint: disable=protected-access
1785        if (self._in_graph_mode and initial_value is not None and
1786            initial_value.op._get_control_flow_context() is not None):
1787          raise ValueError(
1788              f"The `initial_value` passed to `tf.Variable` {name} is from "
1789              f"inside a control-flow  construct, such as a loop or "
1790              f"conditional. When creating a "
1791              f"`tf.Variable` inside a loop or conditional, use a lambda as "
1792              f"the `initial_value`. Got: initial_value=({initial_value})")
1793        # pylint: enable=protected-access
1794        dtype = initial_value.dtype.base_dtype
1795
1796        if self._in_graph_mode:
1797          with ops.name_scope("IsInitialized"):
1798            is_initialized_op = (
1799                gen_resource_variable_ops.var_is_initialized_op(handle))
1800          if initial_value is not None:
1801            # pylint: disable=g-backslash-continuation
1802            with ops.name_scope("Assign") as n, \
1803                 ops.colocate_with(None, ignore_existing=True), \
1804                 ops.device(handle.device):
1805              # pylint: disable=protected-access
1806              initializer_op = (
1807                  gen_resource_variable_ops.assign_variable_op(
1808                      handle,
1809                      variables._try_guard_against_uninitialized_dependencies(
1810                          name, initial_value),
1811                      name=n))
1812              # pylint: enable=protected-access
1813            # pylint: enable=g-backslash-continuation
1814          with ops.name_scope("Read"):
1815            # Manually assign reads to the handle's device to avoid log
1816            # messages.
1817            with ops.device(handle.device):
1818              value = gen_resource_variable_ops.read_variable_op(handle, dtype)
1819              _maybe_set_handle_data(dtype, handle, value)
1820            graph_element = value
1821            if caching_device is not None:
1822              # Variables may be created in a tf.device() or ops.colocate_with()
1823              # context. At the same time, users would expect caching device to
1824              # be independent of this context, and/or would not expect the
1825              # current device context to be merged with the caching device
1826              # spec.  Therefore we reset the colocation stack before creating
1827              # the cached value. Note that resetting the colocation stack will
1828              # also reset the device stack.
1829              with ops.colocate_with(None, ignore_existing=True):
1830                with ops.device(caching_device):
1831                  cached_value = array_ops.identity(value)
1832            else:
1833              cached_value = None
1834        else:
1835          gen_resource_variable_ops.assign_variable_op(handle, initial_value)
1836          is_initialized_op = None
1837          initializer_op = None
1838          graph_element = None
1839          if caching_device:
1840            with ops.device(caching_device):
1841              cached_value = gen_resource_variable_ops.read_variable_op(
1842                  handle, dtype)
1843              _maybe_set_handle_data(dtype, handle, cached_value)
1844          else:
1845            cached_value = None
1846
1847        if cached_value is not None:
1848          # Store the variable object so that the original variable can be
1849          # accessed to generate functions that are compatible with SavedModel.
1850          cached_value._cached_variable = weakref.ref(self)  # pylint: disable=protected-access
1851
1852        if not context.executing_eagerly():
1853          # Eager variables are only added to collections if they are part of an
1854          # eager variable store (otherwise in an interactive session they would
1855          # hog memory and cause OOM). This is done in ops/variable_scope.py.
1856          ops.add_to_collections(collections, self)
1857        elif ops.GraphKeys.GLOBAL_STEP in collections:
1858          ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)
1859      initial_value = initial_value if self._in_graph_mode else None
1860      super(ResourceVariable, self).__init__(
1861          trainable=trainable,
1862          shape=shape,
1863          dtype=dtype,
1864          handle=handle,
1865          synchronization=synchronization,
1866          constraint=constraint,
1867          aggregation=aggregation,
1868          distribute_strategy=distribute_strategy,
1869          name=name,
1870          unique_id=unique_id,
1871          handle_name=handle_name,
1872          graph_element=graph_element,
1873          initial_value=initial_value,
1874          initializer_op=initializer_op,
1875          is_initialized_op=is_initialized_op,
1876          cached_value=cached_value,
1877          caching_device=caching_device)
1878
1879  def _init_from_proto(self, variable_def, import_scope=None):
1880    """Initializes from `VariableDef` proto."""
1881    # Note that init_from_proto is currently not supported in Eager mode.
1882    assert not context.executing_eagerly()
1883    self._in_graph_mode = True
1884    assert isinstance(variable_def, variable_pb2.VariableDef)
1885    if not variable_def.is_resource:
1886      raise ValueError(f"The `variable_def` you passed to `tf.Variable` is "
1887                       f"Trying to restore a TF 1.x Reference Variable "
1888                       f"as a TF 2.x ResourceVariable. This is unsupported. "
1889                       f"Got variable_def={variable_def}")
1890
1891    # Create from variable_def.
1892    g = ops.get_default_graph()
1893    self._handle = g.as_graph_element(
1894        ops.prepend_name_scope(
1895            variable_def.variable_name, import_scope=import_scope))
1896    self._shape = tensor_shape.TensorShape(self._handle.op.get_attr("shape"))
1897    self._handle_name = self._handle.name
1898    self._unique_id = self._handle_name
1899    self._initializer_op = g.as_graph_element(
1900        ops.prepend_name_scope(
1901            variable_def.initializer_name, import_scope=import_scope))
1902    # Check whether initial_value_name exists for backwards compatibility.
1903    if (hasattr(variable_def, "initial_value_name") and
1904        variable_def.initial_value_name):
1905      self._initial_value = g.as_graph_element(
1906          ops.prepend_name_scope(
1907              variable_def.initial_value_name, import_scope=import_scope))
1908    else:
1909      self._initial_value = None
1910    synchronization, aggregation, trainable = (
1911        variables.validate_synchronization_aggregation_trainable(
1912            variable_def.synchronization, variable_def.aggregation,
1913            variable_def.trainable, variable_def.variable_name))
1914    self._synchronization = synchronization
1915    self._aggregation = aggregation
1916    self._trainable = trainable
1917    if variable_def.snapshot_name:
1918      snapshot = g.as_graph_element(
1919          ops.prepend_name_scope(
1920              variable_def.snapshot_name, import_scope=import_scope))
1921      if snapshot.op.type != "ReadVariableOp":
1922        self._cached_value = snapshot
1923      else:
1924        self._cached_value = None
1925      while snapshot.op.type != "ReadVariableOp":
1926        snapshot = snapshot.op.inputs[0]
1927      self._graph_element = snapshot
1928    else:
1929      self._cached_value = None
1930      # Legacy case for protos without the snapshot name; assume it's the
1931      # following.
1932      self._graph_element = g.get_tensor_by_name(self._handle.op.name +
1933                                                 "/Read/ReadVariableOp:0")
1934    if variable_def.HasField("save_slice_info_def"):
1935      self._save_slice_info = variables.Variable.SaveSliceInfo(
1936          save_slice_info_def=variable_def.save_slice_info_def,
1937          import_scope=import_scope)
1938    else:
1939      self._save_slice_info = None
1940    self._caching_device = None
1941    self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
1942    self._constraint = None
1943
1944
1945class UninitializedVariable(BaseResourceVariable):
1946  """A variable with no initializer."""
1947
1948  def __init__(  # pylint: disable=super-init-not-called
1949      self,
1950      trainable=None,
1951      caching_device=None,
1952      name=None,
1953      shape=None,
1954      dtype=None,
1955      constraint=None,
1956      synchronization=None,
1957      aggregation=None,
1958      extra_handle_data=None,
1959      distribute_strategy=None,
1960      **unused_kwargs):
1961    """Creates the variable handle.
1962
1963    Args:
1964      trainable: If `True`, GradientTapes automatically watch uses of this
1965        Variable.
1966      caching_device: Optional device string or function describing where the
1967        Variable should be cached for reading.  Defaults to the Variable's
1968        device.  If not `None`, caches on another device.  Typical use is to
1969        cache on the device where the Ops using the Variable reside, to
1970        deduplicate copying through `Switch` and other conditional statements.
1971      name: Optional name for the variable. Defaults to `'Variable'` and gets
1972        uniquified automatically.
1973      shape: The variable's shape.
1974      dtype: The variable's dtype.
1975      constraint: An optional projection function to be applied to the variable
1976        after being updated by an `Optimizer` (e.g. used to implement norm
1977        constraints or value constraints for layer weights). The function must
1978        take as input the unprojected Tensor representing the value of the
1979        variable and return the Tensor for the projected value (which must have
1980        the same shape). Constraints are not safe to use when doing asynchronous
1981        distributed training.
1982      synchronization: Indicates when a distributed a variable will be
1983        aggregated. Accepted values are constants defined in the class
1984        `tf.VariableSynchronization`. By default the synchronization is set to
1985        `AUTO` and the current `DistributionStrategy` chooses when to
1986        synchronize.
1987      aggregation: Indicates how a distributed variable will be aggregated.
1988        Accepted values are constants defined in the class
1989        `tf.VariableAggregation`.
1990      extra_handle_data: Optional, another resource handle or Tensor with handle
1991        data to merge with `shape` and `dtype`.
1992      distribute_strategy: The tf.distribute.Strategy this variable is being
1993        created inside of.
1994    """
1995    with ops.init_scope():
1996      # Here we are detecting eagerness within an init_scope, so this will only
1997      # be true when we are running in TF1 graph mode.
1998      self._in_graph_mode = not context.executing_eagerly()
1999      with ops.name_scope(name, "Variable", skip_on_eager=False) as name:
2000        handle_name = ops.name_from_scope_name(name)
2001        if self._in_graph_mode:
2002          shared_name = handle_name
2003          unique_id = shared_name
2004        else:
2005          unique_id = "%s_%d" % (handle_name, ops.uid())
2006          shared_name = None  # Never shared
2007        handle = _variable_handle_from_shape_and_dtype(
2008            shape=shape,
2009            dtype=dtype,
2010            shared_name=shared_name,
2011            name=name,
2012            graph_mode=self._in_graph_mode,
2013            initial_value=extra_handle_data)
2014        if self._in_graph_mode:
2015          # We only need to add the read_variable_op in TF1.
2016          with ops.name_scope("Read"):
2017            # Manually assign reads to the handle's device to avoid log
2018            # messages.
2019            with ops.device(handle.device):
2020              value = gen_resource_variable_ops.read_variable_op(handle, dtype)
2021              _maybe_set_handle_data(dtype, handle, value)
2022            graph_element = value
2023          ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self)
2024          # Do *not* add to TRAINABLE_VARIABLES here, even if self._trainable,
2025          # because retraining or frozen use of imported SavedModels is
2026          # controlled at higher levels of model building.
2027        else:
2028          graph_element = None
2029    super(UninitializedVariable, self).__init__(
2030        distribute_strategy=distribute_strategy,
2031        shape=shape,
2032        dtype=dtype,
2033        unique_id=unique_id,
2034        handle_name=handle_name,
2035        constraint=constraint,
2036        handle=handle,
2037        graph_element=graph_element,
2038        trainable=trainable,
2039        synchronization=synchronization,
2040        aggregation=aggregation,
2041        in_graph_mode=self._in_graph_mode)
2042
2043
2044_pywrap_utils.RegisterType("ResourceVariable", ResourceVariable)
2045math_ops._resource_variable_type = ResourceVariable  # pylint: disable=protected-access
2046
2047
2048def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
2049  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
2050
2051
2052# Register a conversion function which reads the value of the variable,
2053# allowing instances of the class to be used as tensors.
2054ops.register_tensor_conversion_function(BaseResourceVariable,
2055                                        _dense_var_to_tensor)
2056
2057
2058class _UnreadVariable(BaseResourceVariable):
2059  """Represents a future for a read of a variable.
2060
2061  Pretends to be the tensor if anyone looks.
2062  """
2063
2064  def __init__(self, handle, dtype, shape, in_graph_mode, deleter, parent_op,
2065               unique_id):
2066    if isinstance(handle, ops.EagerTensor):
2067      handle_name = ""
2068    else:
2069      handle_name = handle.name
2070    # Only create a graph_element if we're in session.run-land as only
2071    # session.run requires a preexisting tensor to evaluate. Otherwise we can
2072    # avoid accidentally reading the variable.
2073    if context.executing_eagerly() or ops.inside_function():
2074      graph_element = None
2075    else:
2076      with ops.control_dependencies([parent_op]):
2077        graph_element = gen_resource_variable_ops.read_variable_op(
2078            handle, dtype)
2079        _maybe_set_handle_data(dtype, handle, graph_element)
2080    super(_UnreadVariable, self).__init__(
2081        handle=handle,
2082        shape=shape,
2083        handle_name=handle_name,
2084        unique_id=unique_id,
2085        dtype=dtype,
2086        handle_deleter=deleter,
2087        graph_element=graph_element)
2088    self._parent_op = parent_op
2089
2090  @property
2091  def name(self):
2092    if self._in_graph_mode:
2093      return self._parent_op.name
2094    else:
2095      return "UnreadVariable"
2096
2097  def value(self):
2098    return self._read_variable_op()
2099
2100  def read_value(self):
2101    return self._read_variable_op()
2102
2103  def _read_variable_op(self):
2104    with ops.control_dependencies([self._parent_op]):
2105      result = gen_resource_variable_ops.read_variable_op(
2106          self._handle, self._dtype)
2107      _maybe_set_handle_data(self._dtype, self._handle, result)
2108      return result
2109
2110  def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
2111    with ops.control_dependencies([self._parent_op]):
2112      return super(_UnreadVariable, self).assign_sub(delta, use_locking, name,
2113                                                     read_value)
2114
2115  def assign_add(self, delta, use_locking=None, name=None, read_value=True):
2116    with ops.control_dependencies([self._parent_op]):
2117      return super(_UnreadVariable, self).assign_add(delta, use_locking, name,
2118                                                     read_value)
2119
2120  def assign(self, value, use_locking=None, name=None, read_value=True):
2121    with ops.control_dependencies([self._parent_op]):
2122      return super(_UnreadVariable, self).assign(value, use_locking, name,
2123                                                 read_value)
2124
2125  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
2126    with ops.control_dependencies([self._parent_op]):
2127      return super(_UnreadVariable, self).scatter_sub(sparse_delta, use_locking,
2128                                                      name)
2129
2130  def scatter_add(self, sparse_delta, use_locking=False, name=None):
2131    with ops.control_dependencies([self._parent_op]):
2132      return super(_UnreadVariable, self).scatter_add(sparse_delta, use_locking,
2133                                                      name)
2134
2135  def scatter_max(self, sparse_delta, use_locking=False, name=None):
2136    with ops.control_dependencies([self._parent_op]):
2137      return super(_UnreadVariable, self).scatter_max(sparse_delta, use_locking,
2138                                                      name)
2139
2140  def scatter_min(self, sparse_delta, use_locking=False, name=None):
2141    with ops.control_dependencies([self._parent_op]):
2142      return super(_UnreadVariable, self).scatter_min(sparse_delta, use_locking,
2143                                                      name)
2144
2145  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
2146    with ops.control_dependencies([self._parent_op]):
2147      return super(_UnreadVariable, self).scatter_mul(sparse_delta, use_locking,
2148                                                      name)
2149
2150  def scatter_div(self, sparse_delta, use_locking=False, name=None):
2151    with ops.control_dependencies([self._parent_op]):
2152      return super(_UnreadVariable, self).scatter_div(sparse_delta, use_locking,
2153                                                      name)
2154
2155  def scatter_update(self, sparse_delta, use_locking=False, name=None):
2156    with ops.control_dependencies([self._parent_op]):
2157      return super(_UnreadVariable,
2158                   self).scatter_update(sparse_delta, use_locking, name)
2159
2160  def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
2161    with ops.control_dependencies([self._parent_op]):
2162      return super(_UnreadVariable,
2163                   self).batch_scatter_update(sparse_delta, use_locking, name)
2164
2165  def scatter_nd_sub(self, indices, updates, name=None):
2166    with ops.control_dependencies([self._parent_op]):
2167      return super(_UnreadVariable, self).scatter_nd_sub(indices, updates, name)
2168
2169  def scatter_nd_add(self, indices, updates, name=None):
2170    with ops.control_dependencies([self._parent_op]):
2171      return super(_UnreadVariable, self).scatter_nd_add(indices, updates, name)
2172
2173  def scatter_nd_update(self, indices, updates, name=None):
2174    with ops.control_dependencies([self._parent_op]):
2175      return super(_UnreadVariable,
2176                   self).scatter_nd_update(indices, updates, name)
2177
2178  def scatter_nd_max(self, indices, updates, name=None):
2179    with ops.control_dependencies([self._parent_op]):
2180      return super(_UnreadVariable, self).scatter_nd_max(indices, updates, name)
2181
2182  def scatter_nd_min(self, indices, updates, name=None):
2183    with ops.control_dependencies([self._parent_op]):
2184      return super(_UnreadVariable, self).scatter_nd_min(indices, updates, name)
2185
2186  @property
2187  def op(self):
2188    """The op for this variable."""
2189    return self._parent_op
2190
2191
2192@ops.RegisterGradient("ReadVariableOp")
2193def _ReadGrad(_, grad):
2194  """Gradient for read op."""
2195  return grad
2196
2197
2198def variable_shape(handle, out_type=dtypes.int32):
2199  handle_data = get_eager_safe_handle_data(handle)
2200  if handle_data is None or not handle_data.is_set:
2201    return gen_resource_variable_ops.variable_shape(handle, out_type=out_type)
2202  shape_proto = handle_data.shape_and_type[0].shape
2203  if shape_proto.unknown_rank or any(x.size == -1 for x in shape_proto.dim):
2204    return gen_resource_variable_ops.variable_shape(handle, out_type=out_type)
2205  return constant_op.constant([x.size for x in shape_proto.dim], dtype=out_type)
2206
2207
2208@ops.RegisterGradient("ResourceGather")
2209def _GatherGrad(op, grad):
2210  """Gradient for gather op."""
2211  # Build appropriately shaped IndexedSlices
2212  handle = op.inputs[0]
2213  indices = op.inputs[1]
2214  params_shape = variable_shape(handle)
2215  size = array_ops.expand_dims(array_ops.size(indices), 0)
2216  values_shape = array_ops.concat([size, params_shape[1:]], 0)
2217  values = array_ops.reshape(grad, values_shape)
2218  indices = array_ops.reshape(indices, size)
2219  return (ops.IndexedSlices(values, indices, params_shape), None)
2220
2221
2222def _to_proto_fn(v, export_scope=None):
2223  """Converts Variable and ResourceVariable to VariableDef for collections."""
2224  return v.to_proto(export_scope=export_scope)
2225
2226
2227def _from_proto_fn(v, import_scope=None):
2228  """Creates Variable or ResourceVariable from VariableDef as needed."""
2229  if v.is_resource:
2230    return ResourceVariable.from_proto(v, import_scope=import_scope)
2231  return variables.Variable.from_proto(v, import_scope=import_scope)
2232
2233
2234ops.register_proto_function(
2235    ops.GraphKeys.GLOBAL_VARIABLES,
2236    proto_type=variable_pb2.VariableDef,
2237    to_proto=_to_proto_fn,
2238    from_proto=_from_proto_fn)
2239ops.register_proto_function(
2240    ops.GraphKeys.TRAINABLE_VARIABLES,
2241    proto_type=variable_pb2.VariableDef,
2242    to_proto=_to_proto_fn,
2243    from_proto=_from_proto_fn)
2244ops.register_proto_function(
2245    ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
2246    proto_type=variable_pb2.VariableDef,
2247    to_proto=_to_proto_fn,
2248    from_proto=_from_proto_fn)
2249ops.register_proto_function(
2250    ops.GraphKeys.LOCAL_VARIABLES,
2251    proto_type=variable_pb2.VariableDef,
2252    to_proto=_to_proto_fn,
2253    from_proto=_from_proto_fn)
2254ops.register_proto_function(
2255    ops.GraphKeys.MODEL_VARIABLES,
2256    proto_type=variable_pb2.VariableDef,
2257    to_proto=_to_proto_fn,
2258    from_proto=_from_proto_fn)
2259ops.register_proto_function(
2260    ops.GraphKeys.GLOBAL_STEP,
2261    proto_type=variable_pb2.VariableDef,
2262    to_proto=_to_proto_fn,
2263    from_proto=_from_proto_fn)
2264ops.register_proto_function(
2265    ops.GraphKeys.METRIC_VARIABLES,
2266    proto_type=variable_pb2.VariableDef,
2267    to_proto=_to_proto_fn,
2268    from_proto=_from_proto_fn)
2269
2270
2271@tf_export("__internal__.ops.is_resource_variable", v1=[])
2272def is_resource_variable(var):
2273  """"Returns True if `var` is to be considered a ResourceVariable."""
2274  return isinstance(var, BaseResourceVariable) or hasattr(
2275      var, "_should_act_as_resource_variable")
2276
2277
2278def copy_to_graph_uninitialized(var):
2279  """Copies an existing variable to a new graph, with no initializer."""
2280  # Like ResourceVariable.__deepcopy__, but does not set an initializer on the
2281  # new variable.
2282  # pylint: disable=protected-access
2283  new_variable = UninitializedVariable(
2284      trainable=var.trainable,
2285      constraint=var._constraint,
2286      shape=var.shape,
2287      dtype=var.dtype,
2288      name=var._shared_name,
2289      synchronization=var.synchronization,
2290      aggregation=var.aggregation,
2291      extra_handle_data=var.handle)
2292  new_variable._maybe_initialize_trackable()
2293  # pylint: enable=protected-access
2294  return new_variable
2295
2296
2297ops.NotDifferentiable("Assert")
2298ops.NotDifferentiable("VarIsInitializedOp")
2299ops.NotDifferentiable("VariableShape")
2300
2301
2302class VariableSpec(tensor_spec.DenseSpec):
2303  """Describes a tf.Variable."""
2304
2305  __slots__ = ["trainable"]
2306
2307  value_type = property(lambda self: BaseResourceVariable)
2308
2309  def __init__(self, shape, dtype=dtypes.float32,
2310               name=None, trainable=True):
2311    super(VariableSpec, self).__init__(shape, dtype=dtype, name=name)
2312    self.trainable = trainable
2313
2314  def _to_components(self, value):
2315    raise NotImplementedError
2316
2317  def _from_components(self, components):
2318    raise NotImplementedError
2319
2320  def _from_compatible_tensor_list(self, tensor_list):
2321    assert len(tensor_list) == 1
2322    return tensor_list[0]
2323
2324
2325_pywrap_utils.RegisterType("VariableSpec", VariableSpec)
2326
2327
2328def write_object_proto_for_resource_variable(resource_variable, proto, options):
2329  """Writes additional information of the variable into the SavedObject proto.
2330
2331  This allows users to define a `hook` to provide extra information of the
2332  variable to the SavedObject.
2333
2334  For example, DistritubtedVariable class would fill in components in the
2335  distributed context.
2336
2337  Args:
2338    resource_variable: A `ResourceVariable` or `DistributedValue` that has the
2339      information to be saved into the proto.
2340    proto: `SavedObject` proto to update.
2341    options: A `SaveOption` instance that configures save behavior.
2342  """
2343  proto.variable.SetInParent()
2344  if not resource_variable.name.endswith(":0"):
2345    raise ValueError(f"Cowardly refusing to save variable "
2346                     f"{resource_variable.name} because of "
2347                     f"unexpected suffix in the name (':0') "
2348                     f"which won't be restored.")
2349  proto.variable.name = meta_graph._op_name(resource_variable.name)  # pylint: disable=protected-access
2350  proto.variable.trainable = resource_variable.trainable
2351  proto.variable.dtype = resource_variable.dtype.as_datatype_enum
2352  proto.variable.synchronization = resource_variable.synchronization.value
2353  proto.variable.aggregation = resource_variable.aggregation.value
2354  proto.variable.shape.CopyFrom(resource_variable.shape.as_proto())
2355  if options.experimental_variable_policy._save_variable_devices(  # pylint: disable=protected-access
2356  ):
2357    if hasattr(resource_variable, "device"):
2358      proto.variable.device = resource_variable.device
2359