• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops to use variables as resources."""
16
17# pylint: disable=g-bad-name
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.core.framework import attr_value_pb2
23from tensorflow.core.framework import variable_pb2
24from tensorflow.python.eager import context
25from tensorflow.python.eager import tape
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import gen_array_ops
31from tensorflow.python.ops import gen_resource_variable_ops
32from tensorflow.python.ops import gen_state_ops
33from tensorflow.python.ops import variables
34# go/tf-wildcard-import
35# pylint: disable=wildcard-import
36from tensorflow.python.ops.gen_resource_variable_ops import *
37# pylint: enable=wildcard-import
38from tensorflow.python.training import checkpointable
39from tensorflow.python.util import compat
40
41
42def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
43  """Creates a variable handle with information to do shape inference."""
44  container = ops.get_default_graph()._container  # pylint: disable=protected-access
45  if container is None:
46    container = ""
47  if not graph_mode:
48    # When in eager mode use a uid for the shared_name, to prevent accidental
49    # sharing.
50    shared_name = str(ops.uid())
51  handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
52                                                   shared_name=shared_name,
53                                                   name=name,
54                                                   container=container)
55  if graph_mode:
56    return handle
57
58  # We do not want two distinct ResourceVariable objects for the same
59  # underlying resource in the runtime.
60  # When in eager mode, explicitly ensure so here. When in graph mode, it's
61  # ensured by always generating different variable names.
62  exists = gen_resource_variable_ops.var_is_initialized_op(handle)
63  if exists:
64    raise ValueError("variable object with name '%s' already created. Use "
65                     "get_variable() if reuse is desired." %
66                     shared_name)
67  with context.graph_mode(), ops.Graph().as_default() as graph:
68    h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
69                                                shared_name=shared_name,
70                                                name=name,
71                                                container=container)
72
73    # Tensor._handle_data contains information for the shape-inference code to
74    # know the shape and dtype of the variable pointed to by a handle. Since
75    # shape inference doesn't run in eager mode we copy this data here for when
76    # the handle is captured by an eager mode function.
77    handle._handle_data = h._handle_data  # pylint: disable=protected-access
78  # Clean up our reference cycles to avoid making the garbage collector run.
79  # pylint: disable=protected-access
80  # OrderedDict, constructed on Graph creation, makes a simple reference loop
81  # and hides it in an __attribute in some Python versions. We don't need to
82  # throw an error if we can't find it, but if we do find it we can break the
83  # loop to avoid creating work for the garbage collector.
84  problematic_cycle = graph._functions.__dict__.get("_OrderedDict__root", None)
85  # pylint: enable=protected-access
86  if problematic_cycle:
87    try:
88      del problematic_cycle[0][:]
89    except TypeError:
90      # This is probably not one of the problematic Python versions. Continue
91      # with the rest of our cleanup.
92      pass
93  # Now clean up our own reference cycles by clearing all of the attributes for
94  # the Graph and op we created.
95  h.__dict__ = {}
96  graph.__dict__ = {}
97  return handle
98
99
100class EagerResourceDeleter(object):
101  """An object which cleans up a resource handle.
102
103  An alternative to defining a __del__ method on an object. The intended use is
104  that ResourceVariables or other objects with resource handles will maintain a
105  single reference to this object. When the parent object is collected, this
106  object will be too. Even if the parent object is part of a reference cycle,
107  the cycle will be collectable.
108  """
109
110  def __init__(self, handle, handle_device):
111    if not isinstance(handle, ops.Tensor):
112      raise ValueError(
113          ("Passed handle=%s to EagerResourceDeleter. Was expecting a handle "
114           "Tensor." % (handle,)))
115    self._handle = handle
116    self._handle_device = handle_device
117
118  def __del__(self):
119    # Resources follow object-identity when executing eagerly, so it is safe to
120    # delete the resource we have a handle to. Each Graph has a unique container
121    # name, which prevents resource sharing.
122    try:
123      # This resource was created in eager mode. However, this destructor may be
124      # running in graph mode (especially during unit tests). To clean up
125      # successfully, we switch back into eager mode temporarily.
126      with context.eager_mode():
127        with ops.device(self._handle_device):
128          gen_resource_variable_ops.destroy_resource_op(
129              self._handle, ignore_lookup_error=True)
130    except TypeError:
131      # Suppress some exceptions, mainly for the case when we're running on
132      # module deletion. Things that can go wrong include the context module
133      # already being unloaded, self._handle._handle_data no longer being
134      # valid, and so on. Printing warnings in these cases is silly
135      # (exceptions raised from __del__ are printed as warnings to stderr).
136      pass  # 'NoneType' object is not callable when the handle has been
137            # partially unloaded.
138    except AttributeError:
139      pass  # 'NoneType' object has no attribute 'eager_mode' when context has
140            # been unloaded. Will catch other module unloads as well.
141
142
143def shape_safe_assign_variable_handle(handle, shape, value, name=None):
144  """Helper that checks shape compatibility and assigns variable."""
145  value_tensor = ops.convert_to_tensor(value)
146  shape.assert_is_compatible_with(value_tensor.shape)
147  return gen_resource_variable_ops.assign_variable_op(handle,
148                                                      value_tensor,
149                                                      name=name)
150
151
152class ResourceVariable(variables.Variable):
153  """Variable based on resource handles.
154
155  See the ${variables} documentation for more details.
156
157  A `ResourceVariable` allows you to maintain state across subsequent calls to
158  session.run.
159
160  The `ResourceVariable` constructor requires an initial value for the variable,
161  which can be a `Tensor` of any type and shape. The initial value defines the
162  type and shape of the variable. After construction, the type and shape of
163  the variable are fixed. The value can be changed using one of the assign
164  methods.
165
166  Just like any `Tensor`, variables created with `ResourceVariable()` can be
167  used as inputs for other Ops in the graph. Additionally, all the operators
168  overloaded for the `Tensor` class are carried over to variables, so you can
169  also add nodes to the graph by just doing arithmetic on variables.
170
171  Unlike tf.Variable, a tf.ResourceVariable has well-defined semantics. Each
172  usage of a ResourceVariable in a TensorFlow graph adds a read_value operation
173  to the graph. The Tensors returned by a read_value operation are guaranteed
174  to see all modifications to the value of the variable which happen in any
175  operation on which the read_value depends on (either directly, indirectly, or
176  via a control dependency) and guaranteed to not see any modification to the
177  value of the variable on which the read_value operation does not depend on.
178
179  For example, if there is more than one assignment to a ResourceVariable in
180  a single session.run call there is a well-defined value for each operation
181  which uses the variable's value if the assignments and the read are connected
182  by edges in the graph. Consider the following example, in which two writes
183  can cause tf.Variable and tf.ResourceVariable to behave differently:
184
185   ```python
186    a = tf.ResourceVariable(1.0)
187    a.initializer.run()
188
189    assign = a.assign(2.0)
190    with tf.control_dependencies([assign]):
191      b = a.read_value()
192    with tf.control_dependencies([b]):
193      other_assign = a.assign(3.0)
194    with tf.control_dependencies([other_assign]):
195      # Will print 2.0 because the value was read before other_assign ran. If
196      # `a` was a tf.Variable instead, 2.0 or 3.0 could be printed.
197      tf.Print(b, [b]).eval()
198  ```
199
200  To enforce these consistency properties tf.ResourceVariable might make more
201  copies than an equivalent tf.Variable under the hood, so tf.Variable is still
202  not deprecated.
203  """
204
205  def __init__(self,
206               initial_value=None,
207               trainable=True,
208               collections=None,
209               validate_shape=True,
210               caching_device=None,
211               name=None,
212               dtype=None,
213               variable_def=None,
214               import_scope=None,
215               constraint=None):
216    """Creates a variable.
217
218    Args:
219      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
220        which is the initial value for the Variable. The initial value must have
221        a shape specified unless `validate_shape` is set to False. Can also be a
222        callable with no argument that returns the initial value when called.
223        (Note that initializer functions from init_ops.py must first be bound
224         to a shape before being used here.)
225      trainable: If `True`, the default, also adds the variable to the graph
226        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
227        the default list of variables to use by the `Optimizer` classes.
228      collections: List of graph collections keys. The new variable is added to
229        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
230      validate_shape: Ignored. Provided for compatibility with tf.Variable.
231      caching_device: Optional device string or function describing where the
232        Variable should be cached for reading.  Defaults to the Variable's
233        device.  If not `None`, caches on another device.  Typical use is to
234        cache on the device where the Ops using the Variable reside, to
235        deduplicate copying through `Switch` and other conditional statements.
236      name: Optional name for the variable. Defaults to `'Variable'` and gets
237        uniquified automatically.
238      dtype: If set, initial_value will be converted to the given type.
239        If None, either the datatype will be kept (if initial_value is
240        a Tensor) or float32 will be used (if it is a Python object convertible
241        to a Tensor).
242      variable_def: `VariableDef` protocol buffer. If not None, recreates the
243        `ResourceVariable` object with its contents. `variable_def` and other
244        arguments (except for import_scope) are mutually exclusive.
245      import_scope: Optional `string`. Name scope to add to the
246        ResourceVariable. Only used when `variable_def` is provided.
247      constraint: An optional projection function to be applied to the variable
248        after being updated by an `Optimizer` (e.g. used to implement norm
249        constraints or value constraints for layer weights). The function must
250        take as input the unprojected Tensor representing the value of the
251        variable and return the Tensor for the projected value
252        (which must have the same shape). Constraints are not safe to
253        use when doing asynchronous distributed training.
254
255    Raises:
256      ValueError: If the initial value is not specified, or does not have a
257        shape and `validate_shape` is `True`.
258
259    @compatibility(eager)
260    When Eager Execution is enabled, the default for the `collections` argument
261    is `None`, which signifies that this `Variable` will not be added to any
262    collections.
263    @end_compatibility
264    """
265    if variable_def:
266      if initial_value is not None:
267        raise ValueError("variable_def and initial_value are mutually "
268                         "exclusive.")
269      if not context.in_graph_mode():
270        raise ValueError("Creating ResourceVariable from variable_def"
271                         " only supported in GRAPH mode.")
272      self._init_from_proto(variable_def, import_scope=import_scope)
273    else:
274      self._init_from_args(
275          initial_value=initial_value,
276          trainable=trainable,
277          collections=collections,
278          validate_shape=validate_shape,
279          caching_device=caching_device,
280          name=name,
281          dtype=dtype,
282          constraint=constraint)
283
284  # pylint: disable=unused-argument
285  def _init_from_args(self,
286                      initial_value=None,
287                      trainable=True,
288                      collections=None,
289                      validate_shape=True,
290                      caching_device=None,
291                      name=None,
292                      dtype=None,
293                      constraint=None):
294    """Creates a variable.
295
296    Args:
297      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
298        which is the initial value for the Variable. The initial value must have
299        a shape specified unless `validate_shape` is set to False. Can also be a
300        callable with no argument that returns the initial value when called.
301        (Note that initializer functions from init_ops.py must first be bound
302         to a shape before being used here.)
303      trainable: If `True`, the default, also adds the variable to the graph
304        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
305        the default list of variables to use by the `Optimizer` classes.
306      collections: List of graph collections keys. The new variable is added to
307        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
308      validate_shape: Ignored. Provided for compatibility with tf.Variable.
309      caching_device: Optional device string or function describing where the
310        Variable should be cached for reading.  Defaults to the Variable's
311        device.  If not `None`, caches on another device.  Typical use is to
312        cache on the device where the Ops using the Variable reside, to
313        deduplicate copying through `Switch` and other conditional statements.
314      name: Optional name for the variable. Defaults to `'Variable'` and gets
315        uniquified automatically.
316      dtype: If set, initial_value will be converted to the given type.
317        If None, either the datatype will be kept (if initial_value is
318       a Tensor) or float32 will be used (if it is a Python object convertible
319       to a Tensor).
320      constraint: An optional projection function to be applied to the variable
321        after being updated by an `Optimizer` (e.g. used to implement norm
322        constraints or value constraints for layer weights). The function must
323        take as input the unprojected Tensor representing the value of the
324        variable and return the Tensor for the projected value
325        (which must have the same shape). Constraints are not safe to
326        use when doing asynchronous distributed training.
327
328    Raises:
329      ValueError: If the initial value is not specified, or does not have a
330        shape and `validate_shape` is `True`.
331
332    @compatibility(eager)
333    When Eager Execution is enabled, variables are never added to collections.
334    It is not implicitly added to the `GLOBAL_VARIABLES` or
335    `TRAINABLE_VARIABLES` collections, and the `collections` argument is
336    ignored.
337    @end_compatibility
338    """
339    if initial_value is None:
340      raise ValueError("initial_value must be specified.")
341    init_from_fn = callable(initial_value)
342
343    if collections is None:
344      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
345    if not isinstance(collections, (list, tuple, set)):
346      raise ValueError(
347          "collections argument to Variable constructor must be a list, tuple, "
348          "or set. Got %s of type %s" % (collections, type(collections)))
349    if constraint is not None and not callable(constraint):
350      raise ValueError("The `constraint` argument must be a callable.")
351
352    if isinstance(initial_value, checkpointable.CheckpointInitialValue):
353      self._maybe_initialize_checkpointable()
354      self._update_uid = initial_value.checkpoint_position.restore_uid
355      initial_value = initial_value.wrapped_value
356
357    self._trainable = trainable
358    if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
359      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
360    self._save_slice_info = None
361    # Store the graph key so optimizers know how to only retrieve variables from
362    # this graph.
363    self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
364    with ops.init_scope():
365      self._in_graph_mode = context.in_graph_mode()
366      with ops.name_scope(name, "Variable", []
367                          if init_from_fn else [initial_value]) as name:
368        # pylint: disable=protected-access
369        handle_name = ops._name_from_scope_name(name)
370        if init_from_fn:
371          # Use attr_scope and device(None) to simulate the behavior of
372          # colocate_with when the variable we want to colocate with doesn't
373          # yet exist.
374          if self._in_graph_mode:
375            attr = attr_value_pb2.AttrValue(
376                list=attr_value_pb2.AttrValue.ListValue(
377                    s=[compat.as_bytes("loc:@%s" % handle_name)]))
378            with ops.get_default_graph()._attr_scope({"_class": attr}):
379              with ops.name_scope("Initializer"), ops.device(None):
380                initial_value = ops.convert_to_tensor(
381                    initial_value(), name="initial_value", dtype=dtype)
382              self._handle = _eager_safe_variable_handle(
383                  shape=initial_value.get_shape(),
384                  dtype=initial_value.dtype.base_dtype,
385                  shared_name=handle_name,
386                  name=name,
387                  graph_mode=self._in_graph_mode)
388              self._handle_device = (
389                  self._handle.device if self._in_graph_mode else
390                  context.get_default_context().device_name)
391              self._shape = initial_value.get_shape()
392          else:
393            initial_value = initial_value()
394            with ops.name_scope("Initializer"):
395              initial_value = ops.convert_to_tensor(
396                  initial_value, name="initial_value", dtype=dtype)
397            self._handle = _eager_safe_variable_handle(
398                shape=initial_value.get_shape(),
399                dtype=initial_value.dtype.base_dtype,
400                shared_name=handle_name,
401                name=name,
402                graph_mode=False)
403            self._handle_device = (
404                self._handle.device if self._in_graph_mode else
405                context.get_default_context().device_name)
406            self._shape = initial_value.get_shape()
407        # pylint: enable=protected-access
408
409        # Or get the initial value from a Tensor or Python object.
410        else:
411          with ops.name_scope("Initializer"):
412            initial_value = ops.convert_to_tensor(
413                initial_value, name="initial_value", dtype=dtype)
414          # pylint: disable=protected-access
415          if (self._in_graph_mode and initial_value is not None and
416              initial_value.op._get_control_flow_context() is not None):
417            raise ValueError(
418                "Initializer for variable %s is from inside a control-flow "
419                "construct, such as a loop or conditional. When creating a "
420                "variable inside a loop or conditional, use a lambda as the "
421                "initializer." % name)
422          # pylint: enable=protected-access
423          self._handle = _eager_safe_variable_handle(
424              shape=initial_value.get_shape(),
425              dtype=initial_value.dtype.base_dtype,
426              shared_name=handle_name,
427              name=name,
428              graph_mode=self._in_graph_mode)
429          self._handle_device = (self._handle.device if self._in_graph_mode else
430                                 context.get_default_context().device_name)
431          self._shape = initial_value.get_shape()
432
433        self._initial_value = initial_value if self._in_graph_mode else None
434        self._handle_name = handle_name + ":0"
435        self._dtype = initial_value.dtype.base_dtype
436        self._constraint = constraint
437
438        if self._in_graph_mode:
439          with ops.name_scope("IsInitialized"):
440            self._is_initialized_op = (
441                gen_resource_variable_ops.var_is_initialized_op(self._handle))
442          if initial_value is not None:
443            with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
444              self._initializer_op = (
445                  gen_resource_variable_ops.assign_variable_op(
446                      self._handle,
447                      self._try_guard_against_uninitialized_dependencies(
448                          initial_value),
449                      name=n))
450          with ops.name_scope("Read"), ops.colocate_with(self._handle):
451            # Manually assign reads to the handle's device to avoid log
452            # messages.
453            with ops.device(self._handle_device):
454              value = self._read_variable_op()
455            self._graph_element = value
456            if caching_device is not None:
457              # Variables may be created in a tf.device() or ops.colocate_with()
458              # context. At the same time, users would expect caching device to
459              # be independent of this context, and/or would not expect the
460              # current device context to be merged with the caching device
461              # spec.  Therefore we reset the colocation stack before creating
462              # the cached value. Note that resetting the colocation stack will
463              # also reset the device stack.
464              with ops.colocate_with(None, ignore_existing=True):
465                with ops.device(caching_device):
466                  self._cached_value = array_ops.identity(value)
467            else:
468              self._cached_value = None
469        else:
470          gen_resource_variable_ops.assign_variable_op(self._handle,
471                                                       initial_value)
472          self._is_initialized_op = None
473          self._initializer_op = None
474          self._graph_element = None
475          if caching_device:
476            with ops.device(caching_device):
477              self._cached_value = self._read_variable_op()
478          else:
479            self._cached_value = None
480        if context.in_graph_mode():
481          ops.add_to_collections(collections, self)
482        elif ops.GraphKeys.GLOBAL_STEP in collections:
483          ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)
484
485    if not self._in_graph_mode:
486      # After the handle has been created, set up a way to clean it up when
487      # executing eagerly. We'll hold the only reference to the deleter, so that
488      # when this object is garbage collected the deleter will be too. This
489      # means ResourceVariables can be part of reference cycles without those
490      # cycles being uncollectable, and means that no __del__ will be defined at
491      # all in graph mode.
492      self._handle_deleter = EagerResourceDeleter(
493          handle=self._handle, handle_device=self._handle_device)
494
495  def _init_from_proto(self, variable_def, import_scope=None):
496    """Initializes from `VariableDef` proto."""
497    # Note that init_from_proto is currently not supported in Eager mode.
498    assert context.in_graph_mode()
499    self._in_graph_mode = True
500    assert isinstance(variable_def, variable_pb2.VariableDef)
501    if not variable_def.is_resource:
502      raise ValueError("Trying to restore Variable as ResourceVariable.")
503
504    # Create from variable_def.
505    g = ops.get_default_graph()
506    self._handle = g.as_graph_element(
507        ops.prepend_name_scope(
508            variable_def.variable_name, import_scope=import_scope))
509    self._shape = tensor_shape.TensorShape(
510        self._handle.op.get_attr("shape"))
511    self._handle_device = self._handle.device
512    self._handle_name = self._handle.name
513    self._initializer_op = g.as_graph_element(
514        ops.prepend_name_scope(
515            variable_def.initializer_name, import_scope=import_scope))
516    # Check whether initial_value_name exists for backwards compatibility.
517    if (hasattr(variable_def, "initial_value_name") and
518        variable_def.initial_value_name):
519      self._initial_value = g.as_graph_element(
520          ops.prepend_name_scope(variable_def.initial_value_name,
521                                 import_scope=import_scope))
522    else:
523      self._initial_value = None
524    if variable_def.snapshot_name:
525      self._cached_value = g.as_graph_element(
526          ops.prepend_name_scope(
527              variable_def.snapshot_name, import_scope=import_scope))
528    else:
529      self._cached_value = None
530    if variable_def.HasField("save_slice_info_def"):
531      self._save_slice_info = variables.Variable.SaveSliceInfo(
532          save_slice_info_def=variable_def.save_slice_info_def,
533          import_scope=import_scope)
534    else:
535      self._save_slice_info = None
536    self._caching_device = None
537    self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
538    self._graph_element = self.value()
539    self._constraint = None
540
541  def __nonzero__(self):
542    return self.__bool__()
543
544  def __bool__(self):
545    return bool(self.read_value())
546
547  @property
548  def dtype(self):
549    """The dtype of this variable."""
550    return self._dtype
551
552  @property
553  def device(self):
554    """The device this variable is on."""
555    return self._handle_device
556
557  @property
558  def graph(self):
559    """The `Graph` of this variable."""
560    return self._handle.graph
561
562  @property
563  def name(self):
564    """The name of the handle for this variable."""
565    return self._handle_name
566
567  @property
568  def shape(self):
569    """The shape of this variable."""
570    return self._shape
571
572  @property
573  def create(self):
574    """The op responsible for initializing this variable."""
575    if not self._in_graph_mode:
576      raise RuntimeError("Calling create in EAGER mode not supported.")
577    return self._initializer_op
578
579  @property
580  def handle(self):
581    """The handle by which this variable can be accessed."""
582    return self._handle
583
584  def value(self):
585    """A cached operation which reads the value of this variable."""
586    if self._cached_value is not None:
587      return self._cached_value
588    with ops.colocate_with(None, ignore_existing=True):
589      with ops.device(self._handle_device):
590        return self._read_variable_op()
591
592  def _as_graph_element(self):
593    """Conversion function for Graph.as_graph_element()."""
594    return self._graph_element
595
596  @property
597  def initializer(self):
598    """The op responsible for initializing this variable."""
599    return self._initializer_op
600
601  @property
602  def initial_value(self):
603    """Returns the Tensor used as the initial value for the variable."""
604    if context.in_eager_mode():
605      raise RuntimeError("initial_value not supported in EAGER mode.")
606    return self._initial_value
607
608  @property
609  def constraint(self):
610    """Returns the constraint function associated with this variable.
611
612    Returns:
613      The constraint function that was passed to the variable constructor.
614      Can be `None` if no constraint was passed.
615    """
616    return self._constraint
617
618  @property
619  def op(self):
620    """The op for this variable."""
621    return self._handle.op
622
623  def eval(self, session=None):
624    """Evaluates and returns the value of this variable."""
625    if context.in_eager_mode():
626      raise RuntimeError("Trying to eval in EAGER mode")
627    return self._graph_element.eval(session=session)
628
629  def numpy(self):
630    if context.in_graph_mode():
631      raise NotImplementedError(
632          "numpy() is only available when eager execution is enabled.")
633    return self.read_value().numpy()
634
635  def count_up_to(self, limit):
636    """Increments this variable until it reaches `limit`.
637
638    When that Op is run it tries to increment the variable by `1`. If
639    incrementing the variable would bring it above `limit` then the Op raises
640    the exception `OutOfRangeError`.
641
642    If no error is raised, the Op outputs the value of the variable before
643    the increment.
644
645    This is essentially a shortcut for `count_up_to(self, limit)`.
646
647    Args:
648      limit: value at which incrementing the variable raises an error.
649
650    Returns:
651      A `Tensor` that will hold the variable value before the increment. If no
652      other Op modifies this variable, the values produced will all be
653      distinct.
654    """
655    return gen_state_ops.resource_count_up_to(self.handle, limit=limit,
656                                              T=self.dtype)
657
658  def _set_save_slice_info(self, save_slice_info):
659    """Sets the slice info for this `ResourceVariable`.
660
661    Args:
662      save_slice_info: A `Variable.SaveSliceInfo` object.
663    """
664    self._save_slice_info = save_slice_info
665
666  def _get_save_slice_info(self):
667    return self._save_slice_info
668
669  def _read_variable_op(self):
670    if hasattr(self, "_trainable") and self._trainable:
671      tape.watch_variable(self)
672    return gen_resource_variable_ops.read_variable_op(self._handle,
673                                                      self._dtype)
674
675  def read_value(self):
676    """Constructs an op which reads the value of this variable.
677
678    Should be used when there are multiple reads, or when it is desirable to
679    read the value only after some condition is true.
680
681    Returns:
682     the read operation.
683    """
684    with ops.name_scope("Read"):
685      # Ensure we read the variable in the same device as the handle.
686      with ops.device(self._handle_device):
687        value = self._read_variable_op()
688    # Return an identity so it can get placed on whatever device the context
689    # specifies instead of the device where the variable is.
690    return array_ops.identity(value)
691
692  def sparse_read(self, indices, name=None):
693    """Reads the value of this variable sparsely, using `gather`."""
694    with ops.name_scope("Gather" if name is None else name) as name:
695      if self._trainable:
696        tape.watch_variable(self)
697      value = gen_resource_variable_ops.resource_gather(
698          self._handle, indices, dtype=self._dtype, name=name)
699    return array_ops.identity(value)
700
701  def to_proto(self, export_scope=None):
702    """Converts a `ResourceVariable` to a `VariableDef` protocol buffer.
703
704    Args:
705      export_scope: Optional `string`. Name scope to remove.
706
707    Raises:
708      RuntimeError: If run in EAGER mode.
709
710    Returns:
711      A `VariableDef` protocol buffer, or `None` if the `Variable` is not
712      in the specified name scope.
713    """
714    if context.in_eager_mode():
715      raise RuntimeError("to_proto not supported in EAGER mode.")
716    if export_scope is None or self.handle.name.startswith(export_scope):
717      var_def = variable_pb2.VariableDef()
718      var_def.variable_name = ops.strip_name_scope(self.handle.name,
719                                                   export_scope)
720      if self._initial_value is not None:
721        # This is inside an if-statement for backwards compatibility, since
722        # self._initial_value might be None for variables constructed from old
723        # protos.
724        var_def.initial_value_name = ops.strip_name_scope(
725            self._initial_value.name, export_scope)
726      var_def.initializer_name = ops.strip_name_scope(self.initializer.name,
727                                                      export_scope)
728      if self._cached_value is not None:
729        var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name,
730                                                     export_scope)
731      var_def.is_resource = True
732      if self._save_slice_info:
733        var_def.save_slice_info_def.MergeFrom(
734            self._save_slice_info.to_proto(export_scope=export_scope))
735      return var_def
736    else:
737      return None
738
739  @staticmethod
740  def from_proto(variable_def, import_scope=None):
741    if context.in_eager_mode():
742      raise RuntimeError("from_proto not supported in EAGER mode.")
743    return ResourceVariable(
744        variable_def=variable_def, import_scope=import_scope)
745
746  @staticmethod
747  def _OverloadAllOperators():  # pylint: disable=invalid-name
748    """Register overloads for all operators."""
749    for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
750      ResourceVariable._OverloadOperator(operator)
751    # For slicing, bind getitem differently than a tensor (use SliceHelperVar
752    # instead)
753    # pylint: disable=protected-access
754    setattr(ResourceVariable, "__getitem__", array_ops._SliceHelperVar)
755
756  def _AsTensor(self):
757    return self.value()
758
759  def _ref(self):
760    """Unsupported."""
761    raise NotImplementedError("ResourceVariable does not implement _ref()")
762
763  def set_shape(self, shape):
764    """Unsupported."""
765    raise NotImplementedError("ResourceVariable does not implement set_shape()")
766
767  @staticmethod
768  def _OverloadOperator(operator):  # pylint: disable=invalid-name
769    """Defer an operator overload to `ops.Tensor`.
770
771    We pull the operator out of ops.Tensor dynamically to avoid ordering issues.
772
773    Args:
774      operator: string. The operator name.
775    """
776
777    def _run_op(a, *args):
778      # pylint: disable=protected-access
779      value = a._AsTensor()
780      return getattr(ops.Tensor, operator)(value, *args)
781
782    # Propagate __doc__ to wrapper
783    try:
784      _run_op.__doc__ = getattr(ops.Tensor, operator).__doc__
785    except AttributeError:
786      pass
787
788    setattr(ResourceVariable, operator, _run_op)
789
790  __array_priority__ = 100
791
792  def assign_sub(self, delta, use_locking=None, name=None):
793    # TODO(apassos): this here and below is not atomic. Consider making it
794    # atomic if there's a way to do so without a performance cost for those who
795    # don't need it.
796    return self._lazy_read(gen_resource_variable_ops.assign_sub_variable_op(
797        self.handle,
798        ops.convert_to_tensor(delta, dtype=self.dtype),
799        name=name))
800
801  def assign_add(self, delta, use_locking=None, name=None):
802    return self._lazy_read(gen_resource_variable_ops.assign_add_variable_op(
803        self.handle,
804        ops.convert_to_tensor(delta, dtype=self.dtype),
805        name=name))
806
807  def _lazy_read(self, op):
808    if hasattr(self, "_trainable") and self._trainable:
809      tape.watch_variable(self)
810    return _UnreadVariable(
811        self._handle, self.dtype, self._handle_device, self._shape,
812        self._in_graph_mode,
813        self._handle_deleter if not self._in_graph_mode else None, op)
814
815  def assign(self, value, use_locking=None, name=None):
816    value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
817    self._shape.assert_is_compatible_with(value_tensor.shape)
818    return self._lazy_read(
819        gen_resource_variable_ops.assign_variable_op(
820            self.handle,
821            value_tensor,
822            name=name))
823
824  def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask,
825                            end_mask, ellipsis_mask, new_axis_mask,
826                            shrink_axis_mask):
827    return self._lazy_read(
828        gen_array_ops.resource_strided_slice_assign(
829            ref=self.handle,
830            begin=begin,
831            end=end,
832            strides=strides,
833            value=value,
834            name=name,
835            begin_mask=begin_mask,
836            end_mask=end_mask,
837            ellipsis_mask=ellipsis_mask,
838            new_axis_mask=new_axis_mask,
839            shrink_axis_mask=shrink_axis_mask))
840
841  def __int__(self):
842    if self.dtype != dtypes.int32 and self.dtype != dtypes.int64:
843      raise TypeError("Non-integer variable can't be converted to integer.")
844    return int(self.value().numpy())
845
846  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
847    del name
848    if dtype is not None and dtype != self.dtype:
849      print("trying to switch the dtype to ", dtype, " from ", self.dtype)
850      return NotImplemented
851    if as_ref:
852      return self.read_value().op.inputs[0]
853    else:
854      return self.value()
855
856  def __iadd__(self, unused_other):
857    raise RuntimeError("Variable += value not supported. Use "
858                       "variable.assign_add(value) to modify the variable "
859                       "value and variable = variable + value to get a new "
860                       "Tensor object.")
861
862  def __isub__(self, unused_other):
863    raise RuntimeError("Variable -= value not supported. Use "
864                       "variable.assign_sub(value) to modify the variable "
865                       "value and variable = variable - value to get a new "
866                       "Tensor object.")
867
868  def __imul__(self, unused_other):
869    raise RuntimeError("Variable *= value not supported. Use "
870                       "variable.assign_mul(value) to modify the variable "
871                       "value and variable = variable * value to get a new "
872                       "Tensor object.")
873
874  def __idiv__(self, unused_other):
875    raise RuntimeError("Variable /= value not supported. Use "
876                       "variable.assign_div(value) to modify the variable "
877                       "value and variable = variable / value to get a new "
878                       "Tensor object.")
879
880  def __itruediv__(self, unused_other):
881    raise RuntimeError("Variable /= value not supported. Use "
882                       "variable.assign_div(value) to modify the variable "
883                       "value and variable = variable / value to get a new "
884                       "Tensor object.")
885
886  def __irealdiv__(self, unused_other):
887    raise RuntimeError("Variable /= value not supported. Use "
888                       "variable.assign_div(value) to modify the variable "
889                       "value and variable = variable / value to get a new "
890                       "Tensor object.")
891
892  def __ipow__(self, unused_other):
893    raise RuntimeError("Variable **= value not supported. Use "
894                       "value and variable = variable ** value to get a new "
895                       "Tensor object.")
896
897
898def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
899  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
900
901
902class _UnreadVariable(ResourceVariable):
903  """Represents a future for a read of a variable.
904
905  Pretends to be the tensor if anyone looks.
906  """
907
908  def __init__(self, handle, dtype, handle_device,  # pylint: disable=super-init-not-called
909               shape, in_graph_mode, deleter, parent_op):
910    # We do not call super init on purpose.
911    self._trainable = False
912    self._save_slice_info = None
913    self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
914    self._in_graph_mode = in_graph_mode
915    self._handle = handle
916    self._handle_device = handle_device
917    self._shape = shape
918    self._initial_value = None
919    if isinstance(self._handle, ops.EagerTensor):
920      self._handle_name = ""
921    else:
922      self._handle_name = self._handle.name
923    self._dtype = dtype
924    self._constraint = None
925    self._cached_value = None
926    self._is_initialized_op = None
927    self._initializer_op = None
928    self._parent_op = parent_op
929    if context.in_graph_mode():
930      self._graph_element = self.read_value()
931    else:
932      self._graph_element = None
933    self._handle_deleter = deleter
934
935  def value(self):
936    return self._read_variable_op()
937
938  def read_value(self):
939    return self._read_variable_op()
940
941  def _read_variable_op(self):
942    with ops.control_dependencies([self._parent_op]):
943      return gen_resource_variable_ops.read_variable_op(self._handle,
944                                                        self._dtype)
945
946  def set_shape(self, shape):
947    self._shape = shape
948
949  @property
950  def op(self):
951    """The op for this variable."""
952    return self._parent_op
953
954ops.register_tensor_conversion_function(_UnreadVariable, _dense_var_to_tensor)
955ops.register_dense_tensor_like_type(_UnreadVariable)
956
957# Register a conversion function which reads the value of the variable,
958# allowing instances of the class to be used as tensors.
959
960# Note: registering for Variable after ResourceVariable because inheritance will
961# otherwise lead to the wrong behavior.
962ops.register_tensor_conversion_function(ResourceVariable, _dense_var_to_tensor)
963ops.register_tensor_conversion_function(
964    variables.Variable, variables.Variable._TensorConversionFunction)  # pylint: disable=protected-access
965
966# pylint: disable=protected-access
967ResourceVariable._OverloadAllOperators()
968ops.register_dense_tensor_like_type(ResourceVariable)
969
970
971@ops.RegisterGradient("ReadVariableOp")
972def _ReadGrad(_, grad):
973  """Gradient for read op."""
974  return grad
975
976
977@ops.RegisterGradient("ResourceGather")
978def _GatherGrad(op, grad):
979  """Gradient for gather op."""
980  # Build appropriately shaped IndexedSlices
981  handle = op.inputs[0]
982  indices = op.inputs[1]
983  params_shape = gen_resource_variable_ops.variable_shape(handle)
984  size = array_ops.expand_dims(array_ops.size(indices), 0)
985  values_shape = array_ops.concat([size, params_shape[1:]], 0)
986  values = array_ops.reshape(grad, values_shape)
987  indices = array_ops.reshape(indices, size)
988  return (ops.IndexedSlices(values, indices, params_shape), None)
989
990
991def _to_proto_fn(v, export_scope=None):
992  """Converts Variable and ResourceVariable to VariableDef for collections."""
993  return v.to_proto(export_scope=export_scope)
994
995
996def _from_proto_fn(v, import_scope=None):
997  """Creates Variable or ResourceVariable from VariableDef as needed."""
998  if v.is_resource:
999    return ResourceVariable.from_proto(v, import_scope=import_scope)
1000  return variables.Variable.from_proto(v, import_scope=import_scope)
1001
1002
1003ops.register_proto_function(
1004    ops.GraphKeys.GLOBAL_VARIABLES,
1005    proto_type=variable_pb2.VariableDef,
1006    to_proto=_to_proto_fn,
1007    from_proto=_from_proto_fn)
1008ops.register_proto_function(
1009    ops.GraphKeys.TRAINABLE_VARIABLES,
1010    proto_type=variable_pb2.VariableDef,
1011    to_proto=_to_proto_fn,
1012    from_proto=_from_proto_fn)
1013ops.register_proto_function(
1014    ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
1015    proto_type=variable_pb2.VariableDef,
1016    to_proto=_to_proto_fn,
1017    from_proto=_from_proto_fn)
1018ops.register_proto_function(
1019    ops.GraphKeys.LOCAL_VARIABLES,
1020    proto_type=variable_pb2.VariableDef,
1021    to_proto=_to_proto_fn,
1022    from_proto=_from_proto_fn)
1023ops.register_proto_function(
1024    ops.GraphKeys.MODEL_VARIABLES,
1025    proto_type=variable_pb2.VariableDef,
1026    to_proto=_to_proto_fn,
1027    from_proto=_from_proto_fn)
1028
1029
1030def is_resource_variable(var):
1031  """"Returns True if `var` is to be considered a ResourceVariable."""
1032  return isinstance(var, ResourceVariable) or hasattr(
1033      var, "_should_act_as_resource_variable")
1034