• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops to use variables as resources."""
16
17# pylint: disable=g-bad-name
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import contextlib
23import functools
24import weakref
25
26import numpy as np
27
28from tensorflow.core.framework import attr_value_pb2
29from tensorflow.core.framework import variable_pb2
30from tensorflow.python.client import pywrap_tf_session
31from tensorflow.python.eager import context
32from tensorflow.python.eager import tape
33from tensorflow.python.framework import auto_control_deps_utils as acd
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import cpp_shape_inference_pb2
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import errors
38from tensorflow.python.framework import ops
39from tensorflow.python.framework import tensor_shape
40from tensorflow.python.framework import tensor_spec
41from tensorflow.python.ops import array_ops
42from tensorflow.python.ops import gen_array_ops
43from tensorflow.python.ops import gen_resource_variable_ops
44from tensorflow.python.ops import gen_state_ops
45from tensorflow.python.ops import handle_data_util
46from tensorflow.python.ops import math_ops
47from tensorflow.python.ops import state_ops
48from tensorflow.python.ops import variables
49# go/tf-wildcard-import
50# pylint: disable=wildcard-import
51from tensorflow.python.ops.gen_resource_variable_ops import *
52# pylint: enable=wildcard-import
53from tensorflow.python.training.tracking import base as trackable
54from tensorflow.python.types import core
55from tensorflow.python.util import _pywrap_utils
56from tensorflow.python.util import compat
57from tensorflow.python.util.deprecation import deprecated
58from tensorflow.python.util.tf_export import tf_export
59
60acd.register_read_only_resource_op("ReadVariableOp")
61acd.register_read_only_resource_op("VariableShape")
62acd.register_read_only_resource_op("ResourceGather")
63acd.register_read_only_resource_op("ResourceGatherNd")
64acd.register_read_only_resource_op("_ReadVariablesOp")
65
66
67# TODO(allenl): Remove this alias and migrate callers.
68get_resource_handle_data = handle_data_util.get_resource_handle_data
69
70
71def get_eager_safe_handle_data(handle):
72  """Get the data handle from the Tensor `handle`."""
73  assert isinstance(handle, ops.Tensor)
74
75  if isinstance(handle, ops.EagerTensor):
76    return handle._handle_data  # pylint: disable=protected-access
77  else:
78    return get_resource_handle_data(handle)
79
80
81def _set_handle_shapes_and_types(tensor, handle_data, graph_mode):
82  """Sets the shape inference result HandleData on tensor.
83
84  Args:
85    tensor: A `Tensor` or `EagerTensor`.
86    handle_data: A `CppShapeInferenceResult.HandleData`.
87    graph_mode: A python bool.
88  """
89  tensor._handle_data = handle_data  # pylint: disable=protected-access
90  if not graph_mode:
91    return
92
93  # Not an EagerTensor, so a graph tensor.
94  shapes, types = zip(*[(pair.shape, pair.dtype)
95                        for pair in handle_data.shape_and_type])
96  ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
97  shapes = [
98      [d.size for d in s.dim]  # pylint: disable=g-complex-comprehension
99      if not s.unknown_rank else None for s in shapes
100  ]
101  pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
102      tensor._op._graph._c_graph,  # pylint: disable=protected-access
103      tensor._as_tf_output(),  # pylint: disable=protected-access
104      shapes,
105      ranks,
106      types)
107
108
109def _combine_handle_data(handle, initial_value):
110  """Concats HandleData from tensors `handle` and `initial_value`.
111
112  Args:
113    handle: A `Tensor` of dtype `resource`.
114    initial_value: A `Tensor`.
115
116  Returns:
117    A `CppShapeInferenceResult.HandleData`.  If `initial_value` has dtype
118    `variant`, the `HandleData` contains the concatenation of the shape_and_type
119    from both `handle` and `initial_value`.
120
121  Raises:
122    RuntimeError: If handle, which was returned by VarHandleOp, either has
123      no handle data, or its len(handle_data.shape_and_type) != 1.
124  """
125  assert handle.dtype == dtypes.resource
126
127  variable_handle_data = get_eager_safe_handle_data(handle)
128
129  if initial_value.dtype != dtypes.variant:
130    return variable_handle_data
131
132  extra_handle_data = get_eager_safe_handle_data(initial_value)
133  if extra_handle_data is not None and extra_handle_data.is_set:
134    if (variable_handle_data is None or not variable_handle_data.is_set or
135        len(variable_handle_data.shape_and_type) != 1):
136      raise RuntimeError(
137          "Expected VarHandleOp to return a length==1 shape_and_type, "
138          "but saw: '%s'" % (variable_handle_data,))
139    variable_handle_data.shape_and_type.extend(extra_handle_data.shape_and_type)
140  return variable_handle_data
141
142
143def _variable_handle_from_shape_and_dtype(shape,
144                                          dtype,
145                                          shared_name,
146                                          name,
147                                          graph_mode,
148                                          initial_value=None):
149  """Create a variable handle, copying in handle data from `initial_value`."""
150  container = ops.get_default_graph()._container  # pylint: disable=protected-access
151  if container is None:
152    container = ""
153  shape = tensor_shape.as_shape(shape)
154  dtype = dtypes.as_dtype(dtype)
155  if not graph_mode:
156    if shared_name is not None:
157      raise errors.InternalError(
158          "Using an explicit shared_name is not supported executing eagerly.")
159    shared_name = context.shared_name()
160
161  handle = gen_resource_variable_ops.var_handle_op(
162      shape=shape,
163      dtype=dtype,
164      shared_name=shared_name,
165      name=name,
166      container=container)
167  if initial_value is None:
168    initial_value = handle
169  if graph_mode:
170    full_handle_data = _combine_handle_data(handle, initial_value)
171    _set_handle_shapes_and_types(handle, full_handle_data, graph_mode)
172    return handle
173  else:
174    handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData()
175    handle_data.is_set = True
176    handle_data.shape_and_type.append(
177        cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType(
178            shape=shape.as_proto(), dtype=dtype.as_datatype_enum))
179
180    if initial_value is not None and initial_value.dtype == dtypes.variant:
181      extra_handle_data = get_eager_safe_handle_data(initial_value)
182      if extra_handle_data is not None and extra_handle_data.is_set:
183        if (not handle_data.is_set or len(handle_data.shape_and_type) != 1):
184          raise RuntimeError(
185              "Expected VarHandleOp to return a length==1 shape_and_type, "
186              "but saw: '%s'" % (handle_data,))
187        handle_data.shape_and_type.extend(extra_handle_data.shape_and_type)
188
189    _set_handle_shapes_and_types(handle, handle_data, graph_mode)
190    return handle
191
192
193def eager_safe_variable_handle(initial_value, shape, shared_name, name,
194                               graph_mode):
195  """Creates a variable handle with information to do shape inference.
196
197  The dtype is read from `initial_value` and stored in the returned
198  resource tensor's handle data.
199
200  If `initial_value.dtype == tf.variant`, we additionally extract the handle
201  data (if any) from `initial_value` and append it to the `handle_data`.
202  In this case, the returned tensor's handle data is in the form
203
204  ```
205  is_set: true
206  shape_and_type {
207    shape {
208      // initial_value.shape
209    }
210    dtype: DT_VARIANT
211  }
212  shape_and_type {
213    // handle_data(initial_value).shape_and_type[0]
214  }
215  shape_and_type {
216    // handle_data(initial_value).shape_and_type[1]
217  }
218  ...
219  ```
220
221  Ops that read from this tensor, such as `ReadVariableOp` and
222  `AssignVariableOp`, know that `handle_data(handle).shape_and_type[1:]`
223  correspond to the handle data of the variant(s) stored in the Variable.
224
225  Args:
226    initial_value: A `Tensor`.
227    shape: The shape of the handle data. Can be `TensorShape(None)` (i.e.
228      unknown shape).
229    shared_name: A string.
230    name: A string.
231    graph_mode: A python bool.
232
233  Returns:
234    The handle, a `Tensor` of type `resource`.
235  """
236  dtype = initial_value.dtype.base_dtype
237  return _variable_handle_from_shape_and_dtype(shape, dtype, shared_name, name,
238                                               graph_mode, initial_value)
239
240
241@contextlib.contextmanager
242def _handle_graph(handle):
243  # Note: might have an eager tensor but not be executing eagerly when building
244  # functions.
245  if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor) or
246      ops.has_default_graph()):
247    yield
248  else:
249    with handle.graph.as_default():
250      yield
251
252
253class EagerResourceDeleter(object):
254  """An object which cleans up a resource handle.
255
256  An alternative to defining a __del__ method on an object. The intended use is
257  that ResourceVariables or other objects with resource handles will maintain a
258  single reference to this object. When the parent object is collected, this
259  object will be too. Even if the parent object is part of a reference cycle,
260  the cycle will be collectable.
261  """
262
263  __slots__ = ["_handle", "_handle_device", "_context"]
264
265  def __init__(self, handle, handle_device):
266    if not isinstance(handle, ops.Tensor):
267      raise ValueError(
268          ("Passed handle=%s to EagerResourceDeleter. Was expecting a handle "
269           "Tensor." % (handle,)))
270    self._handle = handle
271    self._handle_device = handle_device
272    # This is held since the __del__ function runs an op, and if the context()
273    # is collected before this object, there will be a segfault when running the
274    # op.
275    self._context = context.context()
276
277  def __del__(self):
278    # Resources follow object-identity when executing eagerly, so it is safe to
279    # delete the resource we have a handle to.
280    try:
281      # A packed EagerTensor doesn't own any resource.
282      if isinstance(self._handle, ops.EagerTensor) and self._handle.is_packed:
283        return
284      # This resource was created in eager mode. However, this destructor may be
285      # running in graph mode (especially during unit tests). To clean up
286      # successfully, we switch back into eager mode temporarily.
287      with context.eager_mode():
288        with ops.device(self._handle_device):
289          gen_resource_variable_ops.destroy_resource_op(
290              self._handle, ignore_lookup_error=True)
291    except TypeError:
292      # Suppress some exceptions, mainly for the case when we're running on
293      # module deletion. Things that can go wrong include the context module
294      # already being unloaded, self._handle._handle_data no longer being
295      # valid, and so on. Printing warnings in these cases is silly
296      # (exceptions raised from __del__ are printed as warnings to stderr).
297      pass  # 'NoneType' object is not callable when the handle has been
298      # partially unloaded.
299    except AttributeError:
300      pass  # 'NoneType' object has no attribute 'eager_mode' when context has
301      # been unloaded. Will catch other module unloads as well.
302
303
304def shape_safe_assign_variable_handle(handle, shape, value, name=None):
305  """Helper that checks shape compatibility and assigns variable."""
306  with _handle_graph(handle):
307    value_tensor = ops.convert_to_tensor(value)
308  shape.assert_is_compatible_with(value_tensor.shape)
309  return gen_resource_variable_ops.assign_variable_op(
310      handle, value_tensor, name=name)
311
312
313def _maybe_set_handle_data(dtype, handle, tensor):
314  if dtype == dtypes.variant:
315    # For DT_VARIANT types, the handle's shape_and_type[1:] stores the
316    # variant's handle data.  Extract it.
317    handle_data = get_eager_safe_handle_data(handle)
318    if handle_data.is_set and len(handle_data.shape_and_type) > 1:
319      tensor._handle_data = (  # pylint: disable=protected-access
320          cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData(
321              is_set=True, shape_and_type=handle_data.shape_and_type[1:]))
322
323
324def variable_accessed(variable):
325  """Records that `variable` was accessed for the tape and FuncGraph."""
326  if hasattr(ops.get_default_graph(), "watch_variable"):
327    ops.get_default_graph().watch_variable(variable)
328  if variable.trainable:
329    tape.variable_accessed(variable)
330
331
332class BaseResourceVariable(variables.VariableV1, core.Tensor):
333  """A python variable from an existing handle."""
334
335  # TODO(wangpeng): Deprecate `constraint` when callers no long pass it in.
336  def __init__(  # pylint: disable=super-init-not-called
337      self,
338      trainable=None,
339      shape=None,
340      dtype=None,
341      handle=None,
342      constraint=None,
343      synchronization=None,
344      aggregation=None,
345      distribute_strategy=None,
346      name=None,
347      unique_id=None,
348      handle_name=None,
349      graph_element=None,
350      initial_value=None,
351      initializer_op=None,
352      is_initialized_op=None,
353      cached_value=None,
354      save_slice_info=None,
355      handle_deleter=None,
356      caching_device=None,
357      **unused_kwargs):
358    """Creates a variable from a handle.
359
360    Args:
361      trainable: If `True`, GradientTapes automatically watch uses of this
362        Variable.
363      shape: The variable's shape.
364      dtype: The variable's dtype.
365      handle: The variable's handle
366      constraint: An optional projection function to be applied to the variable
367        after being updated by an `Optimizer` (e.g. used to implement norm
368        constraints or value constraints for layer weights). The function must
369        take as input the unprojected Tensor representing the value of the
370        variable and return the Tensor for the projected value (which must have
371        the same shape). Constraints are not safe to use when doing asynchronous
372        distributed training.
373      synchronization: Indicates when a distributed a variable will be
374        aggregated. Accepted values are constants defined in the class
375        `tf.VariableSynchronization`. By default the synchronization is set to
376        `AUTO` and the current `DistributionStrategy` chooses when to
377        synchronize.
378      aggregation: Indicates how a distributed variable will be aggregated.
379        Accepted values are constants defined in the class
380        `tf.VariableAggregation`.
381      distribute_strategy: The distribution strategy this variable was created
382        under.
383      name: The name for this variable.
384      unique_id: Internal. Unique ID for this variable's handle.
385      handle_name: The name for the variable's handle.
386      graph_element: Optional, required only in session.run-mode. Pre-created
387        tensor which reads this variable's value.
388      initial_value: Optional. Variable's initial value.
389      initializer_op: Operation which assigns the variable's initial value.
390      is_initialized_op: Pre-created operation to check whether this variable is
391        initialized.
392      cached_value: Pre-created operation to read this variable in a specific
393        device.
394      save_slice_info: Metadata for variable partitioning.
395      handle_deleter: EagerResourceDeleter responsible for cleaning up the
396        handle.
397      caching_device: Optional device string or function describing where the
398        Variable should be cached for reading.  Defaults to the Variable's
399        device.  If not `None`, caches on another device.  Typical use is to
400        cache on the device where the Ops using the Variable reside, to
401        deduplicate copying through `Switch` and other conditional statements.
402    """
403    with ops.init_scope():
404      self._in_graph_mode = not context.executing_eagerly()
405    synchronization, aggregation, trainable = (
406        variables.validate_synchronization_aggregation_trainable(
407            synchronization, aggregation, trainable, name))
408    self._trainable = trainable
409    self._synchronization = synchronization
410    self._aggregation = aggregation
411    self._save_slice_info = save_slice_info
412    self._initial_value = initial_value
413    self._initializer_op = initializer_op
414    self._is_initialized_op = is_initialized_op
415    self._graph_element = graph_element
416    self._caching_device = caching_device
417    self._cached_value = cached_value
418    self._distribute_strategy = distribute_strategy
419    # Store the graph key so optimizers know how to only retrieve variables from
420    # this graph. Guaranteed to be the same as the eager graph_key.
421    self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
422    self._shape = tensor_shape.as_shape(shape)
423    self._dtype = dtypes.as_dtype(dtype)
424    self._handle = handle
425    self._unique_id = unique_id
426    self._handle_name = handle_name + ":0"
427    self._constraint = constraint
428    # After the handle has been created, set up a way to clean it up when
429    # executing eagerly. We'll hold the only reference to the deleter, so that
430    # when this object is garbage collected the deleter will be too. This
431    # means ResourceVariables can be part of reference cycles without those
432    # cycles being uncollectable.
433    if not self._in_graph_mode:
434      if handle_deleter is None:
435        handle_deleter = EagerResourceDeleter(
436            handle=self._handle, handle_device=self._handle.device)
437    self._handle_deleter = handle_deleter
438    self._cached_shape_as_list = None
439
440  def __repr__(self):
441    if context.executing_eagerly() and not self._in_graph_mode:
442      # If we cannot read the value for any reason, still produce a __repr__.
443      try:
444        value_text = ops.numpy_text(self.read_value(), is_repr=True)
445      except:  # pylint: disable=bare-except
446        value_text = "<unavailable>"
447
448      return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % (
449          self.name, self.get_shape(), self.dtype.name, value_text)
450    else:
451      return "<tf.Variable '%s' shape=%s dtype=%s>" % (
452          self.name, self.get_shape(), self.dtype.name)
453
454  @contextlib.contextmanager
455  def _assign_dependencies(self):
456    """Makes assignments depend on the cached value, if any.
457
458    This prevents undefined behavior with reads not ordered wrt writes.
459
460    Yields:
461      None.
462    """
463    if self._cached_value is not None:
464      with ops.control_dependencies([self._cached_value]):
465        yield
466    else:
467      yield
468
469  def __array__(self):
470    """Allows direct conversion to a numpy array.
471
472    >>> np.array(tf.Variable([1.0]))
473    array([1.], dtype=float32)
474
475    Returns:
476      The variable value as a numpy array.
477    """
478    # You can't return `self.numpy()` here because for scalars
479    # that raises:
480    #     ValueError: object __array__ method not producing an array
481    # Even `self.read_value().__array__()` and `self.read_value()._numpy()` give
482    # the same error. The `EagerTensor` class must be doing something behind the
483    # scenes to make `np.array(tf.constant(1))` work.
484    return np.asarray(self.numpy())
485
486  def __nonzero__(self):
487    return self.__bool__()
488
489  def __bool__(self):
490    return bool(self.read_value())
491
492  def __copy__(self):
493    return self
494
495  def __deepcopy__(self, memo):
496    if not context.executing_eagerly():
497      raise NotImplementedError(
498          "__deepcopy__() is only available when eager execution is enabled.")
499    copied_variable = ResourceVariable(
500        initial_value=self.read_value(),
501        trainable=self._trainable,
502        constraint=self._constraint,
503        dtype=self._dtype,
504        name=self._shared_name,
505        distribute_strategy=self._distribute_strategy,
506        synchronization=self.synchronization,
507        aggregation=self.aggregation)
508    memo[self._unique_id] = copied_variable
509    return copied_variable
510
511  @property
512  def dtype(self):
513    """The dtype of this variable."""
514    return self._dtype
515
516  @property
517  def device(self):
518    """The device this variable is on."""
519    return self._handle.device
520
521  @property
522  def graph(self):
523    """The `Graph` of this variable."""
524    return self._handle.graph
525
526  @property
527  def name(self):
528    """The name of the handle for this variable."""
529    return self._handle_name
530
531  @property
532  def shape(self):
533    """The shape of this variable."""
534    return self._shape
535
536  def set_shape(self, shape):
537    self._shape = self._shape.merge_with(shape)
538
539  def _shape_as_list(self):
540    if self.shape.ndims is None:
541      return None
542    return [dim.value for dim in self.shape.dims]
543
544  def _shape_tuple(self):
545    shape = self._shape_as_list()
546    if shape is None:
547      return None
548    return tuple(shape)
549
550  @property
551  def create(self):
552    """The op responsible for initializing this variable."""
553    if not self._in_graph_mode:
554      raise RuntimeError("Calling create is not supported when eager execution"
555                         " is enabled.")
556    return self._initializer_op
557
558  @property
559  def handle(self):
560    """The handle by which this variable can be accessed."""
561    return self._handle
562
563  def value(self):
564    """A cached operation which reads the value of this variable."""
565    if self._cached_value is not None:
566      return self._cached_value
567    with ops.colocate_with(None, ignore_existing=True):
568      return self._read_variable_op()
569
570  def _as_graph_element(self):
571    """Conversion function for Graph.as_graph_element()."""
572    return self._graph_element
573
574  @property
575  def initializer(self):
576    """The op responsible for initializing this variable."""
577    return self._initializer_op
578
579  @property
580  def initial_value(self):
581    """Returns the Tensor used as the initial value for the variable."""
582    if context.executing_eagerly():
583      raise RuntimeError("initial_value not supported in EAGER mode.")
584    return self._initial_value
585
586  @property
587  def constraint(self):
588    """Returns the constraint function associated with this variable.
589
590    Returns:
591      The constraint function that was passed to the variable constructor.
592      Can be `None` if no constraint was passed.
593    """
594    return self._constraint
595
596  @property
597  def op(self):
598    """The op for this variable."""
599    return self._handle.op
600
601  @property
602  def trainable(self):
603    return self._trainable
604
605  @property
606  def synchronization(self):
607    return self._synchronization
608
609  @property
610  def aggregation(self):
611    return self._aggregation
612
613  def eval(self, session=None):
614    """Evaluates and returns the value of this variable."""
615    if context.executing_eagerly():
616      raise RuntimeError("Trying to eval in EAGER mode")
617    return self._graph_element.eval(session=session)
618
619  def numpy(self):
620    if context.executing_eagerly():
621      return self.read_value().numpy()
622    raise NotImplementedError(
623        "numpy() is only available when eager execution is enabled.")
624
625  @deprecated(None, "Prefer Dataset.range instead.")
626  def count_up_to(self, limit):
627    """Increments this variable until it reaches `limit`.
628
629    When that Op is run it tries to increment the variable by `1`. If
630    incrementing the variable would bring it above `limit` then the Op raises
631    the exception `OutOfRangeError`.
632
633    If no error is raised, the Op outputs the value of the variable before
634    the increment.
635
636    This is essentially a shortcut for `count_up_to(self, limit)`.
637
638    Args:
639      limit: value at which incrementing the variable raises an error.
640
641    Returns:
642      A `Tensor` that will hold the variable value before the increment. If no
643      other Op modifies this variable, the values produced will all be
644      distinct.
645    """
646    return gen_state_ops.resource_count_up_to(
647        self.handle, limit=limit, T=self.dtype)
648
649  def _map_resources(self, save_options):
650    """For implementing `Trackable`."""
651    new_variable = None
652    if save_options.experimental_variable_policy._save_variable_devices():  # pylint:disable=protected-access
653      with ops.device(self.device):
654        new_variable = copy_to_graph_uninitialized(self)
655    else:
656      new_variable = copy_to_graph_uninitialized(self)
657    obj_map = {self: new_variable}
658    resource_map = {self._handle: new_variable.handle}
659    return obj_map, resource_map
660
661  def _read_variable_op(self):
662    variable_accessed(self)
663
664    def read_and_set_handle():
665      result = gen_resource_variable_ops.read_variable_op(
666          self._handle, self._dtype)
667      _maybe_set_handle_data(self._dtype, self._handle, result)
668      return result
669
670    if getattr(self, "_caching_device", None) is not None:
671      with ops.colocate_with(None, ignore_existing=True):
672        with ops.device(self._caching_device):
673          result = read_and_set_handle()
674    else:
675      result = read_and_set_handle()
676
677    if not context.executing_eagerly():
678      # Note that if a control flow context is active the input of the read op
679      # might not actually be the handle. This line bypasses it.
680      tape.record_operation(
681          "ReadVariableOp", [result], [self._handle],
682          backward_function=lambda x: [x],
683          forward_function=lambda x: [x])
684    return result
685
686  def read_value(self):
687    """Constructs an op which reads the value of this variable.
688
689    Should be used when there are multiple reads, or when it is desirable to
690    read the value only after some condition is true.
691
692    Returns:
693     the read operation.
694    """
695    with ops.name_scope("Read"):
696      value = self._read_variable_op()
697    # Return an identity so it can get placed on whatever device the context
698    # specifies instead of the device where the variable is.
699    return array_ops.identity(value)
700
701  def sparse_read(self, indices, name=None):
702    """Reads the value of this variable sparsely, using `gather`."""
703    with ops.name_scope("Gather" if name is None else name) as name:
704      variable_accessed(self)
705      value = gen_resource_variable_ops.resource_gather(
706          self._handle, indices, dtype=self._dtype, name=name)
707
708      if self._dtype == dtypes.variant:
709        # For DT_VARIANT types, the handle's shape_and_type[1:] stores the
710        # variant's handle data.  Extract it.
711        handle_data = get_eager_safe_handle_data(self._handle)
712        if handle_data.is_set and len(handle_data.shape_and_type) > 1:
713          value._handle_data = (  # pylint: disable=protected-access
714              cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData(
715                  is_set=True, shape_and_type=handle_data.shape_and_type[1:]))
716
717    return array_ops.identity(value)
718
719  def gather_nd(self, indices, name=None):
720    """Reads the value of this variable sparsely, using `gather_nd`."""
721    with ops.name_scope("GatherNd" if name is None else name) as name:
722      if self.trainable:
723        variable_accessed(self)
724      value = gen_resource_variable_ops.resource_gather_nd(
725          self._handle, indices, dtype=self._dtype, name=name)
726
727    return array_ops.identity(value)
728
729  def to_proto(self, export_scope=None):
730    """Converts a `ResourceVariable` to a `VariableDef` protocol buffer.
731
732    Args:
733      export_scope: Optional `string`. Name scope to remove.
734
735    Raises:
736      RuntimeError: If run in EAGER mode.
737
738    Returns:
739      A `VariableDef` protocol buffer, or `None` if the `Variable` is not
740      in the specified name scope.
741    """
742    if context.executing_eagerly():
743      raise RuntimeError("to_proto not supported in EAGER mode.")
744    if export_scope is None or self.handle.name.startswith(export_scope):
745      var_def = variable_pb2.VariableDef()
746      var_def.variable_name = ops.strip_name_scope(self.handle.name,
747                                                   export_scope)
748      if self._initial_value is not None:
749        # This is inside an if-statement for backwards compatibility, since
750        # self._initial_value might be None for variables constructed from old
751        # protos.
752        var_def.initial_value_name = ops.strip_name_scope(
753            self._initial_value.name, export_scope)
754      var_def.initializer_name = ops.strip_name_scope(self.initializer.name,
755                                                      export_scope)
756      if self._cached_value is not None:
757        var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name,
758                                                     export_scope)
759      else:
760        # Store the graph_element here
761        var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name,
762                                                     export_scope)
763      var_def.is_resource = True
764      var_def.trainable = self.trainable
765      var_def.synchronization = self.synchronization.value
766      var_def.aggregation = self.aggregation.value
767      if self._save_slice_info:
768        var_def.save_slice_info_def.MergeFrom(
769            self._save_slice_info.to_proto(export_scope=export_scope))
770      return var_def
771    else:
772      return None
773
774  @staticmethod
775  def from_proto(variable_def, import_scope=None):
776    if context.executing_eagerly():
777      raise RuntimeError("from_proto not supported in EAGER mode.")
778    return ResourceVariable(
779        variable_def=variable_def, import_scope=import_scope)
780
781  __array_priority__ = 100
782
783  def is_initialized(self, name=None):
784    """Checks whether a resource variable has been initialized.
785
786    Outputs boolean scalar indicating whether the tensor has been initialized.
787
788    Args:
789      name: A name for the operation (optional).
790
791    Returns:
792      A `Tensor` of type `bool`.
793    """
794    # TODO(b/169792703): The current device placement logic never overrides an
795    # explicit placement with a custom device, causing `v.is_initalized()` to
796    # fail under a non-custom device context if `v` is in a custom device. The
797    # explicit placement below makes this work, but should not be necessary once
798    # the logic is updated to handle cases like this.
799    with ops.device(self.device):
800      return gen_resource_variable_ops.var_is_initialized_op(self.handle, name)
801
802  def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
803    """Subtracts a value from this variable.
804
805    Args:
806      delta: A `Tensor`. The value to subtract from this variable.
807      use_locking: If `True`, use locking during the operation.
808      name: The name to use for the operation.
809      read_value: A `bool`. Whether to read and return the new value of the
810        variable or not.
811
812    Returns:
813      If `read_value` is `True`, this method will return the new value of the
814      variable after the assignment has completed. Otherwise, when in graph mode
815      it will return the `Operation` that does the assignment, and when in eager
816      mode it will return `None`.
817    """
818    # TODO(apassos): this here and below is not atomic. Consider making it
819    # atomic if there's a way to do so without a performance cost for those who
820    # don't need it.
821    with _handle_graph(self.handle), self._assign_dependencies():
822      assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(
823          self.handle,
824          ops.convert_to_tensor(delta, dtype=self.dtype),
825          name=name)
826    if read_value:
827      return self._lazy_read(assign_sub_op)
828    return assign_sub_op
829
830  def assign_add(self, delta, use_locking=None, name=None, read_value=True):
831    """Adds a value to this variable.
832
833    Args:
834      delta: A `Tensor`. The value to add to this variable.
835      use_locking: If `True`, use locking during the operation.
836      name: The name to use for the operation.
837      read_value: A `bool`. Whether to read and return the new value of the
838        variable or not.
839
840    Returns:
841      If `read_value` is `True`, this method will return the new value of the
842      variable after the assignment has completed. Otherwise, when in graph mode
843      it will return the `Operation` that does the assignment, and when in eager
844      mode it will return `None`.
845    """
846    with _handle_graph(self.handle), self._assign_dependencies():
847      assign_add_op = gen_resource_variable_ops.assign_add_variable_op(
848          self.handle,
849          ops.convert_to_tensor(delta, dtype=self.dtype),
850          name=name)
851    if read_value:
852      return self._lazy_read(assign_add_op)
853    return assign_add_op
854
855  def _lazy_read(self, op):
856    variable_accessed(self)
857    return _UnreadVariable(
858        handle=self._handle,
859        dtype=self.dtype,
860        shape=self._shape,
861        in_graph_mode=self._in_graph_mode,
862        deleter=self._handle_deleter if not self._in_graph_mode else None,
863        parent_op=op,
864        unique_id=self._unique_id)
865
866  def assign(self, value, use_locking=None, name=None, read_value=True):
867    """Assigns a new value to this variable.
868
869    Args:
870      value: A `Tensor`. The new value for this variable.
871      use_locking: If `True`, use locking during the assignment.
872      name: The name to use for the assignment.
873      read_value: A `bool`. Whether to read and return the new value of the
874        variable or not.
875
876    Returns:
877      If `read_value` is `True`, this method will return the new value of the
878      variable after the assignment has completed. Otherwise, when in graph mode
879      it will return the `Operation` that does the assignment, and when in eager
880      mode it will return `None`.
881    """
882    # Note: not depending on the cached value here since this can be used to
883    # initialize the variable.
884    with _handle_graph(self.handle):
885      value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
886      if not self._shape.is_compatible_with(value_tensor.shape):
887        if self.name is None:
888          tensor_name = ""
889        else:
890          tensor_name = " " + str(self.name)
891        raise ValueError(
892            ("Cannot assign to variable%s due to variable shape %s and value "
893             "shape %s are incompatible") %
894            (tensor_name, self._shape, value_tensor.shape))
895      assign_op = gen_resource_variable_ops.assign_variable_op(
896          self.handle, value_tensor, name=name)
897      if read_value:
898        return self._lazy_read(assign_op)
899    return assign_op
900
901  def __reduce__(self):
902    # The implementation mirrors that of __deepcopy__.
903    return functools.partial(
904        ResourceVariable,
905        initial_value=self.numpy(),
906        trainable=self.trainable,
907        name=self._shared_name,
908        dtype=self.dtype,
909        constraint=self.constraint,
910        distribute_strategy=self._distribute_strategy), ()
911
912  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
913    """Subtracts `tf.IndexedSlices` from this variable.
914
915    Args:
916      sparse_delta: `tf.IndexedSlices` to be subtracted from this variable.
917      use_locking: If `True`, use locking during the operation.
918      name: the name of the operation.
919
920    Returns:
921      The updated variable.
922
923    Raises:
924      TypeError: if `sparse_delta` is not an `IndexedSlices`.
925    """
926    if not isinstance(sparse_delta, ops.IndexedSlices):
927      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
928    return self._lazy_read(
929        gen_resource_variable_ops.resource_scatter_sub(
930            self.handle,
931            sparse_delta.indices,
932            ops.convert_to_tensor(sparse_delta.values, self.dtype),
933            name=name))
934
935  def scatter_add(self, sparse_delta, use_locking=False, name=None):
936    """Adds `tf.IndexedSlices` to this variable.
937
938    Args:
939      sparse_delta: `tf.IndexedSlices` to be added to this variable.
940      use_locking: If `True`, use locking during the operation.
941      name: the name of the operation.
942
943    Returns:
944      The updated variable.
945
946    Raises:
947      TypeError: if `sparse_delta` is not an `IndexedSlices`.
948    """
949    if not isinstance(sparse_delta, ops.IndexedSlices):
950      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
951    return self._lazy_read(
952        gen_resource_variable_ops.resource_scatter_add(
953            self.handle,
954            sparse_delta.indices,
955            ops.convert_to_tensor(sparse_delta.values, self.dtype),
956            name=name))
957
958  def scatter_max(self, sparse_delta, use_locking=False, name=None):
959    """Updates this variable with the max of `tf.IndexedSlices` and itself.
960
961    Args:
962      sparse_delta: `tf.IndexedSlices` to use as an argument of max with this
963        variable.
964      use_locking: If `True`, use locking during the operation.
965      name: the name of the operation.
966
967    Returns:
968      The updated variable.
969
970    Raises:
971      TypeError: if `sparse_delta` is not an `IndexedSlices`.
972    """
973    if not isinstance(sparse_delta, ops.IndexedSlices):
974      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
975    return self._lazy_read(
976        gen_resource_variable_ops.resource_scatter_max(
977            self.handle,
978            sparse_delta.indices,
979            ops.convert_to_tensor(sparse_delta.values, self.dtype),
980            name=name))
981
982  def scatter_min(self, sparse_delta, use_locking=False, name=None):
983    """Updates this variable with the min of `tf.IndexedSlices` and itself.
984
985    Args:
986      sparse_delta: `tf.IndexedSlices` to use as an argument of min with this
987        variable.
988      use_locking: If `True`, use locking during the operation.
989      name: the name of the operation.
990
991    Returns:
992      The updated variable.
993
994    Raises:
995      TypeError: if `sparse_delta` is not an `IndexedSlices`.
996    """
997    if not isinstance(sparse_delta, ops.IndexedSlices):
998      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
999    return self._lazy_read(
1000        gen_resource_variable_ops.resource_scatter_min(
1001            self.handle,
1002            sparse_delta.indices,
1003            ops.convert_to_tensor(sparse_delta.values, self.dtype),
1004            name=name))
1005
1006  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
1007    """Multiply this variable by `tf.IndexedSlices`.
1008
1009    Args:
1010      sparse_delta: `tf.IndexedSlices` to multiply this variable by.
1011      use_locking: If `True`, use locking during the operation.
1012      name: the name of the operation.
1013
1014    Returns:
1015      The updated variable.
1016
1017    Raises:
1018      TypeError: if `sparse_delta` is not an `IndexedSlices`.
1019    """
1020    if not isinstance(sparse_delta, ops.IndexedSlices):
1021      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
1022    return self._lazy_read(
1023        gen_resource_variable_ops.resource_scatter_mul(
1024            self.handle,
1025            sparse_delta.indices,
1026            ops.convert_to_tensor(sparse_delta.values, self.dtype),
1027            name=name))
1028
1029  def scatter_div(self, sparse_delta, use_locking=False, name=None):
1030    """Divide this variable by `tf.IndexedSlices`.
1031
1032    Args:
1033      sparse_delta: `tf.IndexedSlices` to divide this variable by.
1034      use_locking: If `True`, use locking during the operation.
1035      name: the name of the operation.
1036
1037    Returns:
1038      The updated variable.
1039
1040    Raises:
1041      TypeError: if `sparse_delta` is not an `IndexedSlices`.
1042    """
1043    if not isinstance(sparse_delta, ops.IndexedSlices):
1044      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
1045    return self._lazy_read(
1046        gen_resource_variable_ops.resource_scatter_div(
1047            self.handle,
1048            sparse_delta.indices,
1049            ops.convert_to_tensor(sparse_delta.values, self.dtype),
1050            name=name))
1051
1052  def scatter_update(self, sparse_delta, use_locking=False, name=None):
1053    """Assigns `tf.IndexedSlices` to this variable.
1054
1055    Args:
1056      sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
1057      use_locking: If `True`, use locking during the operation.
1058      name: the name of the operation.
1059
1060    Returns:
1061      The updated variable.
1062
1063    Raises:
1064      TypeError: if `sparse_delta` is not an `IndexedSlices`.
1065    """
1066    if not isinstance(sparse_delta, ops.IndexedSlices):
1067      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
1068    return self._lazy_read(
1069        gen_resource_variable_ops.resource_scatter_update(
1070            self.handle,
1071            sparse_delta.indices,
1072            ops.convert_to_tensor(sparse_delta.values, self.dtype),
1073            name=name))
1074
1075  def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
1076    """Assigns `tf.IndexedSlices` to this variable batch-wise.
1077
1078    Analogous to `batch_gather`. This assumes that this variable and the
1079    sparse_delta IndexedSlices have a series of leading dimensions that are the
1080    same for all of them, and the updates are performed on the last dimension of
1081    indices. In other words, the dimensions should be the following:
1082
1083    `num_prefix_dims = sparse_delta.indices.ndims - 1`
1084    `batch_dim = num_prefix_dims + 1`
1085    `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[
1086         batch_dim:]`
1087
1088    where
1089
1090    `sparse_delta.updates.shape[:num_prefix_dims]`
1091    `== sparse_delta.indices.shape[:num_prefix_dims]`
1092    `== var.shape[:num_prefix_dims]`
1093
1094    And the operation performed can be expressed as:
1095
1096    `var[i_1, ..., i_n,
1097         sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[
1098            i_1, ..., i_n, j]`
1099
1100    When sparse_delta.indices is a 1D tensor, this operation is equivalent to
1101    `scatter_update`.
1102
1103    To avoid this operation one can looping over the first `ndims` of the
1104    variable and using `scatter_update` on the subtensors that result of slicing
1105    the first dimension. This is a valid option for `ndims = 1`, but less
1106    efficient than this implementation.
1107
1108    Args:
1109      sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
1110      use_locking: If `True`, use locking during the operation.
1111      name: the name of the operation.
1112
1113    Returns:
1114      The updated variable.
1115
1116    Raises:
1117      TypeError: if `sparse_delta` is not an `IndexedSlices`.
1118    """
1119    if not isinstance(sparse_delta, ops.IndexedSlices):
1120      raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
1121    return self._lazy_read(
1122        state_ops.batch_scatter_update(
1123            self,
1124            sparse_delta.indices,
1125            sparse_delta.values,
1126            use_locking=use_locking,
1127            name=name))
1128
1129  def scatter_nd_sub(self, indices, updates, name=None):
1130    """Applies sparse subtraction to individual values or slices in a Variable.
1131
1132    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1133
1134    `indices` must be integer tensor, containing indices into `ref`.
1135    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1136
1137    The innermost dimension of `indices` (with length `K`) corresponds to
1138    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1139    dimension of `ref`.
1140
1141    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1142
1143    ```
1144    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1145    ```
1146
1147    For example, say we want to add 4 scattered elements to a rank-1 tensor to
1148    8 elements. In Python, that update would look like this:
1149
1150    ```python
1151        ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
1152        indices = tf.constant([[4], [3], [1] ,[7]])
1153        updates = tf.constant([9, 10, 11, 12])
1154        op = ref.scatter_nd_sub(indices, updates)
1155        with tf.compat.v1.Session() as sess:
1156          print sess.run(op)
1157    ```
1158
1159    The resulting update to ref would look like this:
1160
1161        [1, -9, 3, -6, -6, 6, 7, -4]
1162
1163    See `tf.scatter_nd` for more details about how to make updates to
1164    slices.
1165
1166    Args:
1167      indices: The indices to be used in the operation.
1168      updates: The values to be used in the operation.
1169      name: the name of the operation.
1170
1171    Returns:
1172      The updated variable.
1173    """
1174    return self._lazy_read(
1175        gen_state_ops.resource_scatter_nd_sub(
1176            self.handle,
1177            indices,
1178            ops.convert_to_tensor(updates, self.dtype),
1179            name=name))
1180
1181  def scatter_nd_add(self, indices, updates, name=None):
1182    """Applies sparse addition to individual values or slices in a Variable.
1183
1184    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1185
1186    `indices` must be integer tensor, containing indices into `ref`.
1187    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1188
1189    The innermost dimension of `indices` (with length `K`) corresponds to
1190    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1191    dimension of `ref`.
1192
1193    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1194
1195    ```
1196    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1197    ```
1198
1199    For example, say we want to add 4 scattered elements to a rank-1 tensor to
1200    8 elements. In Python, that update would look like this:
1201
1202    ```python
1203        ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
1204        indices = tf.constant([[4], [3], [1] ,[7]])
1205        updates = tf.constant([9, 10, 11, 12])
1206        add = ref.scatter_nd_add(indices, updates)
1207        with tf.compat.v1.Session() as sess:
1208          print sess.run(add)
1209    ```
1210
1211    The resulting update to ref would look like this:
1212
1213        [1, 13, 3, 14, 14, 6, 7, 20]
1214
1215    See `tf.scatter_nd` for more details about how to make updates to
1216    slices.
1217
1218    Args:
1219      indices: The indices to be used in the operation.
1220      updates: The values to be used in the operation.
1221      name: the name of the operation.
1222
1223    Returns:
1224      The updated variable.
1225    """
1226    return self._lazy_read(
1227        gen_state_ops.resource_scatter_nd_add(
1228            self.handle,
1229            indices,
1230            ops.convert_to_tensor(updates, self.dtype),
1231            name=name))
1232
1233  def scatter_nd_update(self, indices, updates, name=None):
1234    """Applies sparse assignment to individual values or slices in a Variable.
1235
1236    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1237
1238    `indices` must be integer tensor, containing indices into `ref`.
1239    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1240
1241    The innermost dimension of `indices` (with length `K`) corresponds to
1242    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1243    dimension of `ref`.
1244
1245    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1246
1247    ```
1248    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1249    ```
1250
1251    For example, say we want to add 4 scattered elements to a rank-1 tensor to
1252    8 elements. In Python, that update would look like this:
1253
1254    ```python
1255        ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
1256        indices = tf.constant([[4], [3], [1] ,[7]])
1257        updates = tf.constant([9, 10, 11, 12])
1258        op = ref.scatter_nd_update(indices, updates)
1259        with tf.compat.v1.Session() as sess:
1260          print sess.run(op)
1261    ```
1262
1263    The resulting update to ref would look like this:
1264
1265        [1, 11, 3, 10, 9, 6, 7, 12]
1266
1267    See `tf.scatter_nd` for more details about how to make updates to
1268    slices.
1269
1270    Args:
1271      indices: The indices to be used in the operation.
1272      updates: The values to be used in the operation.
1273      name: the name of the operation.
1274
1275    Returns:
1276      The updated variable.
1277    """
1278    return self._lazy_read(
1279        gen_state_ops.resource_scatter_nd_update(
1280            self.handle,
1281            indices,
1282            ops.convert_to_tensor(updates, self.dtype),
1283            name=name))
1284
1285  def scatter_nd_max(self, indices, updates, name=None):
1286    """Updates this variable with the max of `tf.IndexedSlices` and itself.
1287
1288    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1289
1290    `indices` must be integer tensor, containing indices into `ref`.
1291    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1292
1293    The innermost dimension of `indices` (with length `K`) corresponds to
1294    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1295    dimension of `ref`.
1296
1297    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1298
1299    ```
1300    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1301    ```
1302
1303    See `tf.scatter_nd` for more details about how to make updates to
1304    slices.
1305
1306    Args:
1307      indices: The indices to be used in the operation.
1308      updates: The values to be used in the operation.
1309      name: the name of the operation.
1310
1311    Returns:
1312      The updated variable.
1313    """
1314    return self._lazy_read(
1315        gen_state_ops.resource_scatter_nd_max(
1316            self.handle,
1317            indices,
1318            ops.convert_to_tensor(updates, self.dtype),
1319            name=name))
1320
1321  def scatter_nd_min(self, indices, updates, name=None):
1322    """Updates this variable with the min of `tf.IndexedSlices` and itself.
1323
1324    `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
1325
1326    `indices` must be integer tensor, containing indices into `ref`.
1327    It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
1328
1329    The innermost dimension of `indices` (with length `K`) corresponds to
1330    indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
1331    dimension of `ref`.
1332
1333    `updates` is `Tensor` of rank `Q-1+P-K` with shape:
1334
1335    ```
1336    [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
1337    ```
1338
1339    See `tf.scatter_nd` for more details about how to make updates to
1340    slices.
1341
1342    Args:
1343      indices: The indices to be used in the operation.
1344      updates: The values to be used in the operation.
1345      name: the name of the operation.
1346
1347    Returns:
1348      The updated variable.
1349    """
1350    return self._lazy_read(
1351        gen_state_ops.resource_scatter_nd_min(
1352            self.handle,
1353            indices,
1354            ops.convert_to_tensor(updates, self.dtype),
1355            name=name))
1356
1357  def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask,
1358                            end_mask, ellipsis_mask, new_axis_mask,
1359                            shrink_axis_mask):
1360    with _handle_graph(self.handle), self._assign_dependencies():
1361      return self._lazy_read(
1362          gen_array_ops.resource_strided_slice_assign(
1363              ref=self.handle,
1364              begin=begin,
1365              end=end,
1366              strides=strides,
1367              value=ops.convert_to_tensor(value, dtype=self.dtype),
1368              name=name,
1369              begin_mask=begin_mask,
1370              end_mask=end_mask,
1371              ellipsis_mask=ellipsis_mask,
1372              new_axis_mask=new_axis_mask,
1373              shrink_axis_mask=shrink_axis_mask))
1374
1375  def __complex__(self):
1376    return complex(self.value().numpy())
1377
1378  def __int__(self):
1379    return int(self.value().numpy())
1380
1381  def __long__(self):
1382    return long(self.value().numpy())
1383
1384  def __float__(self):
1385    return float(self.value().numpy())
1386
1387  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
1388    del name
1389    if dtype is not None and not dtype.is_compatible_with(self.dtype):
1390      raise ValueError(
1391          "Incompatible type conversion requested to type {!r} for variable "
1392          "of type {!r}".format(dtype.name, self.dtype.name))
1393    if as_ref:
1394      return self.read_value().op.inputs[0]
1395    else:
1396      return self.value()
1397
1398  def __iadd__(self, unused_other):
1399    raise RuntimeError("Variable += value not supported. Use "
1400                       "variable.assign_add(value) to modify the variable "
1401                       "value and variable = variable + value to get a new "
1402                       "Tensor object.")
1403
1404  def __isub__(self, unused_other):
1405    raise RuntimeError("Variable -= value not supported. Use "
1406                       "variable.assign_sub(value) to modify the variable "
1407                       "value and variable = variable - value to get a new "
1408                       "Tensor object.")
1409
1410  def __imul__(self, unused_other):
1411    raise RuntimeError("Variable *= value not supported. Use "
1412                       "`var.assign(var * value)` to modify the variable or "
1413                       "`var = var * value` to get a new Tensor object.")
1414
1415  def __idiv__(self, unused_other):
1416    raise RuntimeError("Variable /= value not supported. Use "
1417                       "`var.assign(var / value)` to modify the variable or "
1418                       "`var = var / value` to get a new Tensor object.")
1419
1420  def __itruediv__(self, unused_other):
1421    raise RuntimeError("Variable /= value not supported. Use "
1422                       "`var.assign(var / value)` to modify the variable or "
1423                       "`var = var / value` to get a new Tensor object.")
1424
1425  def __irealdiv__(self, unused_other):
1426    raise RuntimeError("Variable /= value not supported. Use "
1427                       "`var.assign(var / value)` to modify the variable or "
1428                       "`var = var / value` to get a new Tensor object.")
1429
1430  def __ipow__(self, unused_other):
1431    raise RuntimeError("Variable **= value not supported. Use "
1432                       "`var.assign(var ** value)` to modify the variable or "
1433                       "`var = var ** value` to get a new Tensor object.")
1434
1435
1436class ResourceVariable(BaseResourceVariable):
1437  """Variable based on resource handles.
1438
1439  See the [Variables How To](https://tensorflow.org/guide/variables)
1440  for a high level overview.
1441
1442  A `ResourceVariable` allows you to maintain state across subsequent calls to
1443  session.run.
1444
1445  The `ResourceVariable` constructor requires an initial value for the variable,
1446  which can be a `Tensor` of any type and shape. The initial value defines the
1447  type and shape of the variable. After construction, the type and shape of
1448  the variable are fixed. The value can be changed using one of the assign
1449  methods.
1450
1451  Just like any `Tensor`, variables created with
1452  `tf.Variable(use_resource=True)` can be used as inputs for other Ops in the
1453  graph. Additionally, all the operators overloaded for the `Tensor` class are
1454  carried over to variables, so you can also add nodes to the graph by just
1455  doing arithmetic on variables.
1456
1457  Unlike ref-based variable, a ResourceVariable has well-defined semantics. Each
1458  usage of a ResourceVariable in a TensorFlow graph adds a read_value operation
1459  to the graph. The Tensors returned by a read_value operation are guaranteed to
1460  see all modifications to the value of the variable which happen in any
1461  operation on which the read_value depends on (either directly, indirectly, or
1462  via a control dependency) and guaranteed to not see any modification to the
1463  value of the variable from operations that depend on the read_value operation.
1464  Updates from operations that have no dependency relationship to the read_value
1465  operation might or might not be visible to read_value.
1466
1467  For example, if there is more than one assignment to a ResourceVariable in
1468  a single session.run call there is a well-defined value for each operation
1469  which uses the variable's value if the assignments and the read are connected
1470  by edges in the graph. Consider the following example, in which two writes
1471  can cause tf.Variable and tf.ResourceVariable to behave differently:
1472
1473  ```python
1474  a = tf.Variable(1.0, use_resource=True)
1475  a.initializer.run()
1476
1477  assign = a.assign(2.0)
1478  with tf.control_dependencies([assign]):
1479    b = a.read_value()
1480  with tf.control_dependencies([b]):
1481    other_assign = a.assign(3.0)
1482  with tf.control_dependencies([other_assign]):
1483    # Will print 2.0 because the value was read before other_assign ran. If
1484    # `a` was a tf.Variable instead, 2.0 or 3.0 could be printed.
1485    tf.compat.v1.Print(b, [b]).eval()
1486  ```
1487  """
1488
1489  def __init__(
1490      self,  # pylint: disable=super-init-not-called
1491      initial_value=None,
1492      trainable=None,
1493      collections=None,
1494      validate_shape=True,  # pylint: disable=unused-argument
1495      caching_device=None,
1496      name=None,
1497      dtype=None,
1498      variable_def=None,
1499      import_scope=None,
1500      constraint=None,
1501      distribute_strategy=None,
1502      synchronization=None,
1503      aggregation=None,
1504      shape=None):
1505    """Creates a variable.
1506
1507    Args:
1508      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
1509        which is the initial value for the Variable. Can also be a callable with
1510        no argument that returns the initial value when called. (Note that
1511        initializer functions from init_ops.py must first be bound to a shape
1512        before being used here.)
1513      trainable: If `True`, the default, also adds the variable to the graph
1514        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
1515        the default list of variables to use by the `Optimizer` classes.
1516        Defaults to `True`, unless `synchronization` is set to `ON_READ`, in
1517        which case it defaults to `False`.
1518      collections: List of graph collections keys. The new variable is added to
1519        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1520      validate_shape: Ignored. Provided for compatibility with tf.Variable.
1521      caching_device: Optional device string or function describing where the
1522        Variable should be cached for reading.  Defaults to the Variable's
1523        device.  If not `None`, caches on another device.  Typical use is to
1524        cache on the device where the Ops using the Variable reside, to
1525        deduplicate copying through `Switch` and other conditional statements.
1526      name: Optional name for the variable. Defaults to `'Variable'` and gets
1527        uniquified automatically.
1528      dtype: If set, initial_value will be converted to the given type. If None,
1529        either the datatype will be kept (if initial_value is a Tensor) or
1530        float32 will be used (if it is a Python object convertible to a Tensor).
1531      variable_def: `VariableDef` protocol buffer. If not None, recreates the
1532        `ResourceVariable` object with its contents. `variable_def` and other
1533        arguments (except for import_scope) are mutually exclusive.
1534      import_scope: Optional `string`. Name scope to add to the
1535        ResourceVariable. Only used when `variable_def` is provided.
1536      constraint: An optional projection function to be applied to the variable
1537        after being updated by an `Optimizer` (e.g. used to implement norm
1538        constraints or value constraints for layer weights). The function must
1539        take as input the unprojected Tensor representing the value of the
1540        variable and return the Tensor for the projected value (which must have
1541        the same shape). Constraints are not safe to use when doing asynchronous
1542        distributed training.
1543      distribute_strategy: The tf.distribute.Strategy this variable is being
1544        created inside of.
1545      synchronization: Indicates when a distributed a variable will be
1546        aggregated. Accepted values are constants defined in the class
1547        `tf.VariableSynchronization`. By default the synchronization is set to
1548        `AUTO` and the current `DistributionStrategy` chooses when to
1549        synchronize.
1550      aggregation: Indicates how a distributed variable will be aggregated.
1551        Accepted values are constants defined in the class
1552        `tf.VariableAggregation`.
1553      shape: (optional) The shape of this variable. If None, the shape of
1554        `initial_value` will be used. When setting this argument to
1555        `tf.TensorShape(None)` (representing an unspecified shape), the variable
1556        can be assigned with values of different shapes.
1557
1558    Raises:
1559      ValueError: If the initial value is not specified, or does not have a
1560        shape and `validate_shape` is `True`.
1561
1562    @compatibility(eager)
1563    When Eager Execution is enabled, the default for the `collections` argument
1564    is `None`, which signifies that this `Variable` will not be added to any
1565    collections.
1566    @end_compatibility
1567    """
1568    if variable_def:
1569      if initial_value is not None:
1570        raise ValueError("variable_def and initial_value are mutually "
1571                         "exclusive.")
1572      if context.executing_eagerly():
1573        raise ValueError("Creating ResourceVariable from variable_def is "
1574                         "not supported when eager execution is enabled.")
1575      self._init_from_proto(variable_def, import_scope=import_scope)
1576    else:
1577      self._init_from_args(
1578          initial_value=initial_value,
1579          trainable=trainable,
1580          collections=collections,
1581          caching_device=caching_device,
1582          name=name,
1583          dtype=dtype,
1584          constraint=constraint,
1585          synchronization=synchronization,
1586          aggregation=aggregation,
1587          shape=shape,
1588          distribute_strategy=distribute_strategy)
1589
1590  def _init_from_args(self,
1591                      initial_value=None,
1592                      trainable=None,
1593                      collections=None,
1594                      caching_device=None,
1595                      name=None,
1596                      dtype=None,
1597                      constraint=None,
1598                      synchronization=None,
1599                      aggregation=None,
1600                      distribute_strategy=None,
1601                      shape=None):
1602    """Creates a variable.
1603
1604    Args:
1605      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
1606        which is the initial value for the Variable. The initial value must have
1607        a shape specified unless `validate_shape` is set to False. Can also be a
1608        callable with no argument that returns the initial value when called.
1609        (Note that initializer functions from init_ops.py must first be bound to
1610        a shape before being used here.)
1611      trainable: If `True`, the default, also adds the variable to the graph
1612        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
1613        the default list of variables to use by the `Optimizer` classes.
1614        Defaults to `True`, unless `synchronization` is set to `ON_READ`, in
1615        which case it defaults to `False`.
1616      collections: List of graph collections keys. The new variable is added to
1617        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
1618      caching_device: Optional device string or function describing where the
1619        Variable should be cached for reading.  Defaults to the Variable's
1620        device.  If not `None`, caches on another device.  Typical use is to
1621        cache on the device where the Ops using the Variable reside, to
1622        deduplicate copying through `Switch` and other conditional statements.
1623      name: Optional name for the variable. Defaults to `'Variable'` and gets
1624        uniquified automatically.
1625      dtype: If set, initial_value will be converted to the given type. If None,
1626        either the datatype will be kept (if initial_value is a Tensor) or
1627        float32 will be used (if it is a Python object convertible to a Tensor).
1628      constraint: An optional projection function to be applied to the variable
1629        after being updated by an `Optimizer` (e.g. used to implement norm
1630        constraints or value constraints for layer weights). The function must
1631        take as input the unprojected Tensor representing the value of the
1632        variable and return the Tensor for the projected value (which must have
1633        the same shape). Constraints are not safe to use when doing asynchronous
1634        distributed training.
1635      synchronization: Indicates when a distributed a variable will be
1636        aggregated. Accepted values are constants defined in the class
1637        `tf.VariableSynchronization`. By default the synchronization is set to
1638        `AUTO` and the current `DistributionStrategy` chooses when to
1639        synchronize.
1640      aggregation: Indicates how a distributed variable will be aggregated.
1641        Accepted values are constants defined in the class
1642        `tf.VariableAggregation`.
1643      distribute_strategy: DistributionStrategy under which this variable was
1644        created.
1645      shape: (optional) The shape of this variable. If None, the shape of
1646        `initial_value` will be used. When setting this argument to
1647        `tf.TensorShape(None)` (representing an unspecified shape), the variable
1648        can be assigned with values of different shapes.
1649
1650    Raises:
1651      ValueError: If the initial value is not specified, or does not have a
1652        shape and `validate_shape` is `True`.
1653
1654    @compatibility(eager)
1655    When Eager Execution is enabled, variables are never added to collections.
1656    It is not implicitly added to the `GLOBAL_VARIABLES` or
1657    `TRAINABLE_VARIABLES` collections, and the `collections` argument is
1658    ignored.
1659    @end_compatibility
1660    """
1661    synchronization, aggregation, trainable = (
1662        variables.validate_synchronization_aggregation_trainable(
1663            synchronization, aggregation, trainable, name))
1664    if initial_value is None:
1665      raise ValueError("initial_value must be specified.")
1666    init_from_fn = callable(initial_value)
1667
1668    if isinstance(initial_value, ops.Tensor) and hasattr(
1669        initial_value, "graph") and initial_value.graph.building_function:
1670      raise ValueError("Tensor-typed variable initializers must either be "
1671                       "wrapped in an init_scope or callable "
1672                       "(e.g., `tf.Variable(lambda : "
1673                       "tf.truncated_normal([10, 40]))`) when building "
1674                       "functions. Please file a feature request if this "
1675                       "restriction inconveniences you.")
1676
1677    if collections is None:
1678      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
1679    if not isinstance(collections, (list, tuple, set)):
1680      raise ValueError(
1681          "collections argument to Variable constructor must be a list, tuple, "
1682          "or set. Got %s of type %s" % (collections, type(collections)))
1683    if constraint is not None and not callable(constraint):
1684      raise ValueError("The `constraint` argument must be a callable.")
1685
1686    if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
1687      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
1688    with ops.init_scope():
1689      self._in_graph_mode = not context.executing_eagerly()
1690      with ops.name_scope(
1691          name,
1692          "Variable", [] if init_from_fn else [initial_value],
1693          skip_on_eager=False) as name:
1694        # pylint: disable=protected-access
1695        handle_name = ops.name_from_scope_name(name)
1696        if self._in_graph_mode:
1697          shared_name = handle_name
1698          unique_id = shared_name
1699        else:
1700          # When in eager mode use a uid for the shared_name, to prevent
1701          # accidental sharing.
1702          unique_id = "%s_%d" % (handle_name, ops.uid())
1703          shared_name = None  # Never shared
1704        # Use attr_scope and device(None) to simulate the behavior of
1705        # colocate_with when the variable we want to colocate with doesn't
1706        # yet exist.
1707        device_context_manager = (
1708            ops.device if self._in_graph_mode else ops.NullContextmanager)
1709        attr = attr_value_pb2.AttrValue(
1710            list=attr_value_pb2.AttrValue.ListValue(
1711                s=[compat.as_bytes("loc:@%s" % handle_name)]))
1712        with ops.get_default_graph()._attr_scope({"_class": attr}):
1713          with ops.name_scope("Initializer"), device_context_manager(None):
1714            if init_from_fn:
1715              initial_value = initial_value()
1716            if isinstance(initial_value, trackable.CheckpointInitialValue):
1717              self._maybe_initialize_trackable()
1718              self._update_uid = initial_value.checkpoint_position.restore_uid
1719              initial_value = initial_value.wrapped_value
1720            initial_value = ops.convert_to_tensor(initial_value,
1721                                                  name="initial_value",
1722                                                  dtype=dtype)
1723          if shape is not None:
1724            if not initial_value.shape.is_compatible_with(shape):
1725              raise ValueError(
1726                  "The initial value's shape (%s) is not compatible with "
1727                  "the explicitly supplied `shape` argument (%s)." %
1728                  (initial_value.shape, shape))
1729          else:
1730            shape = initial_value.shape
1731          handle = eager_safe_variable_handle(
1732              initial_value=initial_value,
1733              shape=shape,
1734              shared_name=shared_name,
1735              name=name,
1736              graph_mode=self._in_graph_mode)
1737        # pylint: disable=protected-access
1738        if (self._in_graph_mode and initial_value is not None and
1739            initial_value.op._get_control_flow_context() is not None):
1740          raise ValueError(
1741              "Initializer for variable %s is from inside a control-flow "
1742              "construct, such as a loop or conditional. When creating a "
1743              "variable inside a loop or conditional, use a lambda as the "
1744              "initializer." % name)
1745        # pylint: enable=protected-access
1746        dtype = initial_value.dtype.base_dtype
1747
1748        if self._in_graph_mode:
1749          with ops.name_scope("IsInitialized"):
1750            is_initialized_op = (
1751                gen_resource_variable_ops.var_is_initialized_op(handle))
1752          if initial_value is not None:
1753            # pylint: disable=g-backslash-continuation
1754            with ops.name_scope("Assign") as n, \
1755                 ops.colocate_with(None, ignore_existing=True), \
1756                 ops.device(handle.device):
1757              # pylint: disable=protected-access
1758              initializer_op = (
1759                  gen_resource_variable_ops.assign_variable_op(
1760                      handle,
1761                      variables._try_guard_against_uninitialized_dependencies(
1762                          name, initial_value),
1763                      name=n))
1764              # pylint: enable=protected-access
1765            # pylint: enable=g-backslash-continuation
1766          with ops.name_scope("Read"):
1767            # Manually assign reads to the handle's device to avoid log
1768            # messages.
1769            with ops.device(handle.device):
1770              value = gen_resource_variable_ops.read_variable_op(handle, dtype)
1771              _maybe_set_handle_data(dtype, handle, value)
1772            graph_element = value
1773            if caching_device is not None:
1774              # Variables may be created in a tf.device() or ops.colocate_with()
1775              # context. At the same time, users would expect caching device to
1776              # be independent of this context, and/or would not expect the
1777              # current device context to be merged with the caching device
1778              # spec.  Therefore we reset the colocation stack before creating
1779              # the cached value. Note that resetting the colocation stack will
1780              # also reset the device stack.
1781              with ops.colocate_with(None, ignore_existing=True):
1782                with ops.device(caching_device):
1783                  cached_value = array_ops.identity(value)
1784            else:
1785              cached_value = None
1786        else:
1787          gen_resource_variable_ops.assign_variable_op(handle, initial_value)
1788          is_initialized_op = None
1789          initializer_op = None
1790          graph_element = None
1791          if caching_device:
1792            with ops.device(caching_device):
1793              cached_value = gen_resource_variable_ops.read_variable_op(
1794                  handle, dtype)
1795              _maybe_set_handle_data(dtype, handle, cached_value)
1796          else:
1797            cached_value = None
1798
1799        if cached_value is not None:
1800          # Store the variable object so that the original variable can be
1801          # accessed to generate functions that are compatible with SavedModel.
1802          cached_value._cached_variable = weakref.ref(self)  # pylint: disable=protected-access
1803
1804        if not context.executing_eagerly():
1805          # Eager variables are only added to collections if they are part of an
1806          # eager variable store (otherwise in an interactive session they would
1807          # hog memory and cause OOM). This is done in ops/variable_scope.py.
1808          ops.add_to_collections(collections, self)
1809        elif ops.GraphKeys.GLOBAL_STEP in collections:
1810          ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)
1811      initial_value = initial_value if self._in_graph_mode else None
1812      super(ResourceVariable, self).__init__(
1813          trainable=trainable,
1814          shape=shape,
1815          dtype=dtype,
1816          handle=handle,
1817          synchronization=synchronization,
1818          constraint=constraint,
1819          aggregation=aggregation,
1820          distribute_strategy=distribute_strategy,
1821          name=name,
1822          unique_id=unique_id,
1823          handle_name=handle_name,
1824          graph_element=graph_element,
1825          initial_value=initial_value,
1826          initializer_op=initializer_op,
1827          is_initialized_op=is_initialized_op,
1828          cached_value=cached_value,
1829          caching_device=caching_device)
1830
1831  def _init_from_proto(self, variable_def, import_scope=None):
1832    """Initializes from `VariableDef` proto."""
1833    # Note that init_from_proto is currently not supported in Eager mode.
1834    assert not context.executing_eagerly()
1835    self._in_graph_mode = True
1836    assert isinstance(variable_def, variable_pb2.VariableDef)
1837    if not variable_def.is_resource:
1838      raise ValueError("Trying to restore Variable as ResourceVariable.")
1839
1840    # Create from variable_def.
1841    g = ops.get_default_graph()
1842    self._handle = g.as_graph_element(
1843        ops.prepend_name_scope(
1844            variable_def.variable_name, import_scope=import_scope))
1845    self._shape = tensor_shape.TensorShape(self._handle.op.get_attr("shape"))
1846    self._handle_name = self._handle.name
1847    self._unique_id = self._handle_name
1848    self._initializer_op = g.as_graph_element(
1849        ops.prepend_name_scope(
1850            variable_def.initializer_name, import_scope=import_scope))
1851    # Check whether initial_value_name exists for backwards compatibility.
1852    if (hasattr(variable_def, "initial_value_name") and
1853        variable_def.initial_value_name):
1854      self._initial_value = g.as_graph_element(
1855          ops.prepend_name_scope(
1856              variable_def.initial_value_name, import_scope=import_scope))
1857    else:
1858      self._initial_value = None
1859    synchronization, aggregation, trainable = (
1860        variables.validate_synchronization_aggregation_trainable(
1861            variable_def.synchronization, variable_def.aggregation,
1862            variable_def.trainable, variable_def.variable_name))
1863    self._synchronization = synchronization
1864    self._aggregation = aggregation
1865    self._trainable = trainable
1866    if variable_def.snapshot_name:
1867      snapshot = g.as_graph_element(
1868          ops.prepend_name_scope(
1869              variable_def.snapshot_name, import_scope=import_scope))
1870      if snapshot.op.type != "ReadVariableOp":
1871        self._cached_value = snapshot
1872      else:
1873        self._cached_value = None
1874      while snapshot.op.type != "ReadVariableOp":
1875        snapshot = snapshot.op.inputs[0]
1876      self._graph_element = snapshot
1877    else:
1878      self._cached_value = None
1879      # Legacy case for protos without the snapshot name; assume it's the
1880      # following.
1881      self._graph_element = g.get_tensor_by_name(self._handle.op.name +
1882                                                 "/Read/ReadVariableOp:0")
1883    if variable_def.HasField("save_slice_info_def"):
1884      self._save_slice_info = variables.Variable.SaveSliceInfo(
1885          save_slice_info_def=variable_def.save_slice_info_def,
1886          import_scope=import_scope)
1887    else:
1888      self._save_slice_info = None
1889    self._caching_device = None
1890    self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
1891    self._constraint = None
1892
1893
1894class UninitializedVariable(BaseResourceVariable):
1895  """A variable with no initializer."""
1896
1897  def __init__(  # pylint: disable=super-init-not-called
1898      self,
1899      trainable=None,
1900      caching_device=None,
1901      name=None,
1902      shape=None,
1903      dtype=None,
1904      constraint=None,
1905      synchronization=None,
1906      aggregation=None,
1907      extra_handle_data=None,
1908      distribute_strategy=None,
1909      **unused_kwargs):
1910    """Creates the variable handle.
1911
1912    Args:
1913      trainable: If `True`, GradientTapes automatically watch uses of this
1914        Variable.
1915      caching_device: Optional device string or function describing where the
1916        Variable should be cached for reading.  Defaults to the Variable's
1917        device.  If not `None`, caches on another device.  Typical use is to
1918        cache on the device where the Ops using the Variable reside, to
1919        deduplicate copying through `Switch` and other conditional statements.
1920      name: Optional name for the variable. Defaults to `'Variable'` and gets
1921        uniquified automatically.
1922      shape: The variable's shape.
1923      dtype: The variable's dtype.
1924      constraint: An optional projection function to be applied to the variable
1925        after being updated by an `Optimizer` (e.g. used to implement norm
1926        constraints or value constraints for layer weights). The function must
1927        take as input the unprojected Tensor representing the value of the
1928        variable and return the Tensor for the projected value (which must have
1929        the same shape). Constraints are not safe to use when doing asynchronous
1930        distributed training.
1931      synchronization: Indicates when a distributed a variable will be
1932        aggregated. Accepted values are constants defined in the class
1933        `tf.VariableSynchronization`. By default the synchronization is set to
1934        `AUTO` and the current `DistributionStrategy` chooses when to
1935        synchronize.
1936      aggregation: Indicates how a distributed variable will be aggregated.
1937        Accepted values are constants defined in the class
1938        `tf.VariableAggregation`.
1939      extra_handle_data: Optional, another resource handle or Tensor with handle
1940        data to merge with `shape` and `dtype`.
1941      distribute_strategy: The tf.distribute.Strategy this variable is being
1942        created inside of.
1943    """
1944    with ops.init_scope():
1945      self._in_graph_mode = not context.executing_eagerly()
1946    with ops.init_scope():
1947      with ops.name_scope(name, "Variable", skip_on_eager=False) as name:
1948        handle_name = ops.name_from_scope_name(name)
1949        if self._in_graph_mode:
1950          shared_name = handle_name
1951          unique_id = shared_name
1952        else:
1953          unique_id = "%s_%d" % (handle_name, ops.uid())
1954          shared_name = None  # Never shared
1955        handle = _variable_handle_from_shape_and_dtype(
1956            shape=shape,
1957            dtype=dtype,
1958            shared_name=shared_name,
1959            name=name,
1960            graph_mode=self._in_graph_mode,
1961            initial_value=extra_handle_data)
1962        if not context.executing_eagerly():
1963          with ops.name_scope("Read"):
1964            # Manually assign reads to the handle's device to avoid log
1965            # messages.
1966            with ops.device(handle.device):
1967              value = gen_resource_variable_ops.read_variable_op(handle, dtype)
1968              _maybe_set_handle_data(dtype, handle, value)
1969            graph_element = value
1970          ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self)
1971          # Do *not* add to TRAINABLE_VARIABLES here, even if self._trainable,
1972          # because retraining or frozen use of imported SavedModels is
1973          # controlled at higher levels of model building.
1974        else:
1975          graph_element = None
1976    super(UninitializedVariable, self).__init__(
1977        distribute_strategy=distribute_strategy,
1978        shape=shape,
1979        dtype=dtype,
1980        unique_id=unique_id,
1981        handle_name=handle_name,
1982        constraint=constraint,
1983        handle=handle,
1984        graph_element=graph_element,
1985        trainable=trainable,
1986        synchronization=synchronization,
1987        aggregation=aggregation)
1988
1989
1990_pywrap_utils.RegisterType("ResourceVariable", ResourceVariable)
1991math_ops._resource_variable_type = ResourceVariable  # pylint: disable=protected-access
1992
1993
1994def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
1995  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
1996
1997
1998# Register a conversion function which reads the value of the variable,
1999# allowing instances of the class to be used as tensors.
2000ops.register_tensor_conversion_function(BaseResourceVariable,
2001                                        _dense_var_to_tensor)
2002
2003
2004class _UnreadVariable(BaseResourceVariable):
2005  """Represents a future for a read of a variable.
2006
2007  Pretends to be the tensor if anyone looks.
2008  """
2009
2010  def __init__(self, handle, dtype, shape, in_graph_mode, deleter, parent_op,
2011               unique_id):
2012    if isinstance(handle, ops.EagerTensor):
2013      handle_name = ""
2014    else:
2015      handle_name = handle.name
2016    # Only create a graph_element if we're in session.run-land as only
2017    # session.run requires a preexisting tensor to evaluate. Otherwise we can
2018    # avoid accidentally reading the variable.
2019    if context.executing_eagerly() or ops.inside_function():
2020      graph_element = None
2021    else:
2022      with ops.control_dependencies([parent_op]):
2023        graph_element = gen_resource_variable_ops.read_variable_op(
2024            handle, dtype)
2025        _maybe_set_handle_data(dtype, handle, graph_element)
2026    super(_UnreadVariable, self).__init__(
2027        handle=handle,
2028        shape=shape,
2029        handle_name=handle_name,
2030        unique_id=unique_id,
2031        dtype=dtype,
2032        handle_deleter=deleter,
2033        graph_element=graph_element)
2034    self._parent_op = parent_op
2035
2036  @property
2037  def name(self):
2038    if self._in_graph_mode:
2039      return self._parent_op.name
2040    else:
2041      return "UnreadVariable"
2042
2043  def value(self):
2044    return self._read_variable_op()
2045
2046  def read_value(self):
2047    return self._read_variable_op()
2048
2049  def _read_variable_op(self):
2050    with ops.control_dependencies([self._parent_op]):
2051      result = gen_resource_variable_ops.read_variable_op(
2052          self._handle, self._dtype)
2053      _maybe_set_handle_data(self._dtype, self._handle, result)
2054      return result
2055
2056  def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
2057    with ops.control_dependencies([self._parent_op]):
2058      return super(_UnreadVariable, self).assign_sub(delta, use_locking, name,
2059                                                     read_value)
2060
2061  def assign_add(self, delta, use_locking=None, name=None, read_value=True):
2062    with ops.control_dependencies([self._parent_op]):
2063      return super(_UnreadVariable, self).assign_add(delta, use_locking, name,
2064                                                     read_value)
2065
2066  def assign(self, value, use_locking=None, name=None, read_value=True):
2067    with ops.control_dependencies([self._parent_op]):
2068      return super(_UnreadVariable, self).assign(value, use_locking, name,
2069                                                 read_value)
2070
2071  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
2072    with ops.control_dependencies([self._parent_op]):
2073      return super(_UnreadVariable, self).scatter_sub(sparse_delta, use_locking,
2074                                                      name)
2075
2076  def scatter_add(self, sparse_delta, use_locking=False, name=None):
2077    with ops.control_dependencies([self._parent_op]):
2078      return super(_UnreadVariable, self).scatter_add(sparse_delta, use_locking,
2079                                                      name)
2080
2081  def scatter_max(self, sparse_delta, use_locking=False, name=None):
2082    with ops.control_dependencies([self._parent_op]):
2083      return super(_UnreadVariable, self).scatter_max(sparse_delta, use_locking,
2084                                                      name)
2085
2086  def scatter_min(self, sparse_delta, use_locking=False, name=None):
2087    with ops.control_dependencies([self._parent_op]):
2088      return super(_UnreadVariable, self).scatter_min(sparse_delta, use_locking,
2089                                                      name)
2090
2091  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
2092    with ops.control_dependencies([self._parent_op]):
2093      return super(_UnreadVariable, self).scatter_mul(sparse_delta, use_locking,
2094                                                      name)
2095
2096  def scatter_div(self, sparse_delta, use_locking=False, name=None):
2097    with ops.control_dependencies([self._parent_op]):
2098      return super(_UnreadVariable, self).scatter_div(sparse_delta, use_locking,
2099                                                      name)
2100
2101  def scatter_update(self, sparse_delta, use_locking=False, name=None):
2102    with ops.control_dependencies([self._parent_op]):
2103      return super(_UnreadVariable,
2104                   self).scatter_update(sparse_delta, use_locking, name)
2105
2106  def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
2107    with ops.control_dependencies([self._parent_op]):
2108      return super(_UnreadVariable,
2109                   self).batch_scatter_update(sparse_delta, use_locking, name)
2110
2111  def scatter_nd_sub(self, indices, updates, name=None):
2112    with ops.control_dependencies([self._parent_op]):
2113      return super(_UnreadVariable, self).scatter_nd_sub(indices, updates, name)
2114
2115  def scatter_nd_add(self, indices, updates, name=None):
2116    with ops.control_dependencies([self._parent_op]):
2117      return super(_UnreadVariable, self).scatter_nd_add(indices, updates, name)
2118
2119  def scatter_nd_update(self, indices, updates, name=None):
2120    with ops.control_dependencies([self._parent_op]):
2121      return super(_UnreadVariable,
2122                   self).scatter_nd_update(indices, updates, name)
2123
2124  def scatter_nd_max(self, indices, updates, name=None):
2125    with ops.control_dependencies([self._parent_op]):
2126      return super(_UnreadVariable, self).scatter_nd_max(indices, updates, name)
2127
2128  def scatter_nd_min(self, indices, updates, name=None):
2129    with ops.control_dependencies([self._parent_op]):
2130      return super(_UnreadVariable, self).scatter_nd_min(indices, updates, name)
2131
2132  @property
2133  def op(self):
2134    """The op for this variable."""
2135    return self._parent_op
2136
2137
2138@ops.RegisterGradient("ReadVariableOp")
2139def _ReadGrad(_, grad):
2140  """Gradient for read op."""
2141  return grad
2142
2143
2144def variable_shape(handle, out_type=dtypes.int32):
2145  if getattr(handle, "_handle_data",
2146             None) is None or not handle._handle_data.is_set:  # pylint: disable=protected-access
2147    return gen_resource_variable_ops.variable_shape(handle, out_type=out_type)
2148  shape_proto = handle._handle_data.shape_and_type[0].shape  # pylint: disable=protected-access
2149  if shape_proto.unknown_rank or any(x.size == -1 for x in shape_proto.dim):
2150    return gen_resource_variable_ops.variable_shape(handle, out_type=out_type)
2151  return constant_op.constant([x.size for x in shape_proto.dim], dtype=out_type)
2152
2153
2154@ops.RegisterGradient("ResourceGather")
2155def _GatherGrad(op, grad):
2156  """Gradient for gather op."""
2157  # Build appropriately shaped IndexedSlices
2158  handle = op.inputs[0]
2159  indices = op.inputs[1]
2160  params_shape = variable_shape(handle)
2161  size = array_ops.expand_dims(array_ops.size(indices), 0)
2162  values_shape = array_ops.concat([size, params_shape[1:]], 0)
2163  values = array_ops.reshape(grad, values_shape)
2164  indices = array_ops.reshape(indices, size)
2165  return (ops.IndexedSlices(values, indices, params_shape), None)
2166
2167
2168def _to_proto_fn(v, export_scope=None):
2169  """Converts Variable and ResourceVariable to VariableDef for collections."""
2170  return v.to_proto(export_scope=export_scope)
2171
2172
2173def _from_proto_fn(v, import_scope=None):
2174  """Creates Variable or ResourceVariable from VariableDef as needed."""
2175  if v.is_resource:
2176    return ResourceVariable.from_proto(v, import_scope=import_scope)
2177  return variables.Variable.from_proto(v, import_scope=import_scope)
2178
2179
2180ops.register_proto_function(
2181    ops.GraphKeys.GLOBAL_VARIABLES,
2182    proto_type=variable_pb2.VariableDef,
2183    to_proto=_to_proto_fn,
2184    from_proto=_from_proto_fn)
2185ops.register_proto_function(
2186    ops.GraphKeys.TRAINABLE_VARIABLES,
2187    proto_type=variable_pb2.VariableDef,
2188    to_proto=_to_proto_fn,
2189    from_proto=_from_proto_fn)
2190ops.register_proto_function(
2191    ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
2192    proto_type=variable_pb2.VariableDef,
2193    to_proto=_to_proto_fn,
2194    from_proto=_from_proto_fn)
2195ops.register_proto_function(
2196    ops.GraphKeys.LOCAL_VARIABLES,
2197    proto_type=variable_pb2.VariableDef,
2198    to_proto=_to_proto_fn,
2199    from_proto=_from_proto_fn)
2200ops.register_proto_function(
2201    ops.GraphKeys.MODEL_VARIABLES,
2202    proto_type=variable_pb2.VariableDef,
2203    to_proto=_to_proto_fn,
2204    from_proto=_from_proto_fn)
2205ops.register_proto_function(
2206    ops.GraphKeys.GLOBAL_STEP,
2207    proto_type=variable_pb2.VariableDef,
2208    to_proto=_to_proto_fn,
2209    from_proto=_from_proto_fn)
2210ops.register_proto_function(
2211    ops.GraphKeys.METRIC_VARIABLES,
2212    proto_type=variable_pb2.VariableDef,
2213    to_proto=_to_proto_fn,
2214    from_proto=_from_proto_fn)
2215
2216
2217@tf_export("__internal__.ops.is_resource_variable", v1=[])
2218def is_resource_variable(var):
2219  """"Returns True if `var` is to be considered a ResourceVariable."""
2220  return isinstance(var, BaseResourceVariable) or hasattr(
2221      var, "_should_act_as_resource_variable")
2222
2223
2224def copy_to_graph_uninitialized(var):
2225  """Copies an existing variable to a new graph, with no initializer."""
2226  # Like ResourceVariable.__deepcopy__, but does not set an initializer on the
2227  # new variable.
2228  # pylint: disable=protected-access
2229  new_variable = UninitializedVariable(
2230      trainable=var.trainable,
2231      constraint=var._constraint,
2232      shape=var.shape,
2233      dtype=var.dtype,
2234      name=var._shared_name,
2235      synchronization=var.synchronization,
2236      aggregation=var.aggregation,
2237      extra_handle_data=var.handle)
2238  new_variable._maybe_initialize_trackable()
2239  # pylint: enable=protected-access
2240  return new_variable
2241
2242
2243ops.NotDifferentiable("Assert")
2244ops.NotDifferentiable("VarIsInitializedOp")
2245ops.NotDifferentiable("VariableShape")
2246
2247
2248class VariableSpec(tensor_spec.DenseSpec):
2249  """Describes a tf.Variable."""
2250
2251  __slots__ = []
2252
2253  value_type = property(lambda self: BaseResourceVariable)
2254
2255  def _to_components(self, value):
2256    raise NotImplementedError
2257
2258  def _from_components(self, components):
2259    raise NotImplementedError
2260
2261  def _from_compatible_tensor_list(self, tensor_list):
2262    assert len(tensor_list) == 1
2263    return tensor_list[0]
2264
2265
2266_pywrap_utils.RegisterType("VariableSpec", VariableSpec)
2267