• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A class to store named variables and a scope operator to manage sharing."""
16
17import copy
18import enum
19import functools
20import sys
21import threading
22import traceback
23
24from tensorflow.python import tf2
25from tensorflow.python.client import session
26from tensorflow.python.eager import context
27from tensorflow.python.eager import monitoring
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import init_ops
33from tensorflow.python.ops import resource_variable_ops
34from tensorflow.python.ops import variables
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.types import core
37from tensorflow.python.util import deprecation
38from tensorflow.python.util import function_utils
39from tensorflow.python.util import tf_contextlib
40from tensorflow.python.util import tf_inspect
41from tensorflow.python.util.compat import collections_abc
42from tensorflow.python.util.tf_export import tf_export
43
44__all__ = [
45    "AUTO_REUSE", "VariableScope", "get_variable_scope", "get_variable",
46    "get_local_variable", "variable_scope", "variable_op_scope",
47    "no_regularizer", "VariableSynchronization", "VariableAggregation"
48]
49
50_api_usage_gauge = monitoring.BoolGauge(
51    "/tensorflow/api/resource_variables",
52    "Whether variable_scope.enable_resource_variables() is called.")
53
54
55class _PartitionInfo:
56  """Holds partition info used by initializer functions."""
57
58  __slots__ = ["_full_shape", "_var_offset"]
59
60  def __init__(self, full_shape, var_offset):
61    """Constructor.
62
63    Args:
64      full_shape: Tuple or list of `int` indicating the full combined shape of
65        the partitioned variables.
66      var_offset: Tuple or list of `int` specifying offset of this partition
67        with respect to the full variable for each dimension.
68
69    Raises:
70      TypeError: If `full_shape` or `var_offset` is not a sequence.
71      ValueError: If `full_shape` or `var_offset` differ in length. If
72        `var_offset` exceeds `full_shape` in any dimension.
73    """
74    if not isinstance(full_shape, (list, tuple)):
75      raise TypeError(
76          "`full_shape` must be a sequence (like tuple or list) instead of " +
77          type(full_shape).__name__)
78
79    if not isinstance(var_offset, (list, tuple)):
80      raise TypeError(
81          "`var_offset` must be a sequence (like tuple or list) instead of " +
82          type(var_offset).__name__)
83
84    if len(var_offset) != len(full_shape):
85      raise ValueError(
86          "Expected equal length, but `var_offset` is of length {} while "
87          "full_shape is of length {}.".format(
88              len(var_offset), len(full_shape)))
89
90    for offset, shape in zip(var_offset, full_shape):
91      if offset < 0 or offset >= shape:
92        raise ValueError(
93            "Expected 0 <= offset < shape but found offset={}, shape={} for "
94            "var_offset={}, full_shape={}".format(offset, shape, var_offset,
95                                                  full_shape))
96
97    self._full_shape = full_shape
98    self._var_offset = var_offset
99
100  @property
101  def full_shape(self):
102    return self._full_shape
103
104  @property
105  def var_offset(self):
106    return self._var_offset
107
108  def single_offset(self, shape):
109    """Returns the offset when the variable is partitioned in at most one dim.
110
111    Args:
112      shape: Tuple or list of `int` indicating the shape of one specific
113        variable partition.
114
115    Returns:
116      `int` representing the offset in the dimension along which the variable is
117       partitioned. Returns 0 if the variable is not being partitioned.
118
119    Raises:
120      ValueError: Depending on self.single_slice_dim().
121    """
122
123    single_slice_dim = self.single_slice_dim(shape)
124    # If this variable is not being partitioned at all, single_slice_dim() could
125    # return None.
126    if single_slice_dim is None:
127      return 0
128    return self.var_offset[single_slice_dim]
129
130  def single_slice_dim(self, shape):
131    """Returns the slice dim when the variable is partitioned only in one dim.
132
133    Args:
134      shape: Tuple or list of `int` indicating the shape of one specific
135        variable partition.
136
137    Returns:
138      `int` representing the dimension that the variable is partitioned in, or
139      `None` if the variable doesn't seem to be partitioned at all.
140
141    Raises:
142      TypeError: If `shape` is not a sequence.
143      ValueError: If `shape` is not the same length as `self.full_shape`. If
144        the variable is partitioned in more than one dimension.
145    """
146    if not isinstance(shape, (tuple, list)):
147      raise TypeError(
148          "`shape` must be a sequence (like tuple or list) instead of " +
149          type(shape).__name__)
150
151    if len(shape) != len(self.full_shape):
152      raise ValueError(
153          "Expected equal length, but received shape={} of length {} while "
154          "self.full_shape={} is of length {}.".format(shape, len(shape),
155                                                       self.full_shape,
156                                                       len(self.full_shape)))
157
158    for i in range(len(shape)):
159      if self.var_offset[i] + shape[i] > self.full_shape[i]:
160        raise ValueError(
161            "With self.var_offset={}, a partition of shape={} would exceed "
162            "self.full_shape={} in dimension {}.".format(
163                self.var_offset, shape, self.full_shape, i))
164
165    slice_dim = None
166    for i in range(len(shape)):
167      if shape[i] == self.full_shape[i]:
168        continue
169      if slice_dim is not None:
170        raise ValueError(
171            "Cannot use single_slice_dim() with shape={} and "
172            "self.full_shape={} since slice dim could be either dimension {} "
173            "or {}.".format(shape, self.full_shape, i, slice_dim))
174      slice_dim = i
175
176    return slice_dim
177
178
179class _ReuseMode(enum.Enum):
180  """Mode for variable access within a variable scope."""
181
182  # Indicates that variables are to be fetched if they already exist or
183  # otherwise created.
184  AUTO_REUSE = 1
185
186  # TODO(alive): For TensorFlow 2.0, Deprecate True/False/None API in favor of
187  #              enum values.
188  # REUSE_FALSE = 2
189  # REUSE_TRUE = 3
190
191
192# TODO(apassos) remove these forwarding symbols.
193VariableSynchronization = variables.VariableSynchronization  # pylint: disable=invalid-name
194VariableAggregation = variables.VariableAggregation  # pylint: disable=invalid-name
195
196AUTO_REUSE = _ReuseMode.AUTO_REUSE
197tf_export(v1=["AUTO_REUSE"]).export_constant(__name__, "AUTO_REUSE")
198AUTO_REUSE.__doc__ = """
199@compatibility(TF2)
200`tf.compat.v1.AUTO_REUSE` is a legacy API that is a no-op when TF2 behaviors
201are enabled.
202
203If you rely on `get_variable` and auto-reuse, see the
204[model mapping guide](https://www.tensorflow.org/guide/migrate/model_mapping)
205for more info on how to migrate your code.
206
207Note: when you use the `tf.compat.v1.keras.utils.track_tf1_style_variables`
208API as described in the above guide, `get_variable` will always behave as if
209`v1.AUTO_REUSE` is set. Without the decorator, reuse will be ignored and new
210variables will always be created, regardless of if they have already been
211created.
212@end_compatibility
213
214When passed in as the value for the `reuse` flag, `AUTO_REUSE` indicates that
215get_variable() should create the requested variable if it doesn't exist or, if
216it does exist, simply return it.
217"""
218
219_DEFAULT_USE_RESOURCE = tf2.enabled()
220
221
222@tf_export(v1=["enable_resource_variables"])
223def enable_resource_variables():
224  """Creates resource variables by default.
225
226  Resource variables are improved versions of TensorFlow variables with a
227  well-defined memory model. Accessing a resource variable reads its value, and
228  all ops which access a specific read value of the variable are guaranteed to
229  see the same value for that tensor. Writes which happen after a read (by
230  having a control or data dependency on the read) are guaranteed not to affect
231  the value of the read tensor, and similarly writes which happen before a read
232  are guaranteed to affect the value. No guarantees are made about unordered
233  read/write pairs.
234
235  Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0
236  feature.
237  """
238  global _DEFAULT_USE_RESOURCE
239  _DEFAULT_USE_RESOURCE = True
240  logging.vlog(1, "Enabling resource variables")
241  _api_usage_gauge.get_cell().set(True)
242
243
244@tf_export(v1=["resource_variables_enabled"])
245def resource_variables_enabled():
246  """Returns `True` if resource variables are enabled.
247
248  Resource variables are improved versions of TensorFlow variables with a
249  well-defined memory model. Accessing a resource variable reads its value, and
250  all ops which access a specific read value of the variable are guaranteed to
251  see the same value for that tensor. Writes which happen after a read (by
252  having a control or data dependency on the read) are guaranteed not to affect
253  the value of the read tensor, and similarly writes which happen before a read
254  are guaranteed to affect the value. No guarantees are made about unordered
255  read/write pairs.
256
257  Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0
258  feature.
259  """
260  global _DEFAULT_USE_RESOURCE
261  return _DEFAULT_USE_RESOURCE
262
263
264@deprecation.deprecated(
265    None, "non-resource variables are not supported in the long term")
266@tf_export(v1=["disable_resource_variables"])
267def disable_resource_variables():
268  """Opts out of resource variables.
269
270  If your code needs tf.disable_resource_variables() to be called to work
271  properly please file a bug.
272  """
273  global _DEFAULT_USE_RESOURCE
274  _DEFAULT_USE_RESOURCE = False
275  logging.vlog(1, "Disabling resource variables")
276  _api_usage_gauge.get_cell().set(False)
277
278
279def _needs_no_arguments(python_callable):
280  """Returns true if the callable needs no arguments to call."""
281  # TODO(bfontain): Switch to inspect.signature when we are python 3 only.
282  # signature = inspect.signature(python_callable)
283  # return not [1 for param in signature.parameters.values()
284  #             if param.default == param.empty]
285  num_arguments = len(tf_inspect.getargspec(python_callable).args)
286  if not tf_inspect.isfunction(python_callable) and not isinstance(
287      python_callable, functools.partial):
288    # getargspec includes self for function objects (which aren't
289    # functools.partial). This has no default so we need to remove it.
290    # It is not even an argument so its odd that getargspec returns this.
291    # Note that this is fixed with inspect.signature in Python 3.
292    num_arguments -= 1
293  return num_arguments == len(
294      tf_inspect.getargspec(python_callable).defaults or [])
295
296
297class _VariableStore:
298  """Variable store that carries a number of named Variables.
299
300  New variable names and new variables can be created; all stored
301  variables are initialized with the initializer passed to __init__.
302
303  Attributes:
304    vars: a dictionary with string names (same as passed in GetVar) as keys and
305      the corresponding TensorFlow Variables as values.
306  """
307
308  __slots__ = ["_vars", "_partitioned_vars", "_store_eager_variables"]
309
310  def __init__(self):
311    """Create a variable store."""
312    self._vars = {}  # A dictionary of the stored TensorFlow variables.
313    self._partitioned_vars = {}  # A dict of the stored PartitionedVariables.
314    self._store_eager_variables = False
315
316  def get_variable(self,
317                   name,
318                   shape=None,
319                   dtype=dtypes.float32,
320                   initializer=None,
321                   regularizer=None,
322                   reuse=None,
323                   trainable=None,
324                   collections=None,
325                   caching_device=None,
326                   partitioner=None,
327                   validate_shape=True,
328                   use_resource=None,
329                   custom_getter=None,
330                   constraint=None,
331                   synchronization=VariableSynchronization.AUTO,
332                   aggregation=VariableAggregation.NONE):
333    """Gets an existing variable with these parameters or create a new one.
334
335    If a variable with the given name is already stored, we return the stored
336    variable. Otherwise, we create a new one.
337
338    Set `reuse` to `True` when you only want to reuse existing Variables.
339    Set `reuse` to `False` when you only want to create new Variables.
340    Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want
341    variables to be created if they don't exist or returned if they do.
342
343    If initializer is `None` (the default), the default initializer passed in
344    the constructor is used. If that one is `None` too, we use a new
345    `glorot_uniform_initializer`. If initializer is a Tensor, we use
346    it as a value and derive the shape from the initializer.
347
348    If a partitioner is provided, a `PartitionedVariable` is returned.
349    Accessing this object as a `Tensor` returns the shards concatenated along
350    the partition axis.
351
352    Some useful partitioners are available.  See, e.g.,
353    `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
354
355    Args:
356      name: The name of the new or existing variable.
357      shape: Shape of the new or existing variable.
358      dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
359      initializer: Initializer for the variable.
360      regularizer: A (Tensor -> Tensor or None) function; the result of applying
361        it on a newly created variable will be added to the collection
362        GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
363      reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of
364        variables. When eager execution is enabled  this argument is always
365        forced to be False.
366      trainable: If `True` also add the variable to the graph collection
367        `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). `trainable`
368        defaults to `True`, unless `synchronization` is set to `ON_READ`, in
369        which case it defaults to `False`.
370      collections: List of graph collections keys to add the `Variable` to.
371        Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
372      caching_device: Optional device string or function describing where the
373        Variable should be cached for reading.  Defaults to the Variable's
374        device.  If not `None`, caches on another device.  Typical use is to
375        cache on the device where the Ops using the `Variable` reside, to
376        deduplicate copying through `Switch` and other conditional statements.
377      partitioner: Optional callable that accepts a fully defined `TensorShape`
378        and dtype of the `Variable` to be created, and returns a list of
379        partitions for each axis (currently only one axis can be partitioned).
380      validate_shape: If False, allows the variable to be initialized with a
381        value of unknown shape. If True, the default, the shape of initial_value
382        must be known.
383      use_resource: If False, creates a regular Variable. If True, creates
384        instead an experimental ResourceVariable which has well-defined
385        semantics. Defaults to False (will later change to True). When eager
386        execution is enabled this argument is always forced to be true.
387      custom_getter: Callable that takes as a first argument the true getter,
388        and allows overwriting the internal get_variable method. The signature
389        of `custom_getter` should match that of this method,
390        but the most future-proof version will allow for changes: `def
391          custom_getter(getter, *args, **kwargs)`.  Direct access to
392        all `get_variable` parameters is also allowed: `def
393          custom_getter(getter, name, *args, **kwargs)`.  A simple identity
394        custom getter that simply creates variables with modified names is:
395          ```python
396        def custom_getter(getter, name, *args, **kwargs): return getter(name +
397          '_suffix', *args, **kwargs) ```
398      constraint: An optional projection function to be applied to the variable
399        after being updated by an `Optimizer` (e.g. used to implement norm
400        constraints or value constraints for layer weights). The function must
401        take as input the unprojected Tensor representing the value of the
402        variable and return the Tensor for the projected value (which must have
403        the same shape). Constraints are not safe to use when doing asynchronous
404        distributed training.
405      synchronization: Indicates when a distributed a variable will be
406        aggregated. Accepted values are constants defined in the class
407        `tf.VariableSynchronization`. By default the synchronization is set to
408        `AUTO` and the current `DistributionStrategy` chooses when to
409        synchronize.
410      aggregation: Indicates how a distributed variable will be aggregated.
411        Accepted values are constants defined in the class
412        `tf.VariableAggregation`.
413
414    Returns:
415      The created or existing `Variable` (or `PartitionedVariable`, if a
416      partitioner was used).
417
418    Raises:
419      ValueError: when creating a new variable and shape is not declared,
420        when reusing a variable and specifying a conflicting shape,
421        or when violating reuse during variable creation.
422      RuntimeError: when eager execution is enabled and not called from an
423        EagerVariableStore.
424    """
425    if custom_getter is not None and not callable(custom_getter):
426      raise ValueError("Passed a custom_getter which is not callable: %s" %
427                       custom_getter)
428
429    with ops.init_scope():
430      if context.executing_eagerly():
431        # Variable creation and initialization takes place in `init_scope`s;
432        # as such, if an `init_scope` lifts us into the eager context, then we
433        # need to use `ResourceVariable`s.
434        use_resource = True
435
436    # Note that it's fine to reuse eager variables whose initialization was
437    # lifted from a function-building graph into the eager context (that's why
438    # the following clause is not wrapped in an `init_scope`); lifted variables
439    # are tracked by the graph's `VariableStore`.
440    if context.executing_eagerly():
441      if not self._store_eager_variables and reuse:
442        raise RuntimeError(
443            "When eager execution is enabled variable reuse is only supported"
444            " when an EagerVariableStore is active. See the documentation on"
445            " EagerVariableStore for example usage.")
446      if self._store_eager_variables:
447        reuse = AUTO_REUSE
448
449    # If a *_ref type is passed in an error would be triggered further down the
450    # stack. We prevent this using base_dtype to get a non-ref version of the
451    # type, before doing anything else. When _ref types are removed in favor of
452    # resources, this line can be removed.
453    try:
454      dtype = dtype.base_dtype
455    except AttributeError:
456      # .base_dtype not existing means that we will try and use the raw dtype
457      # which was passed in - this might be a NumPy type which is valid.
458      pass
459
460    # This is the main logic of get_variable.  However, custom_getter
461    # may override this logic.  So we save it as a callable and pass
462    # it to custom_getter.
463    # Note: the parameters of _true_getter, and their documentation, match
464    # *exactly* item-for-item with the docstring of this method.
465    def _true_getter(  # pylint: disable=missing-docstring
466        name,
467        shape=None,
468        dtype=dtypes.float32,
469        initializer=None,
470        regularizer=None,
471        reuse=None,
472        trainable=None,
473        collections=None,
474        caching_device=None,
475        partitioner=None,
476        validate_shape=True,
477        use_resource=None,
478        constraint=None,
479        synchronization=VariableSynchronization.AUTO,
480        aggregation=VariableAggregation.NONE):
481      is_scalar = (
482          shape is not None and isinstance(shape, collections_abc.Sequence) and
483          not shape)
484      # Partitioned variable case
485      if partitioner is not None and not is_scalar:
486        if not callable(partitioner):
487          raise ValueError("Partitioner must be callable, but received: %s" %
488                           partitioner)
489        with ops.name_scope(None):
490          return self._get_partitioned_variable(
491              name=name,
492              shape=shape,
493              dtype=dtype,
494              initializer=initializer,
495              regularizer=regularizer,
496              reuse=reuse,
497              trainable=trainable,
498              collections=collections,
499              caching_device=caching_device,
500              partitioner=partitioner,
501              validate_shape=validate_shape,
502              use_resource=use_resource,
503              constraint=constraint,
504              synchronization=synchronization,
505              aggregation=aggregation)
506
507      # Special case for partitioned variable to allow reuse without having to
508      # specify partitioner.
509      if (reuse is True and partitioner is None
510          and name in self._partitioned_vars):
511        return self._get_partitioned_variable(
512            name=name,
513            shape=shape,
514            dtype=dtype,
515            initializer=initializer,
516            regularizer=regularizer,
517            reuse=reuse,
518            trainable=trainable,
519            collections=collections,
520            caching_device=caching_device,
521            partitioner=None,
522            validate_shape=validate_shape,
523            use_resource=use_resource,
524            constraint=constraint,
525            synchronization=synchronization,
526            aggregation=aggregation)
527
528      # Single variable case
529      if "%s/part_0" % name in self._vars:
530        raise ValueError(
531            "No partitioner was provided, but a partitioned version of the "
532            "variable was found: %s/part_0. Perhaps a variable of the same "
533            "name was already created with partitioning?" % name)
534
535      return self._get_single_variable(
536          name=name,
537          shape=shape,
538          dtype=dtype,
539          initializer=initializer,
540          regularizer=regularizer,
541          reuse=reuse,
542          trainable=trainable,
543          collections=collections,
544          caching_device=caching_device,
545          validate_shape=validate_shape,
546          use_resource=use_resource,
547          constraint=constraint,
548          synchronization=synchronization,
549          aggregation=aggregation)
550
551    synchronization, aggregation, trainable = (
552        variables.validate_synchronization_aggregation_trainable(
553            synchronization, aggregation, trainable, name))
554
555    if custom_getter is not None:
556      # Handle backwards compatibility with getter arguments that were added
557      # to the API after users started writing custom getters.
558      custom_getter_kwargs = {
559          "getter": _true_getter,
560          "name": name,
561          "shape": shape,
562          "dtype": dtype,
563          "initializer": initializer,
564          "regularizer": regularizer,
565          "reuse": reuse,
566          "trainable": trainable,
567          "collections": collections,
568          "caching_device": caching_device,
569          "partitioner": partitioner,
570          "validate_shape": validate_shape,
571          "use_resource": use_resource,
572          "synchronization": synchronization,
573          "aggregation": aggregation,
574      }
575      # `fn_args` and `has_kwargs` can handle functions, `functools.partial`,
576      # `lambda`.
577      if ("constraint" in function_utils.fn_args(custom_getter) or
578          function_utils.has_kwargs(custom_getter)):
579        custom_getter_kwargs["constraint"] = constraint
580      return custom_getter(**custom_getter_kwargs)
581    else:
582      return _true_getter(
583          name,
584          shape=shape,
585          dtype=dtype,
586          initializer=initializer,
587          regularizer=regularizer,
588          reuse=reuse,
589          trainable=trainable,
590          collections=collections,
591          caching_device=caching_device,
592          partitioner=partitioner,
593          validate_shape=validate_shape,
594          use_resource=use_resource,
595          constraint=constraint,
596          synchronization=synchronization,
597          aggregation=aggregation)
598
599  def _get_partitioned_variable(self,
600                                name,
601                                partitioner,
602                                shape=None,
603                                dtype=dtypes.float32,
604                                initializer=None,
605                                regularizer=None,
606                                reuse=None,
607                                trainable=None,
608                                collections=None,
609                                caching_device=None,
610                                validate_shape=True,
611                                use_resource=None,
612                                constraint=None,
613                                synchronization=VariableSynchronization.AUTO,
614                                aggregation=VariableAggregation.NONE):
615    """Gets or creates a sharded variable list with these parameters.
616
617    The `partitioner` must be a callable that accepts a fully defined
618    `TensorShape` and returns a sequence of integers (the `partitions`).
619    These integers describe how to partition the given sharded `Variable`
620    along the given dimension.  That is, `partitions[1] = 3` means split
621    the `Variable` into 3 shards along dimension 1.  Currently, sharding along
622    only one axis is supported.
623
624    If the list of variables with the given name (prefix) is already stored,
625    we return the stored variables. Otherwise, we create a new one.
626
627    Set `reuse` to `True` when you only want to reuse existing Variables.
628    Set `reuse` to `False` when you only want to create new Variables.
629    Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want
630    variables to be created if they don't exist or returned if they do.
631
632    If initializer is `None` (the default), the default initializer passed in
633    the constructor is used. If that one is `None` too, we use a new
634    `glorot_uniform_initializer`. If initializer is a Tensor, we use
635    it as a value and derive the shape from the initializer.
636
637    If the initializer is a callable, then it will be called for each
638    shard.  Otherwise the initializer should match the shape of the entire
639    sharded Variable, and it will be sliced accordingly for each shard.
640
641    Some useful partitioners are available.  See, e.g.,
642    `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
643
644    Args:
645      name: the name of the new or existing sharded variable.
646      partitioner: Optional callable that accepts a fully defined `TensorShape`
647        and `dtype` of the Variable to be created, and returns a list of
648        partitions for each axis (currently only one axis can be partitioned).
649      shape: shape of the new or existing sharded variable.
650      dtype: type of the new or existing sharded variable (defaults to
651        `DT_FLOAT`).
652      initializer: initializer for the sharded variable.
653      regularizer: a (Tensor -> Tensor or None) function; the result of applying
654        it on a newly created variable will be added to the collection
655        GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
656      reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of
657        variables.
658      trainable: If `True` also add the variable to the graph collection
659        `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
660      collections: List of graph collections keys to add the Variable to.
661        Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
662      caching_device: Optional device string or function describing where the
663        Variable should be cached for reading.  Defaults to the Variable's
664        device.  If not `None`, caches on another device.  Typical use is to
665        cache on the device where the Ops using the Variable reside, to
666        deduplicate copying through `Switch` and other conditional statements.
667      validate_shape: If False, allows the variable to be initialized with a
668        value of unknown shape. If True, the default, the shape of initial_value
669        must be known.
670      use_resource: If False, creates a regular Variable. If True, creates an
671        experimental ResourceVariable which has well-defined semantics. Defaults
672        to False (will later change to True).
673      constraint: An optional projection function to be applied to the variable
674        after being updated by an `Optimizer` (e.g. used to implement norm
675        constraints or value constraints for layer weights). The function must
676        take as input the unprojected Tensor representing the value of the
677        variable and return the Tensor for the projected value (which must have
678        the same shape). Constraints are not safe to use when doing asynchronous
679        distributed training.
680      synchronization: Indicates when a distributed a variable will be
681        aggregated. Accepted values are constants defined in the class
682        `tf.VariableSynchronization`. By default the synchronization is set to
683        `AUTO` and the current `DistributionStrategy` chooses when to
684        synchronize.
685      aggregation: Indicates how a distributed variable will be aggregated.
686        Accepted values are constants defined in the class
687        `tf.VariableAggregation`.
688
689    Returns:
690      A `PartitionedVariable` object.
691
692    Raises:
693      ValueError: when creating a new variable and shape is not declared,
694        when reusing a variable and specifying a conflicting shape,
695        when violating reuse during variable creation, or if an existing
696        sharded variable exists for the given name but with different sharding.
697    """
698    initializing_from_value = initializer is not None and isinstance(
699        initializer, ops.Tensor)
700    if name in self._vars:
701      raise ValueError(
702          "A partitioner was provided, but an unpartitioned version of the "
703          "variable was found: %s.  Perhaps a variable of the same name was "
704          "already created without partitioning?" % name)
705
706    shape = tensor_shape.as_shape(shape)
707    if initializing_from_value:
708      shape = shape.merge_with(initializer.get_shape())
709
710    partitions = None
711    if not reuse or partitioner:
712      partitions = _call_partitioner(partitioner, shape, dtype)
713
714    if name in self._partitioned_vars:
715      if reuse is False:
716        raise ValueError(
717            "Partitioned variable with name %s already exists. Did you mean to "
718            "set reuse=True or reuse=tf.AUTO_REUSE in VarScope?" % name)
719
720      existing_var = self._partitioned_vars[name]
721      if not shape.is_compatible_with(existing_var.get_shape()):
722        raise ValueError(
723            "Trying to reuse partitioned variable %s, but specified shape %s "
724            "and found shape %s." % (name, shape, existing_var.get_shape()))
725      if not dtype.is_compatible_with(existing_var.dtype):
726        raise ValueError(
727            "Trying to reuse partitioned variable %s, but specified dtype %s "
728            "and found dtype %s." % (name, dtype.name, existing_var.dtype.name))
729
730      # pylint: disable=protected-access
731      if (partitions is not None and
732          existing_var._get_partitions() != partitions):
733        raise ValueError(
734            "Trying to reuse partitioned variable %s, but specified partitions "
735            "%s and found partitions %s." %
736            (name, partitions, existing_var._get_partitions()))
737      # pylint: enable=protected-access
738
739      return existing_var
740
741    if reuse is True:
742      raise ValueError("PartitionedVariable %s does not exist, or was not "
743                       "created with tf.get_variable(). Did you mean to set "
744                       "reuse=False or reuse=tf.AUTO_REUSE in VarScope?" % name)
745
746    slice_dim, num_slices = _get_slice_dim_and_num_slices(partitions)
747
748    if "%s/part_0" % name in self._vars:
749      if "%s/part_%d" % (name, num_slices - 1) not in self._vars:
750        raise ValueError(
751            "Partitioner returned a different partitioning than what was "
752            "already found.  Partitioner returned %d shards, and shard "
753            "%s/part_0 was found, but %s/part_%d was not." %
754            (num_slices, name, name, num_slices - 1))
755      if "%s/part_%d" % (name, num_slices) in self._vars:
756        raise ValueError(
757            "Partitioner returned a different partitioning than what was "
758            "already found.  Partitioner returned %d shards, and shard "
759            "%s/part_0 was found, but so was the extra shard %s/part_%d." %
760            (num_slices, name, name, num_slices))
761
762    vs = []
763    for i, (var_offset, var_shape) in enumerate(
764        _iter_slices(shape.as_list(), num_slices, slice_dim)):
765      partition_info = _PartitionInfo(
766          full_shape=shape.as_list(), var_offset=var_offset)
767      var_full_name = "%s/part_%d" % (name, i)
768      with ops.name_scope(
769          var_full_name + "/PartitionedInitializer", skip_on_eager=False):
770        # Create the tensor to initialize the variable with default value.
771        if initializer is None:
772          init, initializing_from_value = self._get_default_initializer(
773              name=name, shape=shape, dtype=dtype)
774          if initializing_from_value:
775            init_shape = None
776          else:
777            init_shape = var_shape
778        elif callable(initializer):
779          init = initializer
780          init_shape = var_shape
781        elif isinstance(initializer, ops.Tensor):
782          init = array_ops.slice(initializer, var_offset, var_shape)
783          # Use the dtype of the given tensor.
784          dtype = init.dtype.base_dtype
785          init_shape = None
786        else:
787          init = ops.convert_to_tensor(initializer, dtype=dtype)
788          init = array_ops.slice(init, var_offset, var_shape)
789          init_shape = None
790
791      with ops.name_scope(None):
792        var = self._get_single_variable(
793            name=var_full_name,
794            shape=init_shape,
795            dtype=dtype,
796            initializer=init,
797            partition_info=partition_info,
798            regularizer=regularizer,
799            reuse=reuse,
800            trainable=trainable,
801            collections=collections,
802            caching_device=caching_device,
803            validate_shape=validate_shape,
804            use_resource=use_resource,
805            constraint=constraint,
806            synchronization=synchronization,
807            aggregation=aggregation)
808
809      # pylint: disable=protected-access
810      var._set_save_slice_info(
811          variables.Variable.SaveSliceInfo(name, shape.as_list(), var_offset,
812                                           var_shape))
813      vs.append(var)
814      # pylint: enable=protected-access
815
816    partitioned_var = variables.PartitionedVariable(
817        name=name,
818        shape=shape,
819        dtype=dtype,
820        variable_list=vs,
821        partitions=partitions)
822    if not context.executing_eagerly() or self._store_eager_variables:
823      self._partitioned_vars[name] = partitioned_var
824    return partitioned_var
825
826  def _get_single_variable(self,
827                           name,
828                           shape=None,
829                           dtype=dtypes.float32,
830                           initializer=None,
831                           regularizer=None,
832                           partition_info=None,
833                           reuse=None,
834                           trainable=None,
835                           collections=None,
836                           caching_device=None,
837                           validate_shape=True,
838                           use_resource=None,
839                           constraint=None,
840                           synchronization=VariableSynchronization.AUTO,
841                           aggregation=VariableAggregation.NONE):
842    """Get or create a single Variable (e.g.
843
844    a shard or entire variable).
845
846    See the documentation of get_variable above (ignore partitioning components)
847    for details.
848
849    Args:
850      name: see get_variable.
851      shape: see get_variable.
852      dtype: see get_variable.
853      initializer: see get_variable.
854      regularizer: see get_variable.
855      partition_info: _PartitionInfo object.
856      reuse: see get_variable.
857      trainable: see get_variable.
858      collections: see get_variable.
859      caching_device: see get_variable.
860      validate_shape: see get_variable.
861      use_resource: see get_variable.
862      constraint: see get_variable.
863      synchronization: see get_variable.
864      aggregation: see get_variable.
865
866    Returns:
867      A Variable.  See documentation of get_variable above.
868
869    Raises:
870      ValueError: See documentation of get_variable above.
871    """
872    # Set to true if initializer is a constant.
873    initializing_from_value = False
874    if initializer is not None and not callable(initializer):
875      initializing_from_value = True
876    if shape is not None and initializing_from_value:
877      raise ValueError("If initializer is a constant, do not specify shape.")
878
879    dtype = dtypes.as_dtype(dtype)
880    shape = tensor_shape.as_shape(shape)
881
882    if name in self._vars:
883      # Here we handle the case when returning an existing variable.
884      if reuse is False:
885        var = self._vars[name]
886        err_msg = ("Variable %s already exists, disallowed."
887                   " Did you mean to set reuse=True or "
888                   "reuse=tf.AUTO_REUSE in VarScope?" % name)
889        # ResourceVariables don't have an op associated with so no traceback
890        if isinstance(var, resource_variable_ops.ResourceVariable):
891          raise ValueError(err_msg)
892        tb = var.op.traceback[::-1]
893        # Throw away internal tf entries and only take a few lines. In some
894        # cases the traceback can be longer (e.g. if someone uses factory
895        # functions to create variables) so we take more than needed in the
896        # default case.
897        tb = [x for x in tb if "tensorflow/python" not in x[0]][:5]
898        raise ValueError("%s Originally defined at:\n\n%s" %
899                         (err_msg, "".join(traceback.format_list(tb))))
900      found_var = self._vars[name]
901      if not shape.is_compatible_with(found_var.get_shape()):
902        raise ValueError("Trying to share variable %s, but specified shape %s"
903                         " and found shape %s." %
904                         (name, shape, found_var.get_shape()))
905      if not dtype.is_compatible_with(found_var.dtype):
906        dtype_str = dtype.name
907        found_type_str = found_var.dtype.name
908        raise ValueError("Trying to share variable %s, but specified dtype %s"
909                         " and found dtype %s." %
910                         (name, dtype_str, found_type_str))
911      return found_var
912
913    # The code below handles only the case of creating a new variable.
914    if reuse is True:
915      raise ValueError("Variable %s does not exist, or was not created with "
916                       "tf.get_variable(). Did you mean to set "
917                       "reuse=tf.AUTO_REUSE in VarScope?" % name)
918
919    # Create the tensor to initialize the variable with default value.
920    if initializer is None:
921      initializer, initializing_from_value = self._get_default_initializer(
922          name=name, shape=shape, dtype=dtype)
923    # Enter an init scope when creating the initializer.
924    with ops.init_scope():
925      if initializing_from_value:
926        init_val = initializer
927        variable_dtype = None
928      else:
929        # Instantiate initializer if provided initializer is a type object.
930        if tf_inspect.isclass(initializer):
931          initializer = initializer()
932        if shape.is_fully_defined():
933          if "partition_info" in tf_inspect.getargspec(initializer).args:
934            init_val = functools.partial(initializer,
935                                         shape.as_list(),
936                                         dtype=dtype,
937                                         partition_info=partition_info)
938          else:
939            init_val = functools.partial(initializer,
940                                         shape.as_list(), dtype=dtype)
941          variable_dtype = dtype.base_dtype
942        elif _needs_no_arguments(initializer):
943          init_val = initializer
944          variable_dtype = None
945        else:
946          raise ValueError("The initializer passed is not valid. It should "
947                           "be a callable with no arguments and the "
948                           "shape should not be provided or an instance of "
949                           "`tf.keras.initializers.*' and `shape` should be "
950                           "fully defined.")
951
952    # Create the variable.
953    if use_resource is None:
954      # Set the default value if unspecified.
955      use_resource = _DEFAULT_USE_RESOURCE
956    v = variables.VariableV1(
957        initial_value=init_val,
958        name=name,
959        trainable=trainable,
960        collections=collections,
961        caching_device=caching_device,
962        dtype=variable_dtype,
963        validate_shape=validate_shape,
964        constraint=constraint,
965        use_resource=use_resource,
966        synchronization=synchronization,
967        aggregation=aggregation)
968    if context.executing_eagerly() and self._store_eager_variables:
969      if collections:
970        ops.add_to_collections(collections, v)
971      else:
972        ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, v)
973      if trainable:
974        ops.add_to_collection(ops.GraphKeys.TRAINABLE_VARIABLES, v)
975
976    if not context.executing_eagerly() or self._store_eager_variables:
977      # In eager mode we do not want to keep default references to Variable
978      # objects as this will prevent their memory from being released.
979      self._vars[name] = v
980    logging.vlog(1, "Created variable %s with shape %s and init %s", v.name,
981                 format(shape), initializer)
982
983    # Run the regularizer if requested and save the resulting loss.
984    if regularizer:
985      def make_regularizer_op():
986        with ops.colocate_with(v):
987          with ops.name_scope(name + "/Regularizer/"):
988            return regularizer(v)
989
990      if regularizer(v) is not None:
991        lazy_eval_tensor = _LazyEvalTensor(make_regularizer_op)
992        ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES,
993                              lazy_eval_tensor)
994
995    return v
996
997  # Initialize variable when no initializer provided
998  def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32):
999    """Provide a default initializer and a corresponding value.
1000
1001    Args:
1002      name: see get_variable.
1003      shape: see get_variable.
1004      dtype: see get_variable.
1005
1006    Returns:
1007      initializer and initializing_from_value. See get_variable above.
1008
1009    Raises:
1010      ValueError: When giving unsupported dtype.
1011    """
1012    del shape
1013    # If dtype is DT_FLOAT, provide a uniform unit scaling initializer
1014    if dtype.is_floating:
1015      initializer = init_ops.glorot_uniform_initializer()
1016      initializing_from_value = False
1017    # If dtype is DT_INT/DT_UINT, provide a default value `zero`
1018    # If dtype is DT_BOOL, provide a default value `FALSE`
1019    elif (dtype.is_integer or dtype.is_unsigned or dtype.is_bool or
1020          dtype == dtypes.string):
1021      initializer = init_ops.zeros_initializer()
1022      initializing_from_value = False
1023    # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
1024    else:
1025      raise ValueError("An initializer for variable %s of %s is required" %
1026                       (name, dtype.base_dtype))
1027
1028    return initializer, initializing_from_value
1029
1030
1031class _LazyEvalTensor(core.Tensor):
1032  """A Tensor-like object that only evaluates its thunk when used."""
1033
1034  def __init__(self, thunk):
1035    """Initializes a _LazyEvalTensor object.
1036
1037    Args:
1038      thunk: A callable. A thunk which computes the value of the tensor.
1039    """
1040    self._thunk = thunk
1041    self._master_tensor = thunk()
1042
1043  def _as_tensor(self, dtype=None, name=None, as_ref=False):
1044    del name
1045    assert not as_ref
1046    assert dtype in [None, self.dtype]
1047
1048    return self._thunk()
1049
1050
1051def _make_master_property(name):
1052  @property
1053  def prop(self):
1054    return getattr(self._master_tensor, name)  # pylint: disable=protected-access
1055  return prop
1056
1057_master_property_list = ("device", "dtype", "graph", "name", "op", "shape",
1058                         "value_index")
1059for _name in _master_property_list:
1060  setattr(_LazyEvalTensor, _name, _make_master_property(_name))
1061
1062
1063def _make_master_method(name):
1064  def method(self, *args, **kwargs):
1065    return getattr(self._master_tensor, name)(*args, **kwargs)  # pylint: disable=protected-access
1066  return method
1067
1068_master_method_list = ("get_shape", "__str__", "shape_as_list")
1069for _name in _master_method_list:
1070  setattr(_LazyEvalTensor, _name, _make_master_method(_name))
1071
1072
1073def _make_op_method(name):
1074  def method(self, *args, **kwargs):
1075    return getattr(self._as_tensor(), name)(*args, **kwargs)  # pylint: disable=protected-access
1076  return method
1077
1078_op_list = ("__abs__", "__add__", "__and__", "__bool__", "__div__", "__eq__",
1079            "__floordiv__", "__ge__", "__getitem__", "__gt__", "__invert__",
1080            "__iter__", "__le__", "__len__", "__lt__", "__matmul__", "__mod__",
1081            "__mul__", "__ne__", "__neg__", "__nonzero__", "__or__", "__pow__",
1082            "__radd__", "__rand__", "__rdiv__", "__rfloordiv__", "__rmatmul__",
1083            "__rmod__", "__rmul__", "__ror__", "__rpow__", "__rsub__",
1084            "__rtruediv__", "__rxor__", "__sub__", "__truediv__", "__xor__",
1085            "eval", "numpy")
1086for _name in _op_list:
1087  setattr(_LazyEvalTensor, _name, _make_op_method(_name))
1088
1089
1090ops.register_tensor_conversion_function(
1091    _LazyEvalTensor,
1092    lambda val, dtype, name, as_ref: val._as_tensor(dtype, name, as_ref)  # pylint: disable=protected-access
1093    )
1094
1095session.register_session_run_conversion_functions(
1096    _LazyEvalTensor,
1097    lambda fetch: ([fetch._master_tensor], lambda fetched_vals: fetched_vals[0])  # pylint: disable=protected-access
1098    )
1099
1100
1101# To stop regularization, use this regularizer
1102@tf_export(v1=["no_regularizer"])
1103def no_regularizer(_):
1104  """Use this function to prevent regularization of variables."""
1105  return None
1106
1107
1108# TODO(alive): support caching devices and partitioned variables in Eager mode.
1109@tf_export(v1=["VariableScope"])
1110class VariableScope:
1111  """Variable scope object to carry defaults to provide to `get_variable`.
1112
1113  Many of the arguments we need for `get_variable` in a variable store are most
1114  easily handled with a context. This object is used for the defaults.
1115
1116  Attributes:
1117    name: name of the current scope, used as prefix in get_variable.
1118    initializer: default initializer passed to get_variable.
1119    regularizer: default regularizer passed to get_variable.
1120    reuse: Boolean, None, or tf.compat.v1.AUTO_REUSE, setting the reuse in
1121      get_variable. When eager execution is enabled this argument is always
1122      forced to be False.
1123    caching_device: string, callable, or None: the caching device passed to
1124      get_variable.
1125    partitioner: callable or `None`: the partitioner passed to `get_variable`.
1126    custom_getter: default custom getter passed to get_variable.
1127    name_scope: The name passed to `tf.name_scope`.
1128    dtype: default type passed to get_variable (defaults to DT_FLOAT).
1129    use_resource: if False, create a normal Variable; if True create an
1130      experimental ResourceVariable with well-defined semantics. Defaults to
1131      False (will later change to True). When eager execution is enabled this
1132      argument is always forced to be True.
1133    constraint: An optional projection function to be applied to the variable
1134      after being updated by an `Optimizer` (e.g. used to implement norm
1135      constraints or value constraints for layer weights). The function must
1136      take as input the unprojected Tensor representing the value of the
1137      variable and return the Tensor for the projected value (which must have
1138      the same shape). Constraints are not safe to use when doing asynchronous
1139      distributed training.
1140  """
1141
1142  def __init__(self,
1143               reuse,
1144               name="",
1145               initializer=None,
1146               regularizer=None,
1147               caching_device=None,
1148               partitioner=None,
1149               custom_getter=None,
1150               name_scope="",
1151               dtype=dtypes.float32,
1152               use_resource=None,
1153               constraint=None):
1154    """Creates a new VariableScope with the given properties."""
1155    self._name = name
1156    self._initializer = initializer
1157    self._regularizer = regularizer
1158    self._reuse = reuse
1159    self._caching_device = caching_device
1160    self._partitioner = partitioner
1161    self._custom_getter = custom_getter
1162    self._name_scope = name_scope
1163    self._dtype = dtype
1164    self._use_resource = use_resource
1165    self._constraint = constraint
1166    if context.executing_eagerly():
1167      if self._caching_device is not None:
1168        raise NotImplementedError("Caching devices is not yet supported "
1169                                  "when eager execution is enabled.")
1170      self._reuse = AUTO_REUSE
1171      self._use_resource = True
1172
1173  @property
1174  def name(self):
1175    return self._name
1176
1177  @property
1178  def original_name_scope(self):
1179    return self._name_scope
1180
1181  @property
1182  def reuse(self):
1183    return self._reuse
1184
1185  @property
1186  def initializer(self):
1187    return self._initializer
1188
1189  @property
1190  def dtype(self):
1191    return self._dtype
1192
1193  @property
1194  def use_resource(self):
1195    return self._use_resource
1196
1197  @property
1198  def regularizer(self):
1199    return self._regularizer
1200
1201  @property
1202  def caching_device(self):
1203    return self._caching_device
1204
1205  @property
1206  def partitioner(self):
1207    return self._partitioner
1208
1209  @property
1210  def custom_getter(self):
1211    return self._custom_getter
1212
1213  @property
1214  def constraint(self):
1215    return self._constraint
1216
1217  def reuse_variables(self):
1218    """Reuse variables in this scope."""
1219    self._reuse = True
1220
1221  def set_initializer(self, initializer):
1222    """Set initializer for this scope."""
1223    self._initializer = initializer
1224
1225  def set_dtype(self, dtype):
1226    """Set data type for this scope."""
1227    self._dtype = dtype
1228
1229  def set_use_resource(self, use_resource):
1230    """Sets whether to use ResourceVariables for this scope."""
1231    if context.executing_eagerly() and not use_resource:
1232      raise ValueError("When eager execution is enabled, "
1233                       "use_resource cannot be set to false.")
1234    self._use_resource = use_resource
1235
1236  def set_regularizer(self, regularizer):
1237    """Set regularizer for this scope."""
1238    self._regularizer = regularizer
1239
1240  def set_caching_device(self, caching_device):
1241    """Set caching_device for this scope."""
1242    if context.executing_eagerly():
1243      raise NotImplementedError("Caching devices are not yet supported "
1244                                "when eager execution is enabled.")
1245    self._caching_device = caching_device
1246
1247  def set_partitioner(self, partitioner):
1248    """Set partitioner for this scope."""
1249    self._partitioner = partitioner
1250
1251  def set_custom_getter(self, custom_getter):
1252    """Set custom getter for this scope."""
1253    self._custom_getter = custom_getter
1254
1255  def get_collection(self, name):
1256    """Get this scope's variables."""
1257    scope = self._name + "/" if self._name else ""
1258    return ops.get_collection(name, scope)
1259
1260  def trainable_variables(self):
1261    """Get this scope's trainable variables."""
1262    return self.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
1263
1264  def global_variables(self):
1265    """Get this scope's global variables."""
1266    return self.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
1267
1268  def local_variables(self):
1269    """Get this scope's local variables."""
1270    return self.get_collection(ops.GraphKeys.LOCAL_VARIABLES)
1271
1272  def get_variable(self,
1273                   var_store,
1274                   name,
1275                   shape=None,
1276                   dtype=None,
1277                   initializer=None,
1278                   regularizer=None,
1279                   reuse=None,
1280                   trainable=None,
1281                   collections=None,
1282                   caching_device=None,
1283                   partitioner=None,
1284                   validate_shape=True,
1285                   use_resource=None,
1286                   custom_getter=None,
1287                   constraint=None,
1288                   synchronization=VariableSynchronization.AUTO,
1289                   aggregation=VariableAggregation.NONE):
1290    """Gets an existing variable with this name or create a new one."""
1291    if regularizer is None:
1292      regularizer = self._regularizer
1293    if caching_device is None:
1294      caching_device = self._caching_device
1295    if partitioner is None:
1296      partitioner = self._partitioner
1297    if custom_getter is None:
1298      custom_getter = self._custom_getter
1299    if context.executing_eagerly():
1300      reuse = False
1301      use_resource = True
1302    else:
1303      if reuse is None:
1304        reuse = self._reuse
1305      if use_resource is None:
1306        use_resource = self._use_resource
1307
1308    full_name = self.name + "/" + name if self.name else name
1309    # Variable names only depend on variable_scope (full_name here),
1310    # not name_scope, so we reset it below for the time of variable creation.
1311    with ops.name_scope(None, skip_on_eager=False):
1312      # Check that `initializer` dtype and `dtype` are consistent before
1313      # replacing them with defaults.
1314      if (dtype is not None and initializer is not None and
1315          not callable(initializer)):
1316        init_dtype = ops.convert_to_tensor(initializer).dtype.base_dtype
1317        if init_dtype != dtype:
1318          raise ValueError("Initializer type '%s' and explicit dtype '%s' "
1319                           "don't match." % (init_dtype, dtype))
1320      if initializer is None:
1321        initializer = self._initializer
1322      if constraint is None:
1323        constraint = self._constraint
1324      if dtype is None:
1325        dtype = self._dtype
1326      return var_store.get_variable(
1327          full_name,
1328          shape=shape,
1329          dtype=dtype,
1330          initializer=initializer,
1331          regularizer=regularizer,
1332          reuse=reuse,
1333          trainable=trainable,
1334          collections=collections,
1335          caching_device=caching_device,
1336          partitioner=partitioner,
1337          validate_shape=validate_shape,
1338          use_resource=use_resource,
1339          custom_getter=custom_getter,
1340          constraint=constraint,
1341          synchronization=synchronization,
1342          aggregation=aggregation)
1343
1344  def _get_partitioned_variable(self,
1345                                var_store,
1346                                name,
1347                                shape=None,
1348                                dtype=None,
1349                                initializer=None,
1350                                regularizer=None,
1351                                trainable=None,
1352                                collections=None,
1353                                caching_device=None,
1354                                partitioner=None,
1355                                validate_shape=True,
1356                                use_resource=None,
1357                                constraint=None,
1358                                synchronization=VariableSynchronization.AUTO,
1359                                aggregation=VariableAggregation.NONE):
1360    """Gets an existing variable with this name or create a new one."""
1361    if initializer is None:
1362      initializer = self._initializer
1363    if regularizer is None:
1364      regularizer = self._regularizer
1365    if constraint is None:
1366      constraint = self._constraint
1367    if caching_device is None:
1368      caching_device = self._caching_device
1369    if partitioner is None:
1370      partitioner = self._partitioner
1371    if dtype is None:
1372      dtype = self._dtype
1373    if use_resource is None:
1374      use_resource = self._use_resource
1375
1376    if self._custom_getter is not None:
1377      raise ValueError(
1378          "Private access to _get_partitioned_variable is not allowed when "
1379          "a custom getter is set.  Current custom getter: %s.  "
1380          "It is likely that you're using create_partitioned_variables.  "
1381          "If so, consider instead using get_variable with a non-empty "
1382          "partitioner parameter instead." % self._custom_getter)
1383
1384    if partitioner is None:
1385      raise ValueError("No partitioner was specified")
1386
1387    # This allows the variable scope name to be used as the variable name if
1388    # this function is invoked with an empty name arg, for backward
1389    # compatibility with create_partitioned_variables().
1390    full_name_list = []
1391    if self.name:
1392      full_name_list.append(self.name)
1393    if name:
1394      full_name_list.append(name)
1395    full_name = "/".join(full_name_list)
1396
1397    # Variable names only depend on variable_scope (full_name here),
1398    # not name_scope, so we reset it below for the time of variable creation.
1399    with ops.name_scope(None, skip_on_eager=False):
1400      # pylint: disable=protected-access
1401      return var_store._get_partitioned_variable(
1402          full_name,
1403          shape=shape,
1404          dtype=dtype,
1405          initializer=initializer,
1406          regularizer=regularizer,
1407          reuse=self.reuse,
1408          trainable=trainable,
1409          collections=collections,
1410          caching_device=caching_device,
1411          partitioner=partitioner,
1412          validate_shape=validate_shape,
1413          use_resource=use_resource,
1414          constraint=constraint,
1415          synchronization=synchronization,
1416          aggregation=aggregation)
1417      # pylint: enable=protected-access
1418
1419
1420_VARSTORE_KEY = ("__variable_store",)
1421_VARSCOPESTORE_KEY = ("__varscope",)
1422
1423
1424class _VariableScopeStore(threading.local):
1425  """A thread local store for the current variable scope and scope counts."""
1426
1427  def __init__(self):
1428    super(_VariableScopeStore, self).__init__()
1429    self.current_scope = VariableScope(False)
1430    self.variable_scopes_count = {}
1431
1432  def open_variable_scope(self, scope_name):
1433    if scope_name in self.variable_scopes_count:
1434      self.variable_scopes_count[scope_name] += 1
1435    else:
1436      self.variable_scopes_count[scope_name] = 1
1437
1438  def close_variable_subscopes(self, scope_name):
1439    if scope_name is None:
1440      for k in self.variable_scopes_count:
1441        self.variable_scopes_count[k] = 0
1442    else:
1443      startswith_check = scope_name + "/"
1444      startswith_len = len(startswith_check)
1445      for k in self.variable_scopes_count:
1446        if k[:startswith_len] == startswith_check:
1447          self.variable_scopes_count[k] = 0
1448
1449  def variable_scope_count(self, scope_name):
1450    return self.variable_scopes_count.get(scope_name, 0)
1451
1452
1453def get_variable_scope_store():
1454  """Returns the variable scope store for current thread."""
1455  scope_store = ops.get_collection(_VARSCOPESTORE_KEY)
1456
1457  if not scope_store:
1458    scope_store = _VariableScopeStore()
1459    ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store)
1460  else:
1461    scope_store = scope_store[0]
1462
1463  return scope_store
1464
1465
1466@tf_export(v1=["get_variable_scope"])
1467def get_variable_scope():
1468  """Returns the current variable scope.
1469
1470  @compatibility(TF2)
1471  Although it is a legacy `compat.v1` api,
1472  `tf.compat.v1.get_variable` is compatible with eager
1473  execution and `tf.function`
1474
1475  However, to maintain variable-scope based variable reuse
1476  you will need to combine it with
1477  `tf.compat.v1.keras.utils.track_tf1_style_variables`. (Though
1478  it will behave as if reuse is always set to `tf.compat.v1.AUTO_REUSE`.)
1479
1480  See the
1481  [migration guide](https://www.tensorflow.org/guide/migrate/model_mapping)
1482  for more info.
1483
1484  The TF2 equivalent, if you are just trying to track
1485  variable name prefixes and not control `get_variable`-based variable reuse,
1486  would be to use `tf.name_scope` and capture the output of opening the
1487  scope (which represents the current name prefix).
1488
1489  For example:
1490  ```python
1491  x = tf.name_scope('foo') as current_scope:
1492    ...
1493  ```
1494  @end_compatibility
1495  """
1496  return get_variable_scope_store().current_scope
1497
1498
1499def _get_default_variable_store():
1500  store = ops.get_collection(_VARSTORE_KEY)
1501  if store:
1502    return store[0]
1503  store = _VariableStore()
1504  ops.add_to_collection(_VARSTORE_KEY, store)
1505  return store
1506
1507
1508@tf_contextlib.contextmanager
1509def with_variable_store(store):
1510  store_collection = ops.get_collection_ref(_VARSTORE_KEY)
1511  old = list(store_collection)
1512  store_collection[:] = [store]
1513  try:
1514    yield
1515  finally:
1516    store_collection[:] = old
1517
1518
1519class EagerVariableStore:
1520  """Wrapper allowing functional layers to be used with eager execution.
1521
1522  When eager execution is enabled Variables get deleted when they go out of
1523  scope, and are not stored in global collections by default. A lot of code
1524  (mostly the functional layers in tf.layers) assumes that variables are kept in
1525  a global list.
1526
1527  EagerVariableStore can be used in conjunction with this code to make it
1528  eager-friendly. For example, to create a dense layer, use:
1529
1530  ```
1531    container = tfe.EagerVariableStore()
1532    for input in dataset_iterator:
1533      with container.as_default():
1534        x = tf.compat.v1.layers.dense(input, name="l1")
1535    print(container.variables)  # Should print the variables used in the layer.
1536  ```
1537  """
1538
1539  def __init__(self, store=None):
1540    if store is not None:
1541      if not store._store_eager_variables:  # pylint: disable=protected-access
1542        raise ValueError("Cannot construct EagerVariableStore from a "
1543                         "VariableStore object that does not hold eager "
1544                         "variables.")
1545      self._store = store
1546    else:
1547      self._store = _VariableStore()
1548    self._store._store_eager_variables = True  # pylint: disable=protected-access
1549
1550  def as_default(self):
1551    return with_variable_store(self._store)
1552
1553  def variables(self):
1554    return sorted(self._store._vars.values(), key=lambda x: x.name)  # pylint: disable=protected-access
1555
1556  def trainable_variables(self):
1557    # pylint: disable=protected-access
1558    return sorted([x for x in self._store._vars.values() if x.trainable],
1559                  key=lambda x: x.name)
1560    # pylint: enable=protected-access
1561
1562  def non_trainable_variables(self):
1563    # pylint: disable=protected-access
1564    return sorted([x for x in self._store._vars.values() if not x.trainable],
1565                  key=lambda x: x.name)
1566    # pylint: enable=protected-access
1567
1568  def copy(self):
1569    """Copy this variable store and all of its contents.
1570
1571    Variables contained in this store will be copied over to the new variable
1572    store, meaning that they can be modified without affecting the variables in
1573    this store.
1574
1575    Returns:
1576      A new EagerVariableStore instance containing copied variables.
1577    """
1578    # pylint: disable=protected-access
1579    new_store = EagerVariableStore()
1580    for key, var in self._store._vars.items():
1581      # Strip device out of variable name.
1582      try:
1583        index = var.name.index(":")
1584      except ValueError:
1585        stripped_var_name = var.name
1586      else:
1587        stripped_var_name = var.name[:index]
1588
1589      # Create new variable with same value, name, and "trainable" flag.
1590      new_var = resource_variable_ops.ResourceVariable(
1591          var.read_value(), name=stripped_var_name, trainable=var.trainable)
1592      new_store._store._vars[key] = new_var
1593    return new_store
1594    # pylint: enable=protected-access
1595
1596
1597# The argument list for get_variable must match arguments to get_local_variable.
1598# So, if you are updating the arguments, also update arguments to
1599# get_local_variable below.
1600@tf_export(v1=["get_variable"])
1601def get_variable(name,
1602                 shape=None,
1603                 dtype=None,
1604                 initializer=None,
1605                 regularizer=None,
1606                 trainable=None,
1607                 collections=None,
1608                 caching_device=None,
1609                 partitioner=None,
1610                 validate_shape=True,
1611                 use_resource=None,
1612                 custom_getter=None,
1613                 constraint=None,
1614                 synchronization=VariableSynchronization.AUTO,
1615                 aggregation=VariableAggregation.NONE):
1616  return get_variable_scope().get_variable(
1617      _get_default_variable_store(),
1618      name,
1619      shape=shape,
1620      dtype=dtype,
1621      initializer=initializer,
1622      regularizer=regularizer,
1623      trainable=trainable,
1624      collections=collections,
1625      caching_device=caching_device,
1626      partitioner=partitioner,
1627      validate_shape=validate_shape,
1628      use_resource=use_resource,
1629      custom_getter=custom_getter,
1630      constraint=constraint,
1631      synchronization=synchronization,
1632      aggregation=aggregation)
1633
1634
1635get_variable_or_local_docstring = ("""%s
1636
1637@compatibility(TF2)
1638Although it is a legacy `compat.v1` api,
1639`tf.compat.v1.get_variable` is mostly compatible with eager
1640execution and `tf.function` but only if you combine it with the
1641`tf.compat.v1.keras.utils.track_tf1_style_variables` decorator. (Though
1642it will behave as if reuse is always set to `AUTO_REUSE`.)
1643
1644See the
1645[model migration guide](https://www.tensorflow.org/guide/migrate/model_mapping)
1646for more info.
1647
1648If you do not combine it with
1649`tf.compat.v1.keras.utils.track_tf1_style_variables`, `get_variable` will create
1650a brand new variable every single time it is called and will never reuse
1651variables, regardless of variable names or `reuse` arguments.
1652
1653The TF2 equivalent of this symbol would be `tf.Variable`, but note
1654that when using `tf.Variable` you must make sure you track your variables
1655(and regularizer arguments) either manually or via `tf.Module` or
1656`tf.keras.layers.Layer` mechanisms.
1657
1658A section of the
1659[migration guide](https://www.tensorflow.org/guide/migrate/model_mapping#incremental_migration_to_native_tf2)
1660provides more details on incrementally migrating these usages to `tf.Variable`
1661as well.
1662
1663Note: The `partitioner` arg is not compatible with TF2 behaviors even when
1664using `tf.compat.v1.keras.utils.track_tf1_style_variables`. It can be replaced
1665by using `ParameterServerStrategy` and its partitioners. See the
1666[multi-gpu migration guide](https://www.tensorflow.org/guide/migrate/multi_worker_cpu_gpu_training)
1667and the ParameterServerStrategy guides it references for more info.
1668@end_compatibility
1669
1670%sThis function prefixes the name with the current variable scope
1671and performs reuse checks. See the
1672[Variable Scope How To](https://tensorflow.org/guide/variables)
1673for an extensive description of how reusing works. Here is a basic example:
1674
1675```python
1676def foo():
1677  with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
1678    v = tf.get_variable("v", [1])
1679  return v
1680
1681v1 = foo()  # Creates v.
1682v2 = foo()  # Gets the same, existing v.
1683assert v1 == v2
1684```
1685
1686If initializer is `None` (the default), the default initializer passed in
1687the variable scope will be used. If that one is `None` too, a
1688`glorot_uniform_initializer` will be used. The initializer can also be
1689a Tensor, in which case the variable is initialized to this value and shape.
1690
1691Similarly, if the regularizer is `None` (the default), the default regularizer
1692passed in the variable scope will be used (if that is `None` too,
1693then by default no regularization is performed).
1694
1695If a partitioner is provided, a `PartitionedVariable` is returned.
1696Accessing this object as a `Tensor` returns the shards concatenated along
1697the partition axis.
1698
1699Some useful partitioners are available.  See, e.g.,
1700`variable_axis_size_partitioner` and `min_max_variable_partitioner`.
1701
1702Args:
1703  name: The name of the new or existing variable.
1704  shape: Shape of the new or existing variable.
1705  dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
1706  initializer: Initializer for the variable if one is created. Can either be
1707    an initializer object or a Tensor. If it's a Tensor, its shape must be known
1708    unless validate_shape is False.
1709  regularizer: A (Tensor -> Tensor or None) function; the result of
1710    applying it on a newly created variable will be added to the collection
1711    `tf.GraphKeys.REGULARIZATION_LOSSES` and can be used for regularization.
1712  %scollections: List of graph collections keys to add the Variable to.
1713    Defaults to `[%s]` (see `tf.Variable`).
1714  caching_device: Optional device string or function describing where the
1715    Variable should be cached for reading.  Defaults to the Variable's
1716    device.  If not `None`, caches on another device.  Typical use is to
1717    cache on the device where the Ops using the Variable reside, to
1718    deduplicate copying through `Switch` and other conditional statements.
1719  partitioner: Optional callable that accepts a fully defined `TensorShape`
1720    and `dtype` of the Variable to be created, and returns a list of
1721    partitions for each axis (currently only one axis can be partitioned).
1722  validate_shape: If False, allows the variable to be initialized with a
1723      value of unknown shape. If True, the default, the shape of initial_value
1724      must be known. For this to be used the initializer must be a Tensor and
1725      not an initializer object.
1726  use_resource: If False, creates a regular Variable. If true, creates an
1727    experimental ResourceVariable instead with well-defined semantics.
1728    Defaults to False (will later change to True). When eager execution is
1729    enabled this argument is always forced to be True.
1730  custom_getter: Callable that takes as a first argument the true getter, and
1731    allows overwriting the internal get_variable method.
1732    The signature of `custom_getter` should match that of this method,
1733    but the most future-proof version will allow for changes:
1734    `def custom_getter(getter, *args, **kwargs)`.  Direct access to
1735    all `get_variable` parameters is also allowed:
1736    `def custom_getter(getter, name, *args, **kwargs)`.  A simple identity
1737    custom getter that simply creates variables with modified names is:
1738    ```python
1739    def custom_getter(getter, name, *args, **kwargs):
1740      return getter(name + '_suffix', *args, **kwargs)
1741    ```
1742  constraint: An optional projection function to be applied to the variable
1743    after being updated by an `Optimizer` (e.g. used to implement norm
1744    constraints or value constraints for layer weights). The function must
1745    take as input the unprojected Tensor representing the value of the
1746    variable and return the Tensor for the projected value
1747    (which must have the same shape). Constraints are not safe to
1748    use when doing asynchronous distributed training.
1749  synchronization: Indicates when a distributed a variable will be
1750    aggregated. Accepted values are constants defined in the class
1751    `tf.VariableSynchronization`. By default the synchronization is set to
1752    `AUTO` and the current `DistributionStrategy` chooses
1753    when to synchronize.
1754  aggregation: Indicates how a distributed variable will be aggregated.
1755    Accepted values are constants defined in the class
1756    `tf.VariableAggregation`.
1757
1758Returns:
1759  The created or existing `Variable` (or `PartitionedVariable`, if a
1760  partitioner was used).
1761
1762Raises:
1763  ValueError: when creating a new variable and shape is not declared,
1764    when violating reuse during variable creation, or when `initializer` dtype
1765    and `dtype` don't match. Reuse is set inside `variable_scope`.
1766""")
1767get_variable.__doc__ = get_variable_or_local_docstring % (
1768    "Gets an existing variable with these parameters or create a new one.", "",
1769    "trainable: If `True` also add the variable to the graph collection\n"
1770    "    `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).\n  ",
1771    "GraphKeys.GLOBAL_VARIABLES")
1772
1773
1774# The argument list for get_local_variable must match arguments to get_variable.
1775# So, if you are updating the arguments, also update arguments to get_variable.
1776@tf_export(v1=["get_local_variable"])
1777def get_local_variable(  # pylint: disable=missing-docstring
1778    name,
1779    shape=None,
1780    dtype=None,
1781    initializer=None,
1782    regularizer=None,
1783    trainable=False,  # pylint: disable=unused-argument
1784    collections=None,
1785    caching_device=None,
1786    partitioner=None,
1787    validate_shape=True,
1788    use_resource=None,
1789    custom_getter=None,
1790    constraint=None,
1791    synchronization=VariableSynchronization.AUTO,
1792    aggregation=VariableAggregation.NONE):
1793  if collections:
1794    collections += [ops.GraphKeys.LOCAL_VARIABLES]
1795  else:
1796    collections = [ops.GraphKeys.LOCAL_VARIABLES]
1797  return get_variable(
1798      name,
1799      shape=shape,
1800      dtype=dtype,
1801      initializer=initializer,
1802      regularizer=regularizer,
1803      trainable=False,
1804      collections=collections,
1805      caching_device=caching_device,
1806      partitioner=partitioner,
1807      validate_shape=validate_shape,
1808      use_resource=use_resource,
1809      synchronization=synchronization,
1810      aggregation=aggregation,
1811      custom_getter=custom_getter,
1812      constraint=constraint)
1813
1814
1815get_local_variable.__doc__ = get_variable_or_local_docstring % (
1816    "Gets an existing *local* variable or creates a new one.",
1817    "Behavior is the same as in `get_variable`, except that variables are\n"
1818    "added to the `LOCAL_VARIABLES` collection and `trainable` is set to\n"
1819    "`False`.\n", "", "GraphKeys.LOCAL_VARIABLES")
1820
1821
1822def _get_partitioned_variable(name,
1823                              shape=None,
1824                              dtype=None,
1825                              initializer=None,
1826                              regularizer=None,
1827                              trainable=True,
1828                              collections=None,
1829                              caching_device=None,
1830                              partitioner=None,
1831                              validate_shape=True,
1832                              use_resource=None,
1833                              constraint=None,
1834                              synchronization=VariableSynchronization.AUTO,
1835                              aggregation=VariableAggregation.NONE):
1836  """Gets or creates a sharded variable list with these parameters.
1837
1838  The `partitioner` must be a callable that accepts a fully defined
1839  `TensorShape` and returns a sequence of integers (the `partitions`).
1840  These integers describe how to partition the given sharded `Variable`
1841  along the given dimension.  That is, `partitions[1] = 3` means split
1842  the `Variable` into 3 shards along dimension 1.  Currently, sharding along
1843  only one axis is supported.
1844
1845  If the list of variables with the given name (prefix) is already stored,
1846  we return the stored variables. Otherwise, we create a new one.
1847
1848  If initializer is `None` (the default), the default initializer passed in
1849  the constructor is used. If that one is `None` too, we use a new
1850  `glorot_uniform_initializer`. If initializer is a Tensor, we use
1851  it as a value and derive the shape from the initializer.
1852
1853  If the initializer is a callable, then it will be called for each
1854  shard.  Otherwise the initializer should match the shape of the entire
1855  sharded Variable, and it will be sliced accordingly for each shard.
1856
1857  Some useful partitioners are available.  See, e.g.,
1858  `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
1859
1860  Args:
1861    name: The name of the new or existing variable.
1862    shape: Shape of the new or existing variable.
1863    dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
1864    initializer: Initializer for the variable if one is created.
1865    regularizer: A (Tensor -> Tensor or None) function; the result of applying
1866      it on a newly created variable will be added to the collection
1867      GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
1868    trainable: If `True` also add the variable to the graph collection
1869      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
1870    collections: List of graph collections keys to add the Variable to. Defaults
1871      to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
1872    caching_device: Optional device string or function describing where the
1873      Variable should be cached for reading.  Defaults to the Variable's device.
1874      If not `None`, caches on another device.  Typical use is to cache on the
1875      device where the Ops using the Variable reside, to deduplicate copying
1876      through `Switch` and other conditional statements.
1877    partitioner: Optional callable that accepts a fully defined `TensorShape`
1878      and `dtype` of the Variable to be created, and returns a list of
1879      partitions for each axis (currently only one axis can be partitioned).
1880    validate_shape: If False, allows the variable to be initialized with a value
1881      of unknown shape. If True, the default, the shape of initial_value must be
1882      known.
1883    use_resource: If False, creates a regular Variable. If True, creates an
1884      experimental ResourceVariable instead which has well-defined semantics.
1885      Defaults to False (will later change to True).
1886    constraint: An optional projection function to be applied to the variable
1887      after being updated by an `Optimizer` (e.g. used to implement norm
1888      constraints or value constraints for layer weights). The function must
1889      take as input the unprojected Tensor representing the value of the
1890      variable and return the Tensor for the projected value (which must have
1891      the same shape). Constraints are not safe to use when doing asynchronous
1892      distributed training.
1893    synchronization: Indicates when a distributed a variable will be aggregated.
1894      Accepted values are constants defined in the class
1895      `tf.VariableSynchronization`. By default the synchronization is set to
1896      `AUTO` and the current `DistributionStrategy` chooses when to synchronize.
1897    aggregation: Indicates how a distributed variable will be aggregated.
1898      Accepted values are constants defined in the class
1899      `tf.VariableAggregation`.
1900
1901  Returns:
1902    A tuple `(shards, partitions)` where `shards` is the list of `Variable`
1903    shards and `partitions` is the output of the partitioner on the input
1904    shape.
1905
1906  Raises:
1907    ValueError: when creating a new variable and shape is not declared,
1908      or when violating reuse during variable creation. Reuse is set inside
1909      `variable_scope`.
1910  """
1911  # pylint: disable=protected-access
1912  scope = get_variable_scope()
1913  if scope.custom_getter is not None:
1914    raise ValueError(
1915        "Private access to _get_partitioned_variable is not allowed when "
1916        "a custom getter is set.  Current custom getter: %s.  "
1917        "It is likely that you're using create_partitioned_variables.  "
1918        "If so, consider instead using get_variable with a non-empty "
1919        "partitioner parameter instead." % scope.custom_getter)
1920  return scope._get_partitioned_variable(
1921      _get_default_variable_store(),
1922      name,
1923      shape=shape,
1924      dtype=dtype,
1925      initializer=initializer,
1926      regularizer=regularizer,
1927      trainable=trainable,
1928      collections=collections,
1929      caching_device=caching_device,
1930      partitioner=partitioner,
1931      validate_shape=validate_shape,
1932      use_resource=use_resource,
1933      constraint=constraint,
1934      synchronization=synchronization,
1935      aggregation=aggregation)
1936  # pylint: enable=protected-access
1937
1938
1939# Named like a function for compatibility with the previous
1940# @tf_contextlib.contextmanager definition.
1941class _pure_variable_scope:  # pylint: disable=invalid-name
1942  """A context for the variable_scope, see `variable_scope` for docs."""
1943
1944  def __init__(self,
1945               name_or_scope,
1946               reuse=None,
1947               initializer=None,
1948               regularizer=None,
1949               caching_device=None,
1950               partitioner=None,
1951               custom_getter=None,
1952               old_name_scope=None,
1953               dtype=dtypes.float32,
1954               use_resource=None,
1955               constraint=None):
1956    """Creates a context for the variable_scope, see `variable_scope` for docs.
1957
1958    Note: this does not create a name scope.
1959
1960    Args:
1961      name_or_scope: `string` or `VariableScope`: the scope to open.
1962      reuse: `True` or None, or tf.compat.v1.AUTO_REUSE; if `None`, we inherit
1963        the parent scope's reuse flag.
1964      initializer: default initializer for variables within this scope.
1965      regularizer: default regularizer for variables within this scope.
1966      caching_device: default caching device for variables within this scope.
1967      partitioner: default partitioner for variables within this scope.
1968      custom_getter: default custom getter for variables within this scope.
1969      old_name_scope: the original name scope when re-entering a variable scope.
1970      dtype: type of the variables within this scope (defaults to `DT_FLOAT`).
1971      use_resource: If False, variables in this scope will be regular Variables.
1972        If True, experimental ResourceVariables will be creates instead, with
1973        well-defined semantics. Defaults to False (will later change to True).
1974      constraint: An optional projection function to be applied to the variable
1975        after being updated by an `Optimizer` (e.g. used to implement norm
1976        constraints or value constraints for layer weights). The function must
1977        take as input the unprojected Tensor representing the value of the
1978        variable and return the Tensor for the projected value (which must have
1979        the same shape). Constraints are not safe to use when doing asynchronous
1980        distributed training.
1981    """
1982    self._name_or_scope = name_or_scope
1983    self._reuse = reuse
1984    self._initializer = initializer
1985    self._regularizer = regularizer
1986    self._caching_device = caching_device
1987    self._partitioner = partitioner
1988    self._custom_getter = custom_getter
1989    self._old_name_scope = old_name_scope
1990    self._dtype = dtype
1991    self._use_resource = use_resource
1992    self._constraint = constraint
1993    self._var_store = _get_default_variable_store()
1994    self._var_scope_store = get_variable_scope_store()
1995    self._last_variable_scope_object = None
1996    if isinstance(self._name_or_scope, VariableScope):
1997      self._new_name = self._name_or_scope.name
1998      name_scope = self._name_or_scope._name_scope  # pylint: disable=protected-access
1999      # Handler for the case when we jump to a shared scope.  We create a new
2000      #   VariableScope (self._var_scope_object) that contains a copy of the
2001      #   provided shared scope, possibly with changed reuse and initializer, if
2002      #   the user requested this.
2003      variable_scope_object = VariableScope(
2004          self._name_or_scope.reuse if not self._reuse else self._reuse,
2005          name=self._new_name,
2006          initializer=self._name_or_scope.initializer,
2007          regularizer=self._name_or_scope.regularizer,
2008          caching_device=self._name_or_scope.caching_device,
2009          partitioner=self._name_or_scope.partitioner,
2010          dtype=self._name_or_scope.dtype,
2011          custom_getter=self._name_or_scope.custom_getter,
2012          name_scope=name_scope,
2013          use_resource=self._name_or_scope.use_resource,
2014          constraint=self._constraint)
2015      if self._initializer is not None:
2016        variable_scope_object.set_initializer(self._initializer)
2017      if self._regularizer is not None:
2018        variable_scope_object.set_regularizer(self._regularizer)
2019      if self._caching_device is not None:
2020        variable_scope_object.set_caching_device(self._caching_device)
2021      if self._partitioner is not None:
2022        variable_scope_object.set_partitioner(self._partitioner)
2023      if self._custom_getter is not None:
2024        variable_scope_object.set_custom_getter(
2025            _maybe_wrap_custom_getter(self._custom_getter,
2026                                      self._name_or_scope.custom_getter))
2027      if self._dtype is not None:
2028        variable_scope_object.set_dtype(self._dtype)
2029      if self._use_resource is not None:
2030        variable_scope_object.set_use_resource(self._use_resource)
2031      self._cached_variable_scope_object = variable_scope_object
2032
2033  def __enter__(self):
2034    """Begins the scope block.
2035
2036    Returns:
2037      A VariableScope.
2038    Raises:
2039      ValueError: when trying to reuse within a create scope, or create within
2040        a reuse scope, or if reuse is not `None` or `True`.
2041      TypeError: when the types of some arguments are not appropriate.
2042    """
2043    self._old = self._var_scope_store.current_scope
2044    if isinstance(self._name_or_scope, VariableScope):
2045      self._var_scope_store.open_variable_scope(self._new_name)
2046      self._old_subscopes = copy.copy(
2047          self._var_scope_store.variable_scopes_count)
2048      variable_scope_object = self._cached_variable_scope_object
2049    else:
2050      # Handler for the case when we just prolong current variable scope.
2051      #   VariableScope with name extended by the provided one, and inherited
2052      #   reuse and initializer (except if the user provided values to set).
2053      self._new_name = (
2054          self._old.name + "/" +
2055          self._name_or_scope if self._old.name else self._name_or_scope)
2056      self._reuse = (self._reuse or
2057                     self._old.reuse)  # Re-using is inherited by sub-scopes.
2058      if self._old_name_scope is None:
2059        name_scope = self._name_or_scope
2060      else:
2061        name_scope = self._old_name_scope
2062      variable_scope_object = VariableScope(
2063          self._reuse,
2064          name=self._new_name,
2065          initializer=self._old.initializer,
2066          regularizer=self._old.regularizer,
2067          caching_device=self._old.caching_device,
2068          partitioner=self._old.partitioner,
2069          dtype=self._old.dtype,
2070          use_resource=self._old.use_resource,
2071          custom_getter=self._old.custom_getter,
2072          name_scope=name_scope,
2073          constraint=self._constraint)
2074      if self._initializer is not None:
2075        variable_scope_object.set_initializer(self._initializer)
2076      if self._regularizer is not None:
2077        variable_scope_object.set_regularizer(self._regularizer)
2078      if self._caching_device is not None:
2079        variable_scope_object.set_caching_device(self._caching_device)
2080      if self._partitioner is not None:
2081        variable_scope_object.set_partitioner(self._partitioner)
2082      if self._custom_getter is not None:
2083        variable_scope_object.set_custom_getter(
2084            _maybe_wrap_custom_getter(self._custom_getter,
2085                                      self._old.custom_getter))
2086      if self._dtype is not None:
2087        variable_scope_object.set_dtype(self._dtype)
2088      if self._use_resource is not None:
2089        variable_scope_object.set_use_resource(self._use_resource)
2090      self._var_scope_store.open_variable_scope(self._new_name)
2091    self._var_scope_store.current_scope = variable_scope_object
2092    self._last_variable_scope_object = variable_scope_object
2093    return variable_scope_object
2094
2095  def __exit__(self, type_arg, value_arg, traceback_arg):
2096    if (self._var_scope_store.current_scope is
2097        not self._last_variable_scope_object):
2098      raise RuntimeError("Improper nesting of variable_scope.")
2099    # If jumping out from a non-prolonged scope, restore counts.
2100    if isinstance(self._name_or_scope, VariableScope):
2101      self._var_scope_store.variable_scopes_count = self._old_subscopes
2102    else:
2103      self._var_scope_store.close_variable_subscopes(self._new_name)
2104    self._var_scope_store.current_scope = self._old
2105
2106
2107def _maybe_wrap_custom_getter(custom_getter, old_getter):
2108  """Wrap a call to a custom_getter to use the old_getter internally."""
2109  if old_getter is None:
2110    return custom_getter
2111
2112  # The new custom_getter should call the old one
2113  def wrapped_custom_getter(getter, *args, **kwargs):
2114    # Call:
2115    #  custom_getter(
2116    #    lambda: old_getter(true_getter, ...), *args, **kwargs)
2117    # which means custom_getter will call old_getter, which
2118    # will call the true_getter, perform any intermediate
2119    # processing, and return the results to the current
2120    # getter, which will also perform additional processing.
2121    return custom_getter(functools.partial(old_getter, getter), *args, **kwargs)
2122
2123  return wrapped_custom_getter
2124
2125
2126def _get_unique_variable_scope(prefix):
2127  """Get a name with the given prefix unique in the current variable scope."""
2128  var_scope_store = get_variable_scope_store()
2129  current_scope = get_variable_scope()
2130  name = current_scope.name + "/" + prefix if current_scope.name else prefix
2131  if var_scope_store.variable_scope_count(name) == 0:
2132    return prefix
2133  idx = 1
2134  while var_scope_store.variable_scope_count(name + ("_%d" % idx)) > 0:
2135    idx += 1
2136  return prefix + ("_%d" % idx)
2137
2138
2139# Named like a function for backwards compatibility with the
2140# @tf_contextlib.contextmanager version, which was switched to a class to avoid
2141# some object creation overhead.
2142@tf_export(v1=["variable_scope"])  # pylint: disable=invalid-name
2143class variable_scope:
2144  """A context manager for defining ops that creates variables (layers).
2145
2146  @compatibility(TF2)
2147  Although it is a legacy `compat.v1` api,
2148  `tf.compat.v1.variable_scope` is mostly compatible with eager
2149  execution and `tf.function` as long as you combine it with the
2150  `tf.compat.v1.keras.utils.track_tf1_style_variables` decorator (though
2151  it will behave as if reuse is always set to `AUTO_REUSE`.)
2152
2153  See the
2154  [model migration guide](
2155      https://www.tensorflow.org/guide/migrate/model_mapping)
2156  for more info on
2157  migrating code that relies on `variable_scope`-based variable reuse.
2158
2159  When you use it with eager execution enabled but without
2160  `tf.compat.v1.keras.utils.track_tf1_style_variables`,
2161  `tf.compat.v1.variable_scope` will still be able to prefix the names
2162  of variables created within the scope but it will not enable variable reuse
2163  or error-raising checks around variable reuse (`get_variable` calls within
2164  it would always create new variables).
2165
2166  Once you have switched away from `get_variable`-based variable reuse
2167  mechanisms, to switch to TF2 APIs you can just use
2168  `tf.name_scope` to prefix variable names.
2169  @end_compatibility
2170
2171  This context manager validates that the (optional) `values` are from the same
2172  graph, ensures that graph is the default graph, and pushes a name scope and a
2173  variable scope.
2174
2175  If `name_or_scope` is not None, it is used as is. If `name_or_scope` is None,
2176  then `default_name` is used.  In that case, if the same name has been
2177  previously used in the same scope, it will be made unique by appending `_N`
2178  to it.
2179
2180  Variable scope allows you to create new variables and to share already created
2181  ones while providing checks to not create or share by accident. For details,
2182  see the [Variable Scope How To](https://tensorflow.org/guide/variables), here
2183  we present only a few basic examples.
2184
2185  The Variable Scope works as expected when the Eager Execution is Disabled.
2186
2187  ```python
2188  tf.compat.v1.disable_eager_execution()
2189  ```
2190
2191  Simple example of how to create a new variable:
2192
2193  ```python
2194  with tf.compat.v1.variable_scope("foo"):
2195      with tf.compat.v1.variable_scope("bar"):
2196          v = tf.compat.v1.get_variable("v", [1])
2197          assert v.name == "foo/bar/v:0"
2198  ```
2199
2200  Simple example of how to reenter a premade variable scope safely:
2201
2202  ```python
2203  with tf.compat.v1.variable_scope("foo") as vs:
2204    pass
2205
2206  # Re-enter the variable scope.
2207  with tf.compat.v1.variable_scope(vs,
2208                         auxiliary_name_scope=False) as vs1:
2209    # Restore the original name_scope.
2210    with tf.name_scope(vs1.original_name_scope):
2211        v = tf.compat.v1.get_variable("v", [1])
2212        assert v.name == "foo/v:0"
2213        c = tf.constant([1], name="c")
2214        assert c.name == "foo/c:0"
2215  ```
2216
2217  Keep in mind that the counters for `default_name` are discarded once the
2218  parent scope is exited. Therefore when the code re-enters the scope (for
2219  instance by saving it), all nested default_name counters will be restarted.
2220
2221  For instance:
2222
2223  ```python
2224  with tf.compat.v1.variable_scope("foo") as vs:
2225    with tf.compat.v1.variable_scope(None, default_name="bar"):
2226      v = tf.compat.v1.get_variable("a", [1])
2227      assert v.name == "foo/bar/a:0", v.name
2228    with tf.compat.v1.variable_scope(None, default_name="bar"):
2229      v = tf.compat.v1.get_variable("b", [1])
2230      assert v.name == "foo/bar_1/b:0"
2231
2232  with tf.compat.v1.variable_scope(vs):
2233    with tf.compat.v1.variable_scope(None, default_name="bar"):
2234      v = tf.compat.v1.get_variable("c", [1])
2235      assert v.name == "foo/bar/c:0"   # Uses bar instead of bar_2!
2236  ```
2237
2238  Basic example of sharing a variable AUTO_REUSE:
2239
2240  ```python
2241  def foo():
2242    with tf.compat.v1.variable_scope("foo", reuse=tf.compat.v1.AUTO_REUSE):
2243      v = tf.compat.v1.get_variable("v", [1])
2244    return v
2245
2246  v1 = foo()  # Creates v.
2247  v2 = foo()  # Gets the same, existing v.
2248  assert v1 == v2
2249  ```
2250
2251  Basic example of sharing a variable with reuse=True:
2252
2253  ```python
2254  with tf.compat.v1.variable_scope("foo"):
2255      v = tf.compat.v1.get_variable("v", [1])
2256  with tf.compat.v1.variable_scope("foo", reuse=True):
2257      v1 = tf.compat.v1.get_variable("v", [1])
2258  assert v1 == v
2259  ```
2260
2261  Sharing a variable by capturing a scope and setting reuse:
2262
2263  ```python
2264  with tf.compat.v1.variable_scope("foo") as scope:
2265      v = tf.compat.v1.get_variable("v", [1])
2266      scope.reuse_variables()
2267      v1 = tf.compat.v1.get_variable("v", [1])
2268  assert v1 == v
2269  ```
2270
2271  To prevent accidental sharing of variables, we raise an exception when getting
2272  an existing variable in a non-reusing scope.
2273
2274  ```python
2275  with tf.compat.v1.variable_scope("foo"):
2276      v = tf.compat.v1.get_variable("v", [1])
2277      v1 = tf.compat.v1.get_variable("v", [1])
2278      #  Raises ValueError("... v already exists ...").
2279  ```
2280
2281  Similarly, we raise an exception when trying to get a variable that does not
2282  exist in reuse mode.
2283
2284  ```python
2285  with tf.compat.v1.variable_scope("foo", reuse=True):
2286      v = tf.compat.v1.get_variable("v", [1])
2287      #  Raises ValueError("... v does not exists ...").
2288  ```
2289
2290  Note that the `reuse` flag is inherited: if we open a reusing scope, then all
2291  its sub-scopes become reusing as well.
2292
2293  A note about name scoping: Setting `reuse` does not impact the naming of other
2294  ops such as mult. See related discussion on
2295  [github#6189](https://github.com/tensorflow/tensorflow/issues/6189)
2296
2297  Note that up to and including version 1.0, it was allowed (though explicitly
2298  discouraged) to pass False to the reuse argument, yielding undocumented
2299  behaviour slightly different from None. Starting at 1.1.0 passing None and
2300  False as reuse has exactly the same effect.
2301
2302  A note about using variable scopes in multi-threaded environment: Variable
2303  scopes are thread local, so one thread will not see another thread's current
2304  scope. Also, when using `default_name`, unique scopes names are also generated
2305  only on a per thread basis. If the same name was used within a different
2306  thread, that doesn't prevent a new thread from creating the same scope.
2307  However, the underlying variable store is shared across threads (within the
2308  same graph). As such, if another thread tries to create a new variable with
2309  the same name as a variable created by a previous thread, it will fail unless
2310  reuse is True.
2311
2312  Further, each thread starts with an empty variable scope. So if you wish to
2313  preserve name prefixes from a scope from the main thread, you should capture
2314  the main thread's scope and re-enter it in each thread. For e.g.
2315
2316  ```
2317  main_thread_scope = variable_scope.get_variable_scope()
2318
2319  # Thread's target function:
2320  def thread_target_fn(captured_scope):
2321    with variable_scope.variable_scope(captured_scope):
2322      # .... regular code for this thread
2323
2324
2325  thread = threading.Thread(target=thread_target_fn, args=(main_thread_scope,))
2326  ```
2327  """
2328
2329  def __init__(self,
2330               name_or_scope,
2331               default_name=None,
2332               values=None,
2333               initializer=None,
2334               regularizer=None,
2335               caching_device=None,
2336               partitioner=None,
2337               custom_getter=None,
2338               reuse=None,
2339               dtype=None,
2340               use_resource=None,
2341               constraint=None,
2342               auxiliary_name_scope=True):
2343    """Initialize the context manager.
2344
2345    Args:
2346      name_or_scope: `string` or `VariableScope`: the scope to open.
2347      default_name: The default name to use if the `name_or_scope` argument is
2348        `None`, this name will be uniquified. If name_or_scope is provided it
2349        won't be used and therefore it is not required and can be None.
2350      values: The list of `Tensor` arguments that are passed to the op function.
2351      initializer: default initializer for variables within this scope.
2352      regularizer: default regularizer for variables within this scope.
2353      caching_device: default caching device for variables within this scope.
2354      partitioner: default partitioner for variables within this scope.
2355      custom_getter: default custom getter for variables within this scope.
2356      reuse: `True`, None, or tf.compat.v1.AUTO_REUSE; if `True`, we go into
2357        reuse mode for this scope as well as all sub-scopes; if
2358        tf.compat.v1.AUTO_REUSE, we create variables if they do not exist, and
2359        return them otherwise; if None, we inherit the parent scope's reuse
2360        flag. When eager execution is enabled, new variables are always created
2361        unless an EagerVariableStore or template is currently active.
2362      dtype: type of variables created in this scope (defaults to the type in
2363        the passed scope, or inherited from parent scope).
2364      use_resource: If False, all variables will be regular Variables. If True,
2365        experimental ResourceVariables with well-defined semantics will be used
2366        instead. Defaults to False (will later change to True). When eager
2367        execution is enabled this argument is always forced to be True.
2368      constraint: An optional projection function to be applied to the variable
2369        after being updated by an `Optimizer` (e.g. used to implement norm
2370        constraints or value constraints for layer weights). The function must
2371        take as input the unprojected Tensor representing the value of the
2372        variable and return the Tensor for the projected value (which must have
2373        the same shape). Constraints are not safe to use when doing asynchronous
2374        distributed training.
2375      auxiliary_name_scope: If `True`, we create an auxiliary name scope with
2376        the scope. If `False`, we don't create it. Note that the argument is not
2377        inherited, and it only takes effect for once when creating. You should
2378        only use it for re-entering a premade variable scope.
2379
2380    Returns:
2381      A scope that can be captured and reused.
2382
2383    Raises:
2384      ValueError: when trying to reuse within a create scope, or create within
2385        a reuse scope.
2386      TypeError: when the types of some arguments are not appropriate.
2387    """
2388    self._name_or_scope = name_or_scope
2389    self._default_name = default_name
2390    self._values = values
2391    self._initializer = initializer
2392    self._regularizer = regularizer
2393    self._caching_device = caching_device
2394    self._partitioner = partitioner
2395    self._custom_getter = custom_getter
2396    self._reuse = reuse
2397    self._dtype = dtype
2398    self._use_resource = use_resource
2399    self._constraint = constraint
2400    if self._default_name is None and self._name_or_scope is None:
2401      raise TypeError("If default_name is None then name_or_scope is required")
2402    if self._reuse is False:
2403      # We don't allow non-inheriting scopes, False = None here.
2404      self._reuse = None
2405    if not (self._reuse is True
2406            or self._reuse is None
2407            or self._reuse is AUTO_REUSE):
2408      raise ValueError("The reuse parameter must be True or False or None.")
2409    if self._values is None:
2410      self._values = []
2411    self._in_graph_mode = not context.executing_eagerly()
2412    if self._in_graph_mode:
2413      self._graph = ops._get_graph_from_inputs(self._values)  # pylint: disable=protected-access
2414    self._cached_pure_variable_scope = None
2415    self._current_name_scope = None
2416    if not isinstance(auxiliary_name_scope, bool):
2417      raise TypeError("The auxiliary_name_scope must be `True` or `False`, "
2418                      "while get {}".format(auxiliary_name_scope))
2419    self._auxiliary_name_scope = auxiliary_name_scope
2420
2421  def __enter__(self):
2422    # If the default graph is building a function, then we should not replace it
2423    # with the cached graph.
2424    if ops.get_default_graph().building_function:
2425      self._building_function = True
2426    else:
2427      self._building_function = False
2428    if self._in_graph_mode and not self._building_function:
2429      self._graph_context_manager = self._graph.as_default()
2430      self._graph_context_manager.__enter__()
2431    if self._cached_pure_variable_scope is not None:
2432      # Fast path for re-entering variable_scopes. We've held on to the pure
2433      # variable scope from a previous successful __enter__, so we avoid some
2434      # overhead by re-using that object.
2435      if self._current_name_scope is not None:
2436        self._current_name_scope.__enter__()
2437      return self._cached_pure_variable_scope.__enter__()
2438
2439    try:
2440      return self._enter_scope_uncached()
2441    except:
2442      if (self._in_graph_mode and not self._building_function and
2443          self._graph_context_manager is not None):
2444        self._graph_context_manager.__exit__(*sys.exc_info())
2445      raise
2446
2447  def _enter_scope_uncached(self):
2448    """Enters the context manager when there is no cached scope yet.
2449
2450    Returns:
2451      The entered variable scope.
2452
2453    Raises:
2454      TypeError: A wrong type is passed as `scope` at __init__().
2455      ValueError: `reuse` is incorrectly set at __init__().
2456    """
2457    if self._auxiliary_name_scope:
2458      # Create a new name scope later
2459      current_name_scope = None
2460    else:
2461      # Reenter the current name scope
2462      name_scope = ops.get_name_scope()
2463      if name_scope:
2464        # Hack to reenter
2465        name_scope += "/"
2466        current_name_scope = ops.name_scope(name_scope, skip_on_eager=False)
2467      else:
2468        # Root scope
2469        current_name_scope = ops.name_scope(name_scope, skip_on_eager=False)
2470
2471    # IMPORTANT: Only assign to self._cached_pure_variable_scope and
2472    # self._current_name_scope after successful __enter__() calls.
2473    if self._name_or_scope is not None:
2474      if not isinstance(self._name_or_scope, (VariableScope, str)):
2475        raise TypeError("VariableScope: name_or_scope must be a string or "
2476                        "VariableScope.")
2477      if isinstance(self._name_or_scope, str):
2478        name_scope = self._name_or_scope
2479      else:
2480        name_scope = self._name_or_scope.name.split("/")[-1]
2481      if name_scope or current_name_scope:
2482        current_name_scope = current_name_scope or ops.name_scope(
2483            name_scope, skip_on_eager=False)
2484        try:
2485          current_name_scope_name = current_name_scope.__enter__()
2486        except:
2487          current_name_scope.__exit__(*sys.exc_info())
2488          raise
2489        self._current_name_scope = current_name_scope
2490        if isinstance(self._name_or_scope, str):
2491          old_name_scope = current_name_scope_name
2492        else:
2493          old_name_scope = self._name_or_scope.original_name_scope
2494        pure_variable_scope = _pure_variable_scope(
2495            self._name_or_scope,
2496            reuse=self._reuse,
2497            initializer=self._initializer,
2498            regularizer=self._regularizer,
2499            caching_device=self._caching_device,
2500            partitioner=self._partitioner,
2501            custom_getter=self._custom_getter,
2502            old_name_scope=old_name_scope,
2503            dtype=self._dtype,
2504            use_resource=self._use_resource,
2505            constraint=self._constraint)
2506        try:
2507          entered_pure_variable_scope = pure_variable_scope.__enter__()
2508        except:
2509          pure_variable_scope.__exit__(*sys.exc_info())
2510          raise
2511        self._cached_pure_variable_scope = pure_variable_scope
2512        return entered_pure_variable_scope
2513      else:
2514        self._current_name_scope = None
2515        # This can only happen if someone is entering the root variable scope.
2516        pure_variable_scope = _pure_variable_scope(
2517            self._name_or_scope,
2518            reuse=self._reuse,
2519            initializer=self._initializer,
2520            regularizer=self._regularizer,
2521            caching_device=self._caching_device,
2522            partitioner=self._partitioner,
2523            custom_getter=self._custom_getter,
2524            dtype=self._dtype,
2525            use_resource=self._use_resource,
2526            constraint=self._constraint)
2527        try:
2528          entered_pure_variable_scope = pure_variable_scope.__enter__()
2529        except:
2530          pure_variable_scope.__exit__(*sys.exc_info())
2531          raise
2532        self._cached_pure_variable_scope = pure_variable_scope
2533        return entered_pure_variable_scope
2534
2535    else:  # Here name_or_scope is None. Using default name, but made unique.
2536      if self._reuse:
2537        raise ValueError("reuse=True cannot be used without a name_or_scope")
2538      current_name_scope = current_name_scope or ops.name_scope(
2539          self._default_name, skip_on_eager=False)
2540      try:
2541        current_name_scope_name = current_name_scope.__enter__()
2542      except:
2543        current_name_scope.__exit__(*sys.exc_info())
2544        raise
2545      self._current_name_scope = current_name_scope
2546      unique_default_name = _get_unique_variable_scope(self._default_name)
2547      pure_variable_scope = _pure_variable_scope(
2548          unique_default_name,
2549          initializer=self._initializer,
2550          regularizer=self._regularizer,
2551          caching_device=self._caching_device,
2552          partitioner=self._partitioner,
2553          custom_getter=self._custom_getter,
2554          old_name_scope=current_name_scope_name,
2555          dtype=self._dtype,
2556          use_resource=self._use_resource,
2557          constraint=self._constraint)
2558      try:
2559        entered_pure_variable_scope = pure_variable_scope.__enter__()
2560      except:
2561        pure_variable_scope.__exit__(*sys.exc_info())
2562        raise
2563      self._cached_pure_variable_scope = pure_variable_scope
2564      return entered_pure_variable_scope
2565
2566  def __exit__(self, type_arg, value_arg, traceback_arg):
2567    try:
2568      self._cached_pure_variable_scope.__exit__(type_arg, value_arg,
2569                                                traceback_arg)
2570    finally:
2571      try:
2572        if self._current_name_scope:
2573          self._current_name_scope.__exit__(type_arg, value_arg,
2574                                            traceback_arg)
2575      finally:
2576        if self._in_graph_mode and not self._building_function:
2577          self._graph_context_manager.__exit__(type_arg, value_arg,
2578                                               traceback_arg)
2579
2580
2581# pylint: disable=g-doc-return-or-yield
2582@tf_export(v1=["variable_op_scope"])
2583@tf_contextlib.contextmanager
2584def variable_op_scope(values,
2585                      name_or_scope,
2586                      default_name=None,
2587                      initializer=None,
2588                      regularizer=None,
2589                      caching_device=None,
2590                      partitioner=None,
2591                      custom_getter=None,
2592                      reuse=None,
2593                      dtype=None,
2594                      use_resource=None,
2595                      constraint=None):
2596  """Deprecated: context manager for defining an op that creates variables."""
2597  logging.warn("tf.variable_op_scope(values, name, default_name) is deprecated,"
2598               " use tf.variable_scope(name, default_name, values)")
2599  with variable_scope(
2600      name_or_scope,
2601      default_name=default_name,
2602      values=values,
2603      initializer=initializer,
2604      regularizer=regularizer,
2605      caching_device=caching_device,
2606      partitioner=partitioner,
2607      custom_getter=custom_getter,
2608      reuse=reuse,
2609      dtype=dtype,
2610      use_resource=use_resource,
2611      constraint=constraint) as scope:
2612    yield scope
2613
2614
2615def _call_partitioner(partitioner, shape, dtype):
2616  """Call partitioner validating its inputs/output.
2617
2618  Args:
2619    partitioner: a function mapping `Tensor` shape and dtype to a list of
2620      partitions.
2621    shape: shape of the `Tensor` to partition, must have at least two
2622      dimensions.
2623    dtype: dtype of the elements in the `Tensor`.
2624
2625  Returns:
2626    A list with elements >=1 and exactly one >1. The index of that
2627    element corresponds to the partitioning axis.
2628  """
2629  if not shape.is_fully_defined():
2630    raise ValueError("Shape of a new partitioned variable must be "
2631                     "fully defined, but instead was %s." % (shape,))
2632  if shape.ndims < 1:
2633    raise ValueError("A partitioned Variable must have rank at least 1, "
2634                     "shape: %s" % shape)
2635
2636  slicing = partitioner(shape=shape, dtype=dtype)
2637  if not isinstance(slicing, collections_abc.Sequence):
2638    raise ValueError("Partitioner must return a sequence, but saw: %s" %
2639                     slicing)
2640  if len(slicing) != shape.ndims:
2641    raise ValueError(
2642        "Partitioner returned a partition list that does not match the "
2643        "Variable's rank: %s vs. %s" % (slicing, shape))
2644  if any(p < 1 for p in slicing):
2645    raise ValueError("Partitioner returned zero partitions for some axes: %s" %
2646                     slicing)
2647  if sum(p > 1 for p in slicing) > 1:
2648    raise ValueError("Can only slice a variable along one dimension: "
2649                     "shape: %s, partitioning: %s" % (shape, slicing))
2650  return slicing
2651
2652
2653# TODO(slebedev): could be inlined, but
2654# `_VariableStore._get_partitioned_variable` is too complex even
2655# without this logic.
2656def _get_slice_dim_and_num_slices(slicing):
2657  """Get slicing dimension and number of slices from the partitioner output."""
2658  for slice_dim, num_slices in enumerate(slicing):
2659    if num_slices > 1:
2660      break
2661  else:
2662    # Degenerate case: no partitioning applied.
2663    slice_dim = 0
2664    num_slices = 1
2665  return slice_dim, num_slices
2666
2667
2668def _iter_slices(full_shape, num_slices, slice_dim):
2669  """Slices a given a shape along the specified dimension."""
2670  num_slices_with_excess = full_shape[slice_dim] % num_slices
2671  offset = [0] * len(full_shape)
2672  min_slice_len = full_shape[slice_dim] // num_slices
2673  for i in range(num_slices):
2674    shape = full_shape[:]
2675    shape[slice_dim] = min_slice_len + bool(i < num_slices_with_excess)
2676    yield offset[:], shape
2677    offset[slice_dim] += shape[slice_dim]
2678
2679
2680def default_variable_creator(next_creator=None, **kwargs):
2681  """Default variable creator."""
2682  assert next_creator is None
2683  initial_value = kwargs.get("initial_value", None)
2684  trainable = kwargs.get("trainable", None)
2685  collections = kwargs.get("collections", None)
2686  validate_shape = kwargs.get("validate_shape", True)
2687  caching_device = kwargs.get("caching_device", None)
2688  name = kwargs.get("name", None)
2689  variable_def = kwargs.get("variable_def", None)
2690  dtype = kwargs.get("dtype", None)
2691  expected_shape = kwargs.get("expected_shape", None)
2692  import_scope = kwargs.get("import_scope", None)
2693  constraint = kwargs.get("constraint", None)
2694  use_resource = kwargs.get("use_resource", None)
2695  synchronization = kwargs.get("synchronization", None)
2696  aggregation = kwargs.get("aggregation", None)
2697  shape = kwargs.get("shape", None)
2698
2699  if use_resource is None:
2700    use_resource = get_variable_scope().use_resource
2701  if use_resource is None:
2702    use_resource = _DEFAULT_USE_RESOURCE
2703  use_resource = use_resource or context.executing_eagerly()
2704  if use_resource:
2705    distribute_strategy = kwargs.get("distribute_strategy", None)
2706    return resource_variable_ops.ResourceVariable(
2707        initial_value=initial_value,
2708        trainable=trainable,
2709        collections=collections,
2710        validate_shape=validate_shape,
2711        caching_device=caching_device,
2712        name=name,
2713        dtype=dtype,
2714        constraint=constraint,
2715        variable_def=variable_def,
2716        import_scope=import_scope,
2717        distribute_strategy=distribute_strategy,
2718        synchronization=synchronization,
2719        aggregation=aggregation,
2720        shape=shape)
2721  else:
2722    return variables.RefVariable(
2723        initial_value=initial_value,
2724        trainable=trainable,
2725        collections=collections,
2726        validate_shape=validate_shape,
2727        caching_device=caching_device,
2728        name=name,
2729        dtype=dtype,
2730        constraint=constraint,
2731        variable_def=variable_def,
2732        expected_shape=expected_shape,
2733        import_scope=import_scope,
2734        synchronization=synchronization,
2735        aggregation=aggregation,
2736        shape=shape)
2737
2738
2739def default_variable_creator_v2(next_creator=None, **kwargs):
2740  """Default variable creator."""
2741  assert next_creator is None
2742  initial_value = kwargs.get("initial_value", None)
2743  trainable = kwargs.get("trainable", None)
2744  validate_shape = kwargs.get("validate_shape", True)
2745  caching_device = kwargs.get("caching_device", None)
2746  name = kwargs.get("name", None)
2747  variable_def = kwargs.get("variable_def", None)
2748  dtype = kwargs.get("dtype", None)
2749  import_scope = kwargs.get("import_scope", None)
2750  constraint = kwargs.get("constraint", None)
2751  distribute_strategy = kwargs.get("distribute_strategy", None)
2752  synchronization = kwargs.get("synchronization", None)
2753  aggregation = kwargs.get("aggregation", None)
2754  shape = kwargs.get("shape", None)
2755
2756  return resource_variable_ops.ResourceVariable(
2757      initial_value=initial_value,
2758      trainable=trainable,
2759      validate_shape=validate_shape,
2760      caching_device=caching_device,
2761      name=name,
2762      dtype=dtype,
2763      constraint=constraint,
2764      variable_def=variable_def,
2765      import_scope=import_scope,
2766      distribute_strategy=distribute_strategy,
2767      synchronization=synchronization,
2768      aggregation=aggregation,
2769      shape=shape)
2770
2771
2772variables.default_variable_creator = default_variable_creator
2773variables.default_variable_creator_v2 = default_variable_creator_v2
2774
2775
2776def _make_getter(captured_getter, captured_previous):
2777  """Gets around capturing loop variables in python being broken."""
2778  return lambda **kwargs: captured_getter(captured_previous, **kwargs)
2779
2780
2781# TODO(apassos) remove forwarding symbol
2782variable = variables.VariableV1
2783
2784
2785@tf_export(v1=["variable_creator_scope"])
2786@tf_contextlib.contextmanager
2787def variable_creator_scope_v1(variable_creator):
2788  """Scope which defines a variable creation function to be used by variable().
2789
2790  variable_creator is expected to be a function with the following signature:
2791
2792  ```
2793    def variable_creator(next_creator, **kwargs)
2794  ```
2795
2796  The creator is supposed to eventually call the next_creator to create a
2797  variable if it does want to create a variable and not call Variable or
2798  ResourceVariable directly. This helps make creators composable. A creator may
2799  choose to create multiple variables, return already existing variables, or
2800  simply register that a variable was created and defer to the next creators in
2801  line. Creators can also modify the keyword arguments seen by the next
2802  creators.
2803
2804  Custom getters in the variable scope will eventually resolve down to these
2805  custom creators when they do create variables.
2806
2807  The valid keyword arguments in kwds are:
2808
2809   * initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
2810        which is the initial value for the Variable. The initial value must have
2811        a shape specified unless `validate_shape` is set to False. Can also be a
2812        callable with no argument that returns the initial value when called. In
2813        that case, `dtype` must be specified. (Note that initializer functions
2814        from init_ops.py must first be bound to a shape before being used here.)
2815   * trainable: If `True`, the default, also adds the variable to the graph
2816        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
2817        the default list of variables to use by the `Optimizer` classes.
2818        `trainable` defaults to `True`, unless `synchronization` is
2819        set to `ON_READ`, in which case it defaults to `False`.
2820   * collections: List of graph collections keys. The new variable is added to
2821        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
2822   * validate_shape: If `False`, allows the variable to be initialized with a
2823        value of unknown shape. If `True`, the default, the shape of
2824        `initial_value` must be known.
2825   * caching_device: Optional device string describing where the Variable
2826        should be cached for reading.  Defaults to the Variable's device.
2827        If not `None`, caches on another device.  Typical use is to cache
2828        on the device where the Ops using the Variable reside, to deduplicate
2829        copying through `Switch` and other conditional statements.
2830   * name: Optional name for the variable. Defaults to `'Variable'` and gets
2831        uniquified automatically.
2832   * dtype: If set, initial_value will be converted to the given type.
2833        If `None`, either the datatype will be kept (if `initial_value` is
2834        a Tensor), or `convert_to_tensor` will decide.
2835   * constraint: A constraint function to be applied to the variable after
2836        updates by some algorithms.
2837   * use_resource: if True, a ResourceVariable is always created.
2838   * synchronization: Indicates when a distributed a variable will be
2839        aggregated. Accepted values are constants defined in the class
2840        `tf.VariableSynchronization`. By default the synchronization is set to
2841        `AUTO` and the current `DistributionStrategy` chooses
2842        when to synchronize.
2843   * aggregation: Indicates how a distributed variable will be aggregated.
2844        Accepted values are constants defined in the class
2845        `tf.VariableAggregation`.
2846
2847  This set may grow over time, so it's important the signature of creators is as
2848  mentioned above.
2849
2850  Args:
2851    variable_creator: the passed creator
2852
2853  Yields:
2854    A scope in which the creator is active
2855  """
2856  with ops.get_default_graph()._variable_creator_scope(variable_creator):  # pylint: disable=protected-access
2857    yield
2858
2859
2860# Note: only the docstrings differ between this and v1.
2861@tf_export("variable_creator_scope", v1=[])
2862@tf_contextlib.contextmanager
2863def variable_creator_scope(variable_creator):
2864  """Scope which defines a variable creation function to be used by variable().
2865
2866  variable_creator is expected to be a function with the following signature:
2867
2868  ```
2869    def variable_creator(next_creator, **kwargs)
2870  ```
2871
2872  The creator is supposed to eventually call the next_creator to create a
2873  variable if it does want to create a variable and not call Variable or
2874  ResourceVariable directly. This helps make creators composable. A creator may
2875  choose to create multiple variables, return already existing variables, or
2876  simply register that a variable was created and defer to the next creators in
2877  line. Creators can also modify the keyword arguments seen by the next
2878  creators.
2879
2880  Custom getters in the variable scope will eventually resolve down to these
2881  custom creators when they do create variables.
2882
2883  The valid keyword arguments in kwds are:
2884
2885   * initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
2886        which is the initial value for the Variable. The initial value must have
2887        a shape specified unless `validate_shape` is set to False. Can also be a
2888        callable with no argument that returns the initial value when called. In
2889        that case, `dtype` must be specified. (Note that initializer functions
2890        from init_ops.py must first be bound to a shape before being used here.)
2891   * trainable: If `True`, the default, GradientTapes automatically watch
2892        uses of this Variable.
2893   * validate_shape: If `False`, allows the variable to be initialized with a
2894        value of unknown shape. If `True`, the default, the shape of
2895        `initial_value` must be known.
2896   * caching_device: Optional device string describing where the Variable
2897        should be cached for reading.  Defaults to the Variable's device.
2898        If not `None`, caches on another device.  Typical use is to cache
2899        on the device where the Ops using the Variable reside, to deduplicate
2900        copying through `Switch` and other conditional statements.
2901   * name: Optional name for the variable. Defaults to `'Variable'` and gets
2902        uniquified automatically.
2903      dtype: If set, initial_value will be converted to the given type.
2904        If `None`, either the datatype will be kept (if `initial_value` is
2905        a Tensor), or `convert_to_tensor` will decide.
2906   * constraint: A constraint function to be applied to the variable after
2907        updates by some algorithms.
2908   * synchronization: Indicates when a distributed a variable will be
2909        aggregated. Accepted values are constants defined in the class
2910        `tf.VariableSynchronization`. By default the synchronization is set to
2911        `AUTO` and the current `DistributionStrategy` chooses
2912        when to synchronize.
2913   * aggregation: Indicates how a distributed variable will be aggregated.
2914        Accepted values are constants defined in the class
2915        `tf.VariableAggregation`.
2916
2917  This set may grow over time, so it's important the signature of creators is as
2918  mentioned above.
2919
2920  Args:
2921    variable_creator: the passed creator
2922
2923  Yields:
2924    A scope in which the creator is active
2925  """
2926  with ops.get_default_graph()._variable_creator_scope(variable_creator):  # pylint: disable=protected-access
2927    yield
2928