• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""A class to store named variables and a scope operator to manage sharing."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections as collections_lib
23import copy
24import enum  # pylint: disable=g-bad-import-order
25import functools
26import sys
27import threading
28import traceback
29
30import six
31from six import iteritems
32from six.moves import xrange  # pylint: disable=redefined-builtin
33
34from tensorflow.python import tf2
35from tensorflow.python.eager import context
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_shape
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import init_ops
41from tensorflow.python.ops import resource_variable_ops
42from tensorflow.python.ops import variables
43from tensorflow.python.platform import tf_logging as logging
44from tensorflow.python.util import deprecation
45from tensorflow.python.util import function_utils
46from tensorflow.python.util import tf_contextlib
47from tensorflow.python.util import tf_inspect
48from tensorflow.python.util.tf_export import tf_export
49
50__all__ = [
51    "AUTO_REUSE", "VariableScope", "get_variable_scope", "get_variable",
52    "get_local_variable", "variable_scope", "variable_op_scope",
53    "no_regularizer", "VariableSynchronization", "VariableAggregation"
54]
55
56
57class _PartitionInfo(object):
58  """Holds partition info used by initializer functions.
59  """
60
61  def __init__(self, full_shape, var_offset):
62    """Constructor.
63
64    Args:
65      full_shape: Tuple or list of `int` indicating the full combined shape
66        of the partitioned variables.
67      var_offset: Tuple or list of `int` specifying offset of this partition
68        with respect to the full variable for each dimension.
69
70    Raises:
71      TypeError: If `full_shape` or `var_offset` is not a sequence.
72      ValueError: If `full_shape` or `var_offset` differ in length. If
73        `var_offset` exceeds `full_shape` in any dimension.
74    """
75    if not isinstance(full_shape, collections_lib.Sequence) or isinstance(
76        full_shape, six.string_types):
77      raise TypeError(
78          "`full_shape` must be a sequence (like tuple or list) instead of " +
79          type(full_shape).__name__)
80
81    if not isinstance(var_offset, collections_lib.Sequence) or isinstance(
82        var_offset, six.string_types):
83      raise TypeError(
84          "`var_offset` must be a sequence (like tuple or list) instead of " +
85          type(var_offset).__name__)
86
87    if len(var_offset) != len(full_shape):
88      raise ValueError(
89          "Expected equal length, but `var_offset` is of length {} while "
90          "full_shape is of length {}.".format(
91              len(var_offset), len(full_shape)))
92
93    for i in xrange(len(full_shape)):
94      offset = var_offset[i]
95      shape = full_shape[i]
96      if offset < 0 or offset >= shape:
97        raise ValueError(
98            "Expected 0 <= offset < shape but found offset={}, shape={} for "
99            "var_offset={}, full_shape={}".format(offset, shape, var_offset,
100                                                  full_shape))
101
102    self._full_shape = full_shape
103    self._var_offset = var_offset
104
105  @property
106  def full_shape(self):
107    return self._full_shape
108
109  @property
110  def var_offset(self):
111    return self._var_offset
112
113  def single_offset(self, shape):
114    """Returns the offset when the variable is partitioned in at most one dim.
115
116    Args:
117      shape: Tuple or list of `int` indicating the shape of one specific
118        variable partition.
119
120    Returns:
121      `int` representing the offset in the dimension along which the variable is
122       partitioned. Returns 0 if the variable is not being partitioned.
123
124    Raises:
125      ValueError: Depending on self.single_slice_dim().
126    """
127
128    single_slice_dim = self.single_slice_dim(shape)
129    # If this variable is not being partitioned at all, single_slice_dim() could
130    # return None.
131    if single_slice_dim is None:
132      return 0
133    return self.var_offset[single_slice_dim]
134
135  def single_slice_dim(self, shape):
136    """Returns the slice dim when the variable is partitioned only in one dim.
137
138    Args:
139      shape: Tuple or list of `int` indicating the shape of one specific
140        variable partition.
141
142    Returns:
143      `int` representing the dimension that the variable is partitioned in, or
144      `None` if the variable doesn't seem to be partitioned at all.
145
146    Raises:
147      TypeError: If `shape` is not a sequence.
148      ValueError: If `shape` is not the same length as `self.full_shape`. If
149        the variable is partitioned in more than one dimension.
150    """
151    if not isinstance(shape, collections_lib.Sequence) or isinstance(
152        shape, six.string_types):
153      raise TypeError(
154          "`shape` must be a sequence (like tuple or list) instead of " +
155          type(shape).__name__)
156
157    if len(shape) != len(self.full_shape):
158      raise ValueError(
159          "Expected equal length, but received shape={} of length {} while "
160          "self.full_shape={} is of length {}.".format(shape, len(
161              shape), self.full_shape, len(self.full_shape)))
162
163    for i in xrange(len(shape)):
164      if self.var_offset[i] + shape[i] > self.full_shape[i]:
165        raise ValueError(
166            "With self.var_offset={}, a partition of shape={} would exceed "
167            "self.full_shape={} in dimension {}.".format(
168                self.var_offset, shape, self.full_shape, i))
169
170    slice_dim = None
171    for i in xrange(len(shape)):
172      if shape[i] == self.full_shape[i]:
173        continue
174      if slice_dim is not None:
175        raise ValueError(
176            "Cannot use single_slice_dim() with shape={} and "
177            "self.full_shape={} since slice dim could be either dimension {} "
178            "or {}.".format(shape, self.full_shape, i, slice_dim))
179      slice_dim = i
180
181    return slice_dim
182
183
184class _ReuseMode(enum.Enum):
185  """Mode for variable access within a variable scope."""
186
187  # Indicates that variables are to be fetched if they already exist or
188  # otherwise created.
189  AUTO_REUSE = 1
190
191  # TODO(alive): For TensorFlow 2.0, Deprecate True/False/None API in favor of
192  #              enum values.
193  # REUSE_FALSE = 2
194  # REUSE_TRUE = 3
195
196
197# TODO(apassos) remove these forwarding symbols.
198VariableSynchronization = variables.VariableSynchronization  # pylint: disable=invalid-name
199VariableAggregation = variables.VariableAggregation  # pylint: disable=invalid-name
200
201AUTO_REUSE = _ReuseMode.AUTO_REUSE
202tf_export(v1=["AUTO_REUSE"]).export_constant(__name__, "AUTO_REUSE")
203AUTO_REUSE.__doc__ = """
204When passed in as the value for the `reuse` flag, AUTO_REUSE indicates that
205get_variable() should create the requested variable if it doesn't exist or, if
206it does exist, simply return it.
207"""
208
209
210_DEFAULT_USE_RESOURCE = tf2.enabled()
211
212
213@tf_export(v1=["enable_resource_variables"])
214def enable_resource_variables():
215  """Creates resource variables by default.
216
217  Resource variables are improved versions of TensorFlow variables with a
218  well-defined memory model. Accessing a resource variable reads its value, and
219  all ops which access a specific read value of the variable are guaranteed to
220  see the same value for that tensor. Writes which happen after a read (by
221  having a control or data dependency on the read) are guaranteed not to affect
222  the value of the read tensor, and similarly writes which happen before a read
223  are guaranteed to affect the value. No guarantees are made about unordered
224  read/write pairs.
225
226  Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0
227  feature.
228  """
229  global _DEFAULT_USE_RESOURCE
230  _DEFAULT_USE_RESOURCE = True
231
232
233@tf_export(v1=["resource_variables_enabled"])
234def resource_variables_enabled():
235  """Returns `True` if resource variables are enabled.
236
237  Resource variables are improved versions of TensorFlow variables with a
238  well-defined memory model. Accessing a resource variable reads its value, and
239  all ops which access a specific read value of the variable are guaranteed to
240  see the same value for that tensor. Writes which happen after a read (by
241  having a control or data dependency on the read) are guaranteed not to affect
242  the value of the read tensor, and similarly writes which happen before a read
243  are guaranteed to affect the value. No guarantees are made about unordered
244  read/write pairs.
245
246  Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0
247  feature.
248  """
249  global _DEFAULT_USE_RESOURCE
250  return _DEFAULT_USE_RESOURCE
251
252
253@deprecation.deprecated(
254    None, "non-resource variables are not supported in the long term")
255@tf_export(v1=["disable_resource_variables"])
256def disable_resource_variables():
257  """Opts out of resource variables.
258
259  If your code needs tf.disable_resource_variables() to be called to work
260  properly please file a bug.
261  """
262  global _DEFAULT_USE_RESOURCE
263  _DEFAULT_USE_RESOURCE = False
264
265
266class _VariableStore(object):
267  """Variable store that carries a number of named Variables.
268
269  New variable names and new variables can be created; all stored
270  variables are initialized with the initializer passed to __init__.
271
272  Attributes:
273    vars: a dictionary with string names (same as passed in GetVar) as keys
274          and the corresponding TensorFlow Variables as values.
275  """
276
277  def __init__(self):
278    """Create a variable store."""
279    self._vars = {}  # A dictionary of the stored TensorFlow variables.
280    self._partitioned_vars = {}  # A dict of the stored PartitionedVariables.
281    self._store_eager_variables = False
282
283  def get_variable(self,
284                   name,
285                   shape=None,
286                   dtype=dtypes.float32,
287                   initializer=None,
288                   regularizer=None,
289                   reuse=None,
290                   trainable=None,
291                   collections=None,
292                   caching_device=None,
293                   partitioner=None,
294                   validate_shape=True,
295                   use_resource=None,
296                   custom_getter=None,
297                   constraint=None,
298                   synchronization=VariableSynchronization.AUTO,
299                   aggregation=VariableAggregation.NONE):
300    """Gets an existing variable with these parameters or create a new one.
301
302    If a variable with the given name is already stored, we return the stored
303    variable. Otherwise, we create a new one.
304
305    Set `reuse` to `True` when you only want to reuse existing Variables.
306    Set `reuse` to `False` when you only want to create new Variables.
307    Set `reuse` to None (the default) or tf.AUTO_REUSE when you want
308    variables to be created if they don't exist or returned if they do.
309
310    If initializer is `None` (the default), the default initializer passed in
311    the constructor is used. If that one is `None` too, we use a new
312    `glorot_uniform_initializer`. If initializer is a Tensor, we use
313    it as a value and derive the shape from the initializer.
314
315    If a partitioner is provided, a `PartitionedVariable` is returned.
316    Accessing this object as a `Tensor` returns the shards concatenated along
317    the partition axis.
318
319    Some useful partitioners are available.  See, e.g.,
320    `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
321
322    Args:
323      name: The name of the new or existing variable.
324      shape: Shape of the new or existing variable.
325      dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
326      initializer: Initializer for the variable.
327      regularizer: A (Tensor -> Tensor or None) function; the result of
328        applying it on a newly created variable will be added to the collection
329        GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
330      reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation
331        of variables. When eager execution is enabled  this argument is always
332        forced to be False.
333      trainable: If `True` also add the variable to the graph collection
334        `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
335        `trainable` defaults to `True` unless `synchronization` is
336        set to `ON_READ`.
337      collections: List of graph collections keys to add the `Variable` to.
338        Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
339      caching_device: Optional device string or function describing where the
340        Variable should be cached for reading.  Defaults to the Variable's
341        device.  If not `None`, caches on another device.  Typical use is to
342        cache on the device where the Ops using the `Variable` reside, to
343        deduplicate copying through `Switch` and other conditional statements.
344      partitioner: Optional callable that accepts a fully defined `TensorShape`
345        and dtype of the `Variable` to be created, and returns a list of
346        partitions for each axis (currently only one axis can be partitioned).
347      validate_shape: If False, allows the variable to be initialized with a
348        value of unknown shape. If True, the default, the shape of initial_value
349        must be known.
350      use_resource: If False, creates a regular Variable. If True, creates
351        instead an experimental ResourceVariable which has well-defined
352        semantics. Defaults to False (will later change to True).
353        When eager execution is enabled this argument is always forced to be
354        true.
355      custom_getter: Callable that takes as a first argument the true getter,
356        and allows overwriting the internal get_variable method.
357        The signature of `custom_getter` should match that of this method,
358        but the most future-proof version will allow for changes:
359        `def custom_getter(getter, *args, **kwargs)`.  Direct access to
360        all `get_variable` parameters is also allowed:
361        `def custom_getter(getter, name, *args, **kwargs)`.  A simple identity
362        custom getter that simply creates variables with modified names is:
363        ```python
364        def custom_getter(getter, name, *args, **kwargs):
365          return getter(name + '_suffix', *args, **kwargs)
366        ```
367      constraint: An optional projection function to be applied to the variable
368        after being updated by an `Optimizer` (e.g. used to implement norm
369        constraints or value constraints for layer weights). The function must
370        take as input the unprojected Tensor representing the value of the
371        variable and return the Tensor for the projected value
372        (which must have the same shape). Constraints are not safe to
373        use when doing asynchronous distributed training.
374      synchronization: Indicates when a distributed a variable will be
375        aggregated. Accepted values are constants defined in the class
376        `tf.VariableSynchronization`. By default the synchronization is set to
377        `AUTO` and the current `DistributionStrategy` chooses
378        when to synchronize. If `synchronization` is set to `ON_READ`,
379        `trainable` must not be set to `True`.
380      aggregation: Indicates how a distributed variable will be aggregated.
381        Accepted values are constants defined in the class
382        `tf.VariableAggregation`.
383
384    Returns:
385      The created or existing `Variable` (or `PartitionedVariable`, if a
386      partitioner was used).
387
388    Raises:
389      ValueError: when creating a new variable and shape is not declared,
390        when reusing a variable and specifying a conflicting shape,
391        or when violating reuse during variable creation.
392      RuntimeError: when eager execution is enabled and not called from an
393        EagerVariableStore.
394    """
395    if custom_getter is not None and not callable(custom_getter):
396      raise ValueError(
397          "Passed a custom_getter which is not callable: %s" % custom_getter)
398
399    with ops.init_scope():
400      if context.executing_eagerly():
401        # Variable creation and initialization takes place in `init_scope`s;
402        # as such, if an `init_scope` lifts us into the eager context, then we
403        # need to use `ResourceVariable`s.
404        use_resource = True
405
406    # Note that it's fine to reuse eager variables whose initialization was
407    # lifted from a function-building graph into the eager context (that's why
408    # the following clause is not wrapped in an `init_scope`); lifted variables
409    # are tracked by the graph's `VariableStore`.
410    if context.executing_eagerly():
411      if not self._store_eager_variables and reuse:
412        raise RuntimeError(
413            "When eager execution is enabled variable reuse is only supported"
414            " when an EagerVariableStore is active. See the documentation on"
415            " EagerVariableStore for example usage.")
416      if self._store_eager_variables:
417        reuse = AUTO_REUSE
418
419    # If a *_ref type is passed in an error would be triggered further down the
420    # stack. We prevent this using base_dtype to get a non-ref version of the
421    # type, before doing anything else. When _ref types are removed in favor of
422    # resources, this line can be removed.
423    try:
424      dtype = dtype.base_dtype
425    except AttributeError:
426      # .base_dtype not existing means that we will try and use the raw dtype
427      # which was passed in - this might be a NumPy type which is valid.
428      pass
429
430    # This is the main logic of get_variable.  However, custom_getter
431    # may override this logic.  So we save it as a callable and pass
432    # it to custom_getter.
433    # Note: the parameters of _true_getter, and their documentation, match
434    # *exactly* item-for-item with the docstring of this method.
435    def _true_getter(  # pylint: disable=missing-docstring
436        name,
437        shape=None,
438        dtype=dtypes.float32,
439        initializer=None,
440        regularizer=None,
441        reuse=None,
442        trainable=None,
443        collections=None,
444        caching_device=None,
445        partitioner=None,
446        validate_shape=True,
447        use_resource=None,
448        constraint=None,
449        synchronization=VariableSynchronization.AUTO,
450        aggregation=VariableAggregation.NONE):
451      is_scalar = (shape is not None
452                   and isinstance(shape, collections_lib.Sequence)
453                   and not shape)
454      # Partitioned variable case
455      if partitioner is not None and not is_scalar:
456        if not callable(partitioner):
457          raise ValueError(
458              "Partitioner must be callable, but received: %s" % partitioner)
459        with ops.name_scope(None):
460          return self._get_partitioned_variable(
461              name=name,
462              shape=shape,
463              dtype=dtype,
464              initializer=initializer,
465              regularizer=regularizer,
466              reuse=reuse,
467              trainable=trainable,
468              collections=collections,
469              caching_device=caching_device,
470              partitioner=partitioner,
471              validate_shape=validate_shape,
472              use_resource=use_resource,
473              constraint=constraint,
474              synchronization=synchronization,
475              aggregation=aggregation)
476
477      # Special case for partitioned variable to allow reuse without having to
478      # specify partitioner.
479      if (reuse is True and partitioner is None
480          and name in self._partitioned_vars):
481        return self._get_partitioned_variable(
482            name=name,
483            shape=shape,
484            dtype=dtype,
485            initializer=initializer,
486            regularizer=regularizer,
487            reuse=reuse,
488            trainable=trainable,
489            collections=collections,
490            caching_device=caching_device,
491            partitioner=None,
492            validate_shape=validate_shape,
493            use_resource=use_resource,
494            constraint=constraint,
495            synchronization=synchronization,
496            aggregation=aggregation)
497
498      # Single variable case
499      if "%s/part_0" % name in self._vars:
500        raise ValueError(
501            "No partitioner was provided, but a partitioned version of the "
502            "variable was found: %s/part_0. Perhaps a variable of the same "
503            "name was already created with partitioning?" % name)
504
505      return self._get_single_variable(
506          name=name,
507          shape=shape,
508          dtype=dtype,
509          initializer=initializer,
510          regularizer=regularizer,
511          reuse=reuse,
512          trainable=trainable,
513          collections=collections,
514          caching_device=caching_device,
515          validate_shape=validate_shape,
516          use_resource=use_resource,
517          constraint=constraint,
518          synchronization=synchronization,
519          aggregation=aggregation)
520
521    # Set trainable value based on synchronization value.
522    trainable = _get_trainable_value(
523        synchronization=synchronization, trainable=trainable)
524
525    if custom_getter is not None:
526      # Handle backwards compatibility with getter arguments that were added
527      # to the API after users started writing custom getters.
528      custom_getter_kwargs = {
529          "getter": _true_getter,
530          "name": name,
531          "shape": shape,
532          "dtype": dtype,
533          "initializer": initializer,
534          "regularizer": regularizer,
535          "reuse": reuse,
536          "trainable": trainable,
537          "collections": collections,
538          "caching_device": caching_device,
539          "partitioner": partitioner,
540          "validate_shape": validate_shape,
541          "use_resource": use_resource,
542          "synchronization": synchronization,
543          "aggregation": aggregation,
544      }
545      # `fn_args` and `has_kwargs` can handle functions, `functools.partial`,
546      # `lambda`.
547      if ("constraint" in function_utils.fn_args(custom_getter) or
548          function_utils.has_kwargs(custom_getter)):
549        custom_getter_kwargs["constraint"] = constraint
550      return custom_getter(**custom_getter_kwargs)
551    else:
552      return _true_getter(
553          name,
554          shape=shape,
555          dtype=dtype,
556          initializer=initializer,
557          regularizer=regularizer,
558          reuse=reuse,
559          trainable=trainable,
560          collections=collections,
561          caching_device=caching_device,
562          partitioner=partitioner,
563          validate_shape=validate_shape,
564          use_resource=use_resource,
565          constraint=constraint,
566          synchronization=synchronization,
567          aggregation=aggregation)
568
569  def _get_partitioned_variable(self,
570                                name,
571                                partitioner,
572                                shape=None,
573                                dtype=dtypes.float32,
574                                initializer=None,
575                                regularizer=None,
576                                reuse=None,
577                                trainable=None,
578                                collections=None,
579                                caching_device=None,
580                                validate_shape=True,
581                                use_resource=None,
582                                constraint=None,
583                                synchronization=VariableSynchronization.AUTO,
584                                aggregation=VariableAggregation.NONE):
585    """Gets or creates a sharded variable list with these parameters.
586
587    The `partitioner` must be a callable that accepts a fully defined
588    `TensorShape` and returns a sequence of integers (the `partitions`).
589    These integers describe how to partition the given sharded `Variable`
590    along the given dimension.  That is, `partitions[1] = 3` means split
591    the `Variable` into 3 shards along dimension 1.  Currently, sharding along
592    only one axis is supported.
593
594    If the list of variables with the given name (prefix) is already stored,
595    we return the stored variables. Otherwise, we create a new one.
596
597    Set `reuse` to `True` when you only want to reuse existing Variables.
598    Set `reuse` to `False` when you only want to create new Variables.
599    Set `reuse` to None (the default) or tf.AUTO_REUSE when you want
600    variables to be created if they don't exist or returned if they do.
601
602    If initializer is `None` (the default), the default initializer passed in
603    the constructor is used. If that one is `None` too, we use a new
604    `glorot_uniform_initializer`. If initializer is a Tensor, we use
605    it as a value and derive the shape from the initializer.
606
607    If the initializer is a callable, then it will be called for each
608    shard.  Otherwise the initializer should match the shape of the entire
609    sharded Variable, and it will be sliced accordingly for each shard.
610
611    Some useful partitioners are available.  See, e.g.,
612    `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
613
614    Args:
615      name: the name of the new or existing sharded variable.
616      partitioner: Optional callable that accepts a fully defined `TensorShape`
617        and `dtype` of the Variable to be created, and returns a list of
618        partitions for each axis (currently only one axis can be partitioned).
619      shape: shape of the new or existing sharded variable.
620      dtype: type of the new or existing sharded variable
621        (defaults to `DT_FLOAT`).
622      initializer: initializer for the sharded variable.
623      regularizer: a (Tensor -> Tensor or None) function; the result of
624        applying it on a newly created variable will be added to the collection
625        GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
626      reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation
627        of variables.
628      trainable: If `True` also add the variable to the graph collection
629        `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
630      collections: List of graph collections keys to add the Variable to.
631        Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
632      caching_device: Optional device string or function describing where the
633        Variable should be cached for reading.  Defaults to the Variable's
634        device.  If not `None`, caches on another device.  Typical use is to
635        cache on the device where the Ops using the Variable reside, to
636        deduplicate copying through `Switch` and other conditional statements.
637      validate_shape: If False, allows the variable to be initialized with a
638        value of unknown shape. If True, the default, the shape of initial_value
639        must be known.
640      use_resource: If False, creates a regular Variable. If True, creates an
641        experimental ResourceVariable which has well-defined semantics. Defaults
642        to False (will later change to True).
643      constraint: An optional projection function to be applied to the variable
644        after being updated by an `Optimizer` (e.g. used to implement norm
645        constraints or value constraints for layer weights). The function must
646        take as input the unprojected Tensor representing the value of the
647        variable and return the Tensor for the projected value
648        (which must have the same shape). Constraints are not safe to
649        use when doing asynchronous distributed training.
650      synchronization: Indicates when a distributed a variable will be
651        aggregated. Accepted values are constants defined in the class
652        `tf.VariableSynchronization`. By default the synchronization is set to
653        `AUTO` and the current `DistributionStrategy` chooses
654        when to synchronize. If `synchronization` is set to `ON_READ`,
655        `trainable` must not be set to `True`.
656      aggregation: Indicates how a distributed variable will be aggregated.
657        Accepted values are constants defined in the class
658        `tf.VariableAggregation`.
659
660    Returns:
661      A `PartitionedVariable` object.
662
663    Raises:
664      ValueError: when creating a new variable and shape is not declared,
665        when reusing a variable and specifying a conflicting shape,
666        when violating reuse during variable creation, or if an existing
667        sharded variable exists for the given name but with different sharding.
668    """
669    initializing_from_value = initializer is not None and isinstance(
670        initializer, ops.Tensor)
671    if name in self._vars:
672      raise ValueError(
673          "A partitioner was provided, but an unpartitioned version of the "
674          "variable was found: %s.  Perhaps a variable of the same name was "
675          "already created without partitioning?" % name)
676
677    shape = tensor_shape.as_shape(shape)
678    if initializing_from_value:
679      shape = shape.merge_with(initializer.get_shape())
680
681    partitions = None
682    if not reuse or partitioner:
683      partitions = _call_partitioner(partitioner, shape, dtype)
684
685    if name in self._partitioned_vars:
686      if reuse is False:
687        raise ValueError(
688            "Partitioned variable with name %s already exists. Did you mean to "
689            "set reuse=True or reuse=tf.AUTO_REUSE in VarScope?"
690            % name)
691
692      existing_var = self._partitioned_vars[name]
693      if not shape.is_compatible_with(existing_var.get_shape()):
694        raise ValueError(
695            "Trying to reuse partitioned variable %s, but specified shape %s "
696            "and found shape %s."
697            % (name, shape, existing_var.get_shape()))
698      if not dtype.is_compatible_with(existing_var.dtype):
699        raise ValueError(
700            "Trying to reuse partitioned variable %s, but specified dtype %s "
701            "and found dtype %s."
702            % (name, dtype.name, existing_var.dtype.name))
703
704      # pylint: disable=protected-access
705      if (partitions is not None and
706          existing_var._get_partitions() != partitions):
707        raise ValueError(
708            "Trying to reuse partitioned variable %s, but specified partitions "
709            "%s and found partitions %s." %
710            (name, partitions, existing_var._get_partitions()))
711      # pylint: enable=protected-access
712
713      return existing_var
714
715    if reuse is True:
716      raise ValueError("PartitionedVariable %s does not exist, or was not "
717                       "created with tf.get_variable(). Did you mean to set "
718                       "reuse=False or reuse=tf.AUTO_REUSE in VarScope?" % name)
719
720    slice_dim, num_slices = _get_slice_dim_and_num_slices(partitions)
721
722    if "%s/part_0" % name in self._vars:
723      if "%s/part_%d" % (name, num_slices - 1) not in self._vars:
724        raise ValueError(
725            "Partitioner returned a different partitioning than what was "
726            "already found.  Partitioner returned %d shards, and shard "
727            "%s/part_0 was found, but %s/part_%d was not."
728            % (num_slices, name, name, num_slices - 1))
729      if "%s/part_%d" % (name, num_slices) in self._vars:
730        raise ValueError(
731            "Partitioner returned a different partitioning than what was "
732            "already found.  Partitioner returned %d shards, and shard "
733            "%s/part_0 was found, but so was the extra shard %s/part_%d."
734            % (num_slices, name, name, num_slices))
735
736    vs = []
737    for i, (var_offset, var_shape) in enumerate(_iter_slices(
738        shape.as_list(),
739        num_slices,
740        slice_dim
741    )):
742      partition_info = _PartitionInfo(
743          full_shape=shape.as_list(), var_offset=var_offset)
744      var_full_name = "%s/part_%d" % (name, i)
745      with ops.name_scope(var_full_name + "/PartitionedInitializer"):
746        # Create the tensor to initialize the variable with default value.
747        if initializer is None:
748          init, initializing_from_value = self._get_default_initializer(
749              name=name, shape=shape, dtype=dtype)
750          if initializing_from_value:
751            init_shape = None
752          else:
753            init_shape = var_shape
754        elif callable(initializer):
755          init = initializer
756          init_shape = var_shape
757        elif isinstance(initializer, ops.Tensor):
758          init = array_ops.slice(initializer, var_offset, var_shape)
759          # Use the dtype of the given tensor.
760          dtype = init.dtype.base_dtype
761          init_shape = None
762        else:
763          init = ops.convert_to_tensor(initializer, dtype=dtype)
764          init = array_ops.slice(init, var_offset, var_shape)
765          init_shape = None
766
767      with ops.name_scope(None):
768        var = self._get_single_variable(
769            name=var_full_name,
770            shape=init_shape,
771            dtype=dtype,
772            initializer=init,
773            partition_info=partition_info,
774            regularizer=regularizer,
775            reuse=reuse,
776            trainable=trainable,
777            collections=collections,
778            caching_device=caching_device,
779            validate_shape=validate_shape,
780            use_resource=use_resource,
781            constraint=constraint,
782            synchronization=synchronization,
783            aggregation=aggregation)
784
785      # pylint: disable=protected-access
786      var._set_save_slice_info(variables.Variable.SaveSliceInfo(
787          name, shape.as_list(), var_offset, var_shape))
788      vs.append(var)
789      # pylint: enable=protected-access
790
791    partitioned_var = variables.PartitionedVariable(name=name,
792                                                    shape=shape,
793                                                    dtype=dtype,
794                                                    variable_list=vs,
795                                                    partitions=partitions)
796    if not context.executing_eagerly() or self._store_eager_variables:
797      self._partitioned_vars[name] = partitioned_var
798    return partitioned_var
799
800  def _get_single_variable(self,
801                           name,
802                           shape=None,
803                           dtype=dtypes.float32,
804                           initializer=None,
805                           regularizer=None,
806                           partition_info=None,
807                           reuse=None,
808                           trainable=None,
809                           collections=None,
810                           caching_device=None,
811                           validate_shape=True,
812                           use_resource=None,
813                           constraint=None,
814                           synchronization=VariableSynchronization.AUTO,
815                           aggregation=VariableAggregation.NONE):
816    """Get or create a single Variable (e.g. a shard or entire variable).
817
818    See the documentation of get_variable above (ignore partitioning components)
819    for details.
820
821    Args:
822      name: see get_variable.
823      shape: see get_variable.
824      dtype: see get_variable.
825      initializer: see get_variable.
826      regularizer: see get_variable.
827      partition_info: _PartitionInfo object.
828      reuse: see get_variable.
829      trainable: see get_variable.
830      collections: see get_variable.
831      caching_device: see get_variable.
832      validate_shape: see get_variable.
833      use_resource: see get_variable.
834      constraint: see get_variable.
835      synchronization: see get_variable.
836      aggregation: see get_variable.
837
838    Returns:
839      A Variable.  See documentation of get_variable above.
840
841    Raises:
842      ValueError: See documentation of get_variable above.
843    """
844    # Set to true if initializer is a constant.
845    initializing_from_value = False
846    if initializer is not None and not callable(initializer):
847      initializing_from_value = True
848    if shape is not None and initializing_from_value:
849      raise ValueError("If initializer is a constant, do not specify shape.")
850
851    dtype = dtypes.as_dtype(dtype)
852    shape = tensor_shape.as_shape(shape)
853
854    if name in self._vars:
855      # Here we handle the case when returning an existing variable.
856      if reuse is False:
857        var = self._vars[name]
858        err_msg = ("Variable %s already exists, disallowed."
859                   " Did you mean to set reuse=True or "
860                   "reuse=tf.AUTO_REUSE in VarScope?" % name)
861        # ResourceVariables don't have an op associated with so no traceback
862        if isinstance(var, resource_variable_ops.ResourceVariable):
863          raise ValueError(err_msg)
864        tb = var.op.traceback[::-1]
865        # Throw away internal tf entries and only take a few lines. In some
866        # cases the traceback can be longer (e.g. if someone uses factory
867        # functions to create variables) so we take more than needed in the
868        # default case.
869        tb = [x for x in tb if "tensorflow/python" not in x[0]][:5]
870        raise ValueError("%s Originally defined at:\n\n%s" % (err_msg, "".join(
871            traceback.format_list(tb))))
872      found_var = self._vars[name]
873      if not shape.is_compatible_with(found_var.get_shape()):
874        raise ValueError("Trying to share variable %s, but specified shape %s"
875                         " and found shape %s." % (name, shape,
876                                                   found_var.get_shape()))
877      if not dtype.is_compatible_with(found_var.dtype):
878        dtype_str = dtype.name
879        found_type_str = found_var.dtype.name
880        raise ValueError("Trying to share variable %s, but specified dtype %s"
881                         " and found dtype %s." % (name, dtype_str,
882                                                   found_type_str))
883      return found_var
884
885    # The code below handles only the case of creating a new variable.
886    if reuse is True:
887      raise ValueError("Variable %s does not exist, or was not created with "
888                       "tf.get_variable(). Did you mean to set "
889                       "reuse=tf.AUTO_REUSE in VarScope?" % name)
890
891    # Create the tensor to initialize the variable with default value.
892    if initializer is None:
893      initializer, initializing_from_value = self._get_default_initializer(
894          name=name, shape=shape, dtype=dtype)
895    # Enter an init scope when creating the initializer.
896    with ops.init_scope():
897      if initializing_from_value:
898        init_val = initializer
899        variable_dtype = None
900      else:
901        # Instantiate initializer if provided initializer is a type object.
902        if tf_inspect.isclass(initializer):
903          initializer = initializer(dtype=dtype)
904        if shape is not None and shape.is_fully_defined():
905          init_val = lambda: initializer(  # pylint: disable=g-long-lambda
906              shape.as_list(), dtype=dtype, partition_info=partition_info)
907          variable_dtype = dtype.base_dtype
908        elif len(tf_inspect.getargspec(initializer).args) == len(
909            tf_inspect.getargspec(initializer).defaults or []):
910          init_val = initializer
911          variable_dtype = None
912        else:
913          raise ValueError("The initializer passed is not valid. It should "
914                           "be a callable with no arguments and the "
915                           "shape should not be provided or an instance of "
916                           "`tf.keras.initializers.*' and `shape` should be "
917                           "fully defined.")
918
919    # Create the variable.
920    if use_resource is None:
921      # Set the default value if unspecified.
922      use_resource = _DEFAULT_USE_RESOURCE
923    v = variables.VariableV1(
924        initial_value=init_val,
925        name=name,
926        trainable=trainable,
927        collections=collections,
928        caching_device=caching_device,
929        dtype=variable_dtype,
930        validate_shape=validate_shape,
931        constraint=constraint,
932        use_resource=use_resource,
933        synchronization=synchronization,
934        aggregation=aggregation)
935    if context.executing_eagerly() and self._store_eager_variables:
936      if collections:
937        ops.add_to_collections(collections, v)
938      else:
939        ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, v)
940      if trainable:
941        ops.add_to_collection(ops.GraphKeys.TRAINABLE_VARIABLES, v)
942
943    if not context.executing_eagerly() or self._store_eager_variables:
944      # In eager mode we do not want to keep default references to Variable
945      # objects as this will prevent their memory from being released.
946      self._vars[name] = v
947    logging.vlog(1, "Created variable %s with shape %s and init %s", v.name,
948                 format(shape), initializer)
949
950    # Run the regularizer if requested and save the resulting loss.
951    if regularizer:
952      with ops.colocate_with(v):
953        with ops.name_scope(name + "/Regularizer/"):
954          with ops.init_scope():
955            loss = regularizer(v)
956        if loss is not None:
957          if context.executing_eagerly():
958            v_name = "v_%s" % type(v)
959            loss_name = "loss_%s" % type(loss)
960          else:
961            v_name = v.name
962            loss_name = loss.name
963          logging.vlog(1, "Applied regularizer to %s and added the result %s "
964                       "to REGULARIZATION_LOSSES.", v_name, loss_name)
965          ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, loss)
966    return v
967
968  # Initialize variable when no initializer provided
969  def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32):
970    """Provide a default initializer and a corresponding value.
971
972    Args:
973      name: see get_variable.
974      shape: see get_variable.
975      dtype: see get_variable.
976
977    Returns:
978      initializer and initializing_from_value. See get_variable above.
979
980    Raises:
981      ValueError: When giving unsupported dtype.
982    """
983    del shape
984    # If dtype is DT_FLOAT, provide a uniform unit scaling initializer
985    if dtype.is_floating:
986      initializer = init_ops.glorot_uniform_initializer()
987      initializing_from_value = False
988    # If dtype is DT_INT/DT_UINT, provide a default value `zero`
989    # If dtype is DT_BOOL, provide a default value `FALSE`
990    elif (dtype.is_integer or dtype.is_unsigned or dtype.is_bool
991          or dtype == dtypes.string):
992      initializer = init_ops.zeros_initializer()
993      initializing_from_value = False
994    # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
995    else:
996      raise ValueError("An initializer for variable %s of %s is required"
997                       % (name, dtype.base_dtype))
998
999    return initializer, initializing_from_value
1000
1001
1002# To stop regularization, use this regularizer
1003@tf_export("no_regularizer")
1004def no_regularizer(_):
1005  """Use this function to prevent regularization of variables."""
1006  return None
1007
1008
1009# TODO(alive): support caching devices and partitioned variables in Eager mode.
1010@tf_export(v1=["VariableScope"])
1011class VariableScope(object):
1012  """Variable scope object to carry defaults to provide to `get_variable`.
1013
1014  Many of the arguments we need for `get_variable` in a variable store are most
1015  easily handled with a context. This object is used for the defaults.
1016
1017  Attributes:
1018    name: name of the current scope, used as prefix in get_variable.
1019    initializer: default initializer passed to get_variable.
1020    regularizer: default regularizer passed to get_variable.
1021    reuse: Boolean, None, or tf.AUTO_REUSE, setting the reuse in
1022      get_variable. When eager execution is enabled this argument is always
1023      forced to be False.
1024    caching_device: string, callable, or None: the caching device passed to
1025      get_variable.
1026    partitioner: callable or `None`: the partitioner passed to `get_variable`.
1027    custom_getter: default custom getter passed to get_variable.
1028    name_scope: The name passed to `tf.name_scope`.
1029    dtype: default type passed to get_variable (defaults to DT_FLOAT).
1030    use_resource: if False, create a normal Variable; if True create an
1031      experimental ResourceVariable with well-defined semantics. Defaults
1032      to False (will later change to True). When eager execution is enabled
1033      this argument is always forced to be True.
1034    constraint: An optional projection function to be applied to the variable
1035      after being updated by an `Optimizer` (e.g. used to implement norm
1036      constraints or value constraints for layer weights). The function must
1037      take as input the unprojected Tensor representing the value of the
1038      variable and return the Tensor for the projected value
1039      (which must have the same shape). Constraints are not safe to
1040      use when doing asynchronous distributed training.
1041  """
1042
1043  def __init__(self,
1044               reuse,
1045               name="",
1046               initializer=None,
1047               regularizer=None,
1048               caching_device=None,
1049               partitioner=None,
1050               custom_getter=None,
1051               name_scope="",
1052               dtype=dtypes.float32,
1053               use_resource=None,
1054               constraint=None):
1055    """Creates a new VariableScope with the given properties."""
1056    self._name = name
1057    self._initializer = initializer
1058    self._regularizer = regularizer
1059    self._reuse = reuse
1060    self._caching_device = caching_device
1061    self._partitioner = partitioner
1062    self._custom_getter = custom_getter
1063    self._name_scope = name_scope
1064    self._dtype = dtype
1065    self._use_resource = use_resource
1066    self._constraint = constraint
1067    if context.executing_eagerly():
1068      if self._caching_device is not None:
1069        raise NotImplementedError("Caching devices is not yet supported "
1070                                  "when eager execution is enabled.")
1071      self._reuse = AUTO_REUSE
1072      self._use_resource = True
1073
1074  @property
1075  def name(self):
1076    return self._name
1077
1078  @property
1079  def original_name_scope(self):
1080    return self._name_scope
1081
1082  @property
1083  def reuse(self):
1084    return self._reuse
1085
1086  @property
1087  def initializer(self):
1088    return self._initializer
1089
1090  @property
1091  def dtype(self):
1092    return self._dtype
1093
1094  @property
1095  def use_resource(self):
1096    return self._use_resource
1097
1098  @property
1099  def regularizer(self):
1100    return self._regularizer
1101
1102  @property
1103  def caching_device(self):
1104    return self._caching_device
1105
1106  @property
1107  def partitioner(self):
1108    return self._partitioner
1109
1110  @property
1111  def custom_getter(self):
1112    return self._custom_getter
1113
1114  @property
1115  def constraint(self):
1116    return self._constraint
1117
1118  def reuse_variables(self):
1119    """Reuse variables in this scope."""
1120    self._reuse = True
1121
1122  def set_initializer(self, initializer):
1123    """Set initializer for this scope."""
1124    self._initializer = initializer
1125
1126  def set_dtype(self, dtype):
1127    """Set data type for this scope."""
1128    self._dtype = dtype
1129
1130  def set_use_resource(self, use_resource):
1131    """Sets whether to use ResourceVariables for this scope."""
1132    if context.executing_eagerly() and not use_resource:
1133      raise ValueError("When eager execution is enabled, "
1134                       "use_resource cannot be set to false.")
1135    self._use_resource = use_resource
1136
1137  def set_regularizer(self, regularizer):
1138    """Set regularizer for this scope."""
1139    self._regularizer = regularizer
1140
1141  def set_caching_device(self, caching_device):
1142    """Set caching_device for this scope."""
1143    if context.executing_eagerly():
1144      raise NotImplementedError("Caching devices are not yet supported "
1145                                "when eager execution is enabled.")
1146    self._caching_device = caching_device
1147
1148  def set_partitioner(self, partitioner):
1149    """Set partitioner for this scope."""
1150    self._partitioner = partitioner
1151
1152  def set_custom_getter(self, custom_getter):
1153    """Set custom getter for this scope."""
1154    self._custom_getter = custom_getter
1155
1156  def get_collection(self, name):
1157    """Get this scope's variables."""
1158    scope = self._name + "/" if self._name else ""
1159    return ops.get_collection(name, scope)
1160
1161  def trainable_variables(self):
1162    """Get this scope's trainable variables."""
1163    return self.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
1164
1165  def global_variables(self):
1166    """Get this scope's global variables."""
1167    return self.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
1168
1169  def local_variables(self):
1170    """Get this scope's local variables."""
1171    return self.get_collection(ops.GraphKeys.LOCAL_VARIABLES)
1172
1173  def get_variable(self,
1174                   var_store,
1175                   name,
1176                   shape=None,
1177                   dtype=None,
1178                   initializer=None,
1179                   regularizer=None,
1180                   reuse=None,
1181                   trainable=None,
1182                   collections=None,
1183                   caching_device=None,
1184                   partitioner=None,
1185                   validate_shape=True,
1186                   use_resource=None,
1187                   custom_getter=None,
1188                   constraint=None,
1189                   synchronization=VariableSynchronization.AUTO,
1190                   aggregation=VariableAggregation.NONE):
1191    """Gets an existing variable with this name or create a new one."""
1192    if regularizer is None:
1193      regularizer = self._regularizer
1194    if caching_device is None:
1195      caching_device = self._caching_device
1196    if partitioner is None:
1197      partitioner = self._partitioner
1198    if custom_getter is None:
1199      custom_getter = self._custom_getter
1200    if context.executing_eagerly():
1201      reuse = False
1202      use_resource = True
1203    else:
1204      if reuse is None:
1205        reuse = self._reuse
1206      if use_resource is None:
1207        use_resource = self._use_resource
1208
1209    full_name = self.name + "/" + name if self.name else name
1210    # Variable names only depend on variable_scope (full_name here),
1211    # not name_scope, so we reset it below for the time of variable creation.
1212    with ops.name_scope(None):
1213      # Check that `initializer` dtype and `dtype` are consistent before
1214      # replacing them with defaults.
1215      if (dtype is not None and initializer is not None and
1216          not callable(initializer)):
1217        init_dtype = ops.convert_to_tensor(initializer).dtype.base_dtype
1218        if init_dtype != dtype:
1219          raise ValueError("Initializer type '%s' and explicit dtype '%s' "
1220                           "don't match." % (init_dtype, dtype))
1221      if initializer is None:
1222        initializer = self._initializer
1223      if constraint is None:
1224        constraint = self._constraint
1225      if dtype is None:
1226        dtype = self._dtype
1227      return var_store.get_variable(
1228          full_name,
1229          shape=shape,
1230          dtype=dtype,
1231          initializer=initializer,
1232          regularizer=regularizer,
1233          reuse=reuse,
1234          trainable=trainable,
1235          collections=collections,
1236          caching_device=caching_device,
1237          partitioner=partitioner,
1238          validate_shape=validate_shape,
1239          use_resource=use_resource,
1240          custom_getter=custom_getter,
1241          constraint=constraint,
1242          synchronization=synchronization,
1243          aggregation=aggregation)
1244
1245  def _get_partitioned_variable(self,
1246                                var_store,
1247                                name,
1248                                shape=None,
1249                                dtype=None,
1250                                initializer=None,
1251                                regularizer=None,
1252                                trainable=None,
1253                                collections=None,
1254                                caching_device=None,
1255                                partitioner=None,
1256                                validate_shape=True,
1257                                use_resource=None,
1258                                constraint=None,
1259                                synchronization=VariableSynchronization.AUTO,
1260                                aggregation=VariableAggregation.NONE):
1261    """Gets an existing variable with this name or create a new one."""
1262    if initializer is None:
1263      initializer = self._initializer
1264    if regularizer is None:
1265      regularizer = self._regularizer
1266    if constraint is None:
1267      constraint = self._constraint
1268    if caching_device is None:
1269      caching_device = self._caching_device
1270    if partitioner is None:
1271      partitioner = self._partitioner
1272    if dtype is None:
1273      dtype = self._dtype
1274    if use_resource is None:
1275      use_resource = self._use_resource
1276
1277    if self._custom_getter is not None:
1278      raise ValueError(
1279          "Private access to _get_partitioned_variable is not allowed when "
1280          "a custom getter is set.  Current custom getter: %s.  "
1281          "It is likely that you're using create_partitioned_variables.  "
1282          "If so, consider instead using get_variable with a non-empty "
1283          "partitioner parameter instead." % self._custom_getter)
1284
1285    if partitioner is None:
1286      raise ValueError("No partitioner was specified")
1287
1288    # This allows the variable scope name to be used as the variable name if
1289    # this function is invoked with an empty name arg, for backward
1290    # compatibility with create_partitioned_variables().
1291    full_name_list = []
1292    if self.name:
1293      full_name_list.append(self.name)
1294    if name:
1295      full_name_list.append(name)
1296    full_name = "/".join(full_name_list)
1297
1298    # Variable names only depend on variable_scope (full_name here),
1299    # not name_scope, so we reset it below for the time of variable creation.
1300    with ops.name_scope(None):
1301      # pylint: disable=protected-access
1302      return var_store._get_partitioned_variable(
1303          full_name,
1304          shape=shape,
1305          dtype=dtype,
1306          initializer=initializer,
1307          regularizer=regularizer,
1308          reuse=self.reuse,
1309          trainable=trainable,
1310          collections=collections,
1311          caching_device=caching_device,
1312          partitioner=partitioner,
1313          validate_shape=validate_shape,
1314          use_resource=use_resource,
1315          constraint=constraint,
1316          synchronization=synchronization,
1317          aggregation=aggregation)
1318      # pylint: enable=protected-access
1319
1320
1321_VARSTORE_KEY = ("__variable_store",)
1322_VARSCOPESTORE_KEY = ("__varscope",)
1323
1324
1325class _VariableScopeStore(threading.local):
1326  """A thread local store for the current variable scope and scope counts."""
1327
1328  def __init__(self):
1329    super(_VariableScopeStore, self).__init__()
1330    self.current_scope = VariableScope(False)
1331    self.variable_scopes_count = {}
1332
1333  def open_variable_scope(self, scope_name):
1334    if scope_name in self.variable_scopes_count:
1335      self.variable_scopes_count[scope_name] += 1
1336    else:
1337      self.variable_scopes_count[scope_name] = 1
1338
1339  def close_variable_subscopes(self, scope_name):
1340    for k in list(self.variable_scopes_count.keys()):
1341      if scope_name is None or k.startswith(scope_name + "/"):
1342        self.variable_scopes_count[k] = 0
1343
1344  def variable_scope_count(self, scope_name):
1345    return self.variable_scopes_count.get(scope_name, 0)
1346
1347
1348def get_variable_scope_store():
1349  """Returns the variable scope store for current thread."""
1350  scope_store = ops.get_collection(_VARSCOPESTORE_KEY)
1351
1352  if not scope_store:
1353    scope_store = _VariableScopeStore()
1354    ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store)
1355  else:
1356    scope_store = scope_store[0]
1357
1358  return scope_store
1359
1360
1361@tf_export(v1=["get_variable_scope"])
1362def get_variable_scope():
1363  """Returns the current variable scope."""
1364  return get_variable_scope_store().current_scope
1365
1366
1367def _get_default_variable_store():
1368  store = ops.get_collection(_VARSTORE_KEY)
1369  if store:
1370    return store[0]
1371  store = _VariableStore()
1372  ops.add_to_collection(_VARSTORE_KEY, store)
1373  return store
1374
1375
1376@tf_contextlib.contextmanager
1377def with_variable_store(store):
1378  store_collection = ops.get_collection_ref(_VARSTORE_KEY)
1379  old = list(store_collection)
1380  store_collection[:] = [store]
1381  try:
1382    yield
1383  finally:
1384    store_collection[:] = old
1385
1386
1387class EagerVariableStore(object):
1388  """Wrapper allowing functional layers to be used with eager execution.
1389
1390  When eager execution is enabled Variables get deleted when they go out of
1391  scope, and are not stored in global collections by default. A lot of code
1392  (mostly the functional layers in tf.layers) assumes that variables are kept in
1393  a global list.
1394
1395  EagerVariableStore can be used in conjunction with this code to make it
1396  eager-friendly. For example, to create a dense layer, use:
1397
1398  ```
1399    container = tfe.EagerVariableStore()
1400    for input in dataset_iterator:
1401      with container.as_default():
1402        x = tf.layers.dense(input, name="l1")
1403    print(container.variables)  # Should print the variables used in the layer.
1404  ```
1405  """
1406
1407  def __init__(self, store=None):
1408    if store is not None:
1409      if not store._store_eager_variables:  # pylint: disable=protected-access
1410        raise ValueError("Cannot construct EagerVariableStore from a "
1411                         "VariableStore object that does not hold eager "
1412                         "variables.")
1413      self._store = store
1414    else:
1415      self._store = _VariableStore()
1416    self._store._store_eager_variables = True  # pylint: disable=protected-access
1417
1418  def as_default(self):
1419    return with_variable_store(self._store)
1420
1421  def variables(self):
1422    return sorted(self._store._vars.values(), key=lambda x: x.name)  # pylint: disable=protected-access
1423
1424  def trainable_variables(self):
1425    # pylint: disable=protected-access
1426    return sorted([x for x in self._store._vars.values() if x.trainable],
1427                  key=lambda x: x.name)
1428    # pylint: enable=protected-access
1429
1430  def non_trainable_variables(self):
1431    # pylint: disable=protected-access
1432    return sorted([x for x in self._store._vars.values() if not x.trainable],
1433                  key=lambda x: x.name)
1434    # pylint: enable=protected-access
1435
1436  def copy(self):
1437    """Copy this variable store and all of its contents.
1438
1439    Variables contained in this store will be copied over to the new variable
1440    store, meaning that they can be modified without affecting the variables in
1441    this store.
1442
1443    Returns:
1444      A new EagerVariableStore instance containing copied variables.
1445    """
1446    # pylint: disable=protected-access
1447    new_store = EagerVariableStore()
1448    for key, var in iteritems(self._store._vars):
1449      # Strip device out of variable name.
1450      try:
1451        index = var.name.index(":")
1452      except ValueError:
1453        stripped_var_name = var.name
1454      else:
1455        stripped_var_name = var.name[:index]
1456
1457      # Create new variable with same value, name, and "trainable" flag.
1458      new_var = resource_variable_ops.ResourceVariable(
1459          var.read_value(),
1460          name=stripped_var_name,
1461          trainable=var.trainable)
1462      new_store._store._vars[key] = new_var
1463    return new_store
1464    # pylint: enable=protected-access
1465
1466
1467# The argument list for get_variable must match arguments to get_local_variable.
1468# So, if you are updating the arguments, also update arguments to
1469# get_local_variable below.
1470@tf_export(v1=["get_variable"])
1471def get_variable(name,
1472                 shape=None,
1473                 dtype=None,
1474                 initializer=None,
1475                 regularizer=None,
1476                 trainable=None,
1477                 collections=None,
1478                 caching_device=None,
1479                 partitioner=None,
1480                 validate_shape=True,
1481                 use_resource=None,
1482                 custom_getter=None,
1483                 constraint=None,
1484                 synchronization=VariableSynchronization.AUTO,
1485                 aggregation=VariableAggregation.NONE):
1486  return get_variable_scope().get_variable(
1487      _get_default_variable_store(),
1488      name,
1489      shape=shape,
1490      dtype=dtype,
1491      initializer=initializer,
1492      regularizer=regularizer,
1493      trainable=trainable,
1494      collections=collections,
1495      caching_device=caching_device,
1496      partitioner=partitioner,
1497      validate_shape=validate_shape,
1498      use_resource=use_resource,
1499      custom_getter=custom_getter,
1500      constraint=constraint,
1501      synchronization=synchronization,
1502      aggregation=aggregation)
1503
1504
1505get_variable_or_local_docstring = ("""%s
1506
1507%sThis function prefixes the name with the current variable scope
1508and performs reuse checks. See the
1509[Variable Scope How To](https://tensorflow.org/guide/variables)
1510for an extensive description of how reusing works. Here is a basic example:
1511
1512```python
1513def foo():
1514  with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
1515    v = tf.get_variable("v", [1])
1516  return v
1517
1518v1 = foo()  # Creates v.
1519v2 = foo()  # Gets the same, existing v.
1520assert v1 == v2
1521```
1522
1523If initializer is `None` (the default), the default initializer passed in
1524the variable scope will be used. If that one is `None` too, a
1525`glorot_uniform_initializer` will be used. The initializer can also be
1526a Tensor, in which case the variable is initialized to this value and shape.
1527
1528Similarly, if the regularizer is `None` (the default), the default regularizer
1529passed in the variable scope will be used (if that is `None` too,
1530then by default no regularization is performed).
1531
1532If a partitioner is provided, a `PartitionedVariable` is returned.
1533Accessing this object as a `Tensor` returns the shards concatenated along
1534the partition axis.
1535
1536Some useful partitioners are available.  See, e.g.,
1537`variable_axis_size_partitioner` and `min_max_variable_partitioner`.
1538
1539Args:
1540  name: The name of the new or existing variable.
1541  shape: Shape of the new or existing variable.
1542  dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
1543  initializer: Initializer for the variable if one is created. Can either be
1544    an initializer object or a Tensor. If it's a Tensor, its shape must be known
1545    unless validate_shape is False.
1546  regularizer: A (Tensor -> Tensor or None) function; the result of
1547    applying it on a newly created variable will be added to the collection
1548    `tf.GraphKeys.REGULARIZATION_LOSSES` and can be used for regularization.
1549  %scollections: List of graph collections keys to add the Variable to.
1550    Defaults to `[%s]` (see `tf.Variable`).
1551  caching_device: Optional device string or function describing where the
1552    Variable should be cached for reading.  Defaults to the Variable's
1553    device.  If not `None`, caches on another device.  Typical use is to
1554    cache on the device where the Ops using the Variable reside, to
1555    deduplicate copying through `Switch` and other conditional statements.
1556  partitioner: Optional callable that accepts a fully defined `TensorShape`
1557    and `dtype` of the Variable to be created, and returns a list of
1558    partitions for each axis (currently only one axis can be partitioned).
1559  validate_shape: If False, allows the variable to be initialized with a
1560      value of unknown shape. If True, the default, the shape of initial_value
1561      must be known. For this to be used the initializer must be a Tensor and
1562      not an initializer object.
1563  use_resource: If False, creates a regular Variable. If true, creates an
1564    experimental ResourceVariable instead with well-defined semantics.
1565    Defaults to False (will later change to True). When eager execution is
1566    enabled this argument is always forced to be True.
1567  custom_getter: Callable that takes as a first argument the true getter, and
1568    allows overwriting the internal get_variable method.
1569    The signature of `custom_getter` should match that of this method,
1570    but the most future-proof version will allow for changes:
1571    `def custom_getter(getter, *args, **kwargs)`.  Direct access to
1572    all `get_variable` parameters is also allowed:
1573    `def custom_getter(getter, name, *args, **kwargs)`.  A simple identity
1574    custom getter that simply creates variables with modified names is:
1575    ```python
1576    def custom_getter(getter, name, *args, **kwargs):
1577      return getter(name + '_suffix', *args, **kwargs)
1578    ```
1579  constraint: An optional projection function to be applied to the variable
1580    after being updated by an `Optimizer` (e.g. used to implement norm
1581    constraints or value constraints for layer weights). The function must
1582    take as input the unprojected Tensor representing the value of the
1583    variable and return the Tensor for the projected value
1584    (which must have the same shape). Constraints are not safe to
1585    use when doing asynchronous distributed training.
1586  synchronization: Indicates when a distributed a variable will be
1587    aggregated. Accepted values are constants defined in the class
1588    `tf.VariableSynchronization`. By default the synchronization is set to
1589    `AUTO` and the current `DistributionStrategy` chooses
1590    when to synchronize. If `synchronization` is set to `ON_READ`,
1591    `trainable` must not be set to `True`.
1592  aggregation: Indicates how a distributed variable will be aggregated.
1593    Accepted values are constants defined in the class
1594    `tf.VariableAggregation`.
1595
1596Returns:
1597  The created or existing `Variable` (or `PartitionedVariable`, if a
1598  partitioner was used).
1599
1600Raises:
1601  ValueError: when creating a new variable and shape is not declared,
1602    when violating reuse during variable creation, or when `initializer` dtype
1603    and `dtype` don't match. Reuse is set inside `variable_scope`.
1604""")
1605get_variable.__doc__ = get_variable_or_local_docstring % (
1606    "Gets an existing variable with these parameters or create a new one.",
1607    "",
1608    "trainable: If `True` also add the variable to the graph collection\n"
1609    "    `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).\n  ",
1610    "GraphKeys.GLOBAL_VARIABLES")
1611
1612
1613# The argument list for get_local_variable must match arguments to get_variable.
1614# So, if you are updating the arguments, also update arguments to get_variable.
1615@tf_export(v1=["get_local_variable"])
1616def get_local_variable(  # pylint: disable=missing-docstring
1617    name,
1618    shape=None,
1619    dtype=None,
1620    initializer=None,
1621    regularizer=None,
1622    trainable=False,  # pylint: disable=unused-argument
1623    collections=None,
1624    caching_device=None,
1625    partitioner=None,
1626    validate_shape=True,
1627    use_resource=None,
1628    custom_getter=None,
1629    constraint=None,
1630    synchronization=VariableSynchronization.AUTO,
1631    aggregation=VariableAggregation.NONE):
1632  if collections:
1633    collections += [ops.GraphKeys.LOCAL_VARIABLES]
1634  else:
1635    collections = [ops.GraphKeys.LOCAL_VARIABLES]
1636  return get_variable(
1637      name,
1638      shape=shape,
1639      dtype=dtype,
1640      initializer=initializer,
1641      regularizer=regularizer,
1642      trainable=False,
1643      collections=collections,
1644      caching_device=caching_device,
1645      partitioner=partitioner,
1646      validate_shape=validate_shape,
1647      use_resource=use_resource,
1648      synchronization=synchronization,
1649      aggregation=aggregation,
1650      custom_getter=custom_getter,
1651      constraint=constraint)
1652
1653
1654get_local_variable.__doc__ = get_variable_or_local_docstring % (
1655    "Gets an existing *local* variable or creates a new one.",
1656    "Behavior is the same as in `get_variable`, except that variables are\n"
1657    "added to the `LOCAL_VARIABLES` collection and `trainable` is set to\n"
1658    "`False`.\n",
1659    "",
1660    "GraphKeys.LOCAL_VARIABLES")
1661
1662
1663def _get_partitioned_variable(name,
1664                              shape=None,
1665                              dtype=None,
1666                              initializer=None,
1667                              regularizer=None,
1668                              trainable=True,
1669                              collections=None,
1670                              caching_device=None,
1671                              partitioner=None,
1672                              validate_shape=True,
1673                              use_resource=None,
1674                              constraint=None,
1675                              synchronization=VariableSynchronization.AUTO,
1676                              aggregation=VariableAggregation.NONE):
1677  """Gets or creates a sharded variable list with these parameters.
1678
1679  The `partitioner` must be a callable that accepts a fully defined
1680  `TensorShape` and returns a sequence of integers (the `partitions`).
1681  These integers describe how to partition the given sharded `Variable`
1682  along the given dimension.  That is, `partitions[1] = 3` means split
1683  the `Variable` into 3 shards along dimension 1.  Currently, sharding along
1684  only one axis is supported.
1685
1686  If the list of variables with the given name (prefix) is already stored,
1687  we return the stored variables. Otherwise, we create a new one.
1688
1689  If initializer is `None` (the default), the default initializer passed in
1690  the constructor is used. If that one is `None` too, we use a new
1691  `glorot_uniform_initializer`. If initializer is a Tensor, we use
1692  it as a value and derive the shape from the initializer.
1693
1694  If the initializer is a callable, then it will be called for each
1695  shard.  Otherwise the initializer should match the shape of the entire
1696  sharded Variable, and it will be sliced accordingly for each shard.
1697
1698  Some useful partitioners are available.  See, e.g.,
1699  `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
1700
1701  Args:
1702    name: The name of the new or existing variable.
1703    shape: Shape of the new or existing variable.
1704    dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
1705    initializer: Initializer for the variable if one is created.
1706    regularizer: A (Tensor -> Tensor or None) function; the result of
1707      applying it on a newly created variable will be added to the collection
1708      GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
1709    trainable: If `True` also add the variable to the graph collection
1710      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
1711    collections: List of graph collections keys to add the Variable to.
1712      Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
1713    caching_device: Optional device string or function describing where the
1714      Variable should be cached for reading.  Defaults to the Variable's
1715      device.  If not `None`, caches on another device.  Typical use is to
1716      cache on the device where the Ops using the Variable reside, to
1717      deduplicate copying through `Switch` and other conditional statements.
1718    partitioner: Optional callable that accepts a fully defined `TensorShape`
1719      and `dtype` of the Variable to be created, and returns a list of
1720      partitions for each axis (currently only one axis can be partitioned).
1721    validate_shape: If False, allows the variable to be initialized with a
1722        value of unknown shape. If True, the default, the shape of initial_value
1723        must be known.
1724    use_resource: If False, creates a regular Variable. If True, creates an
1725      experimental ResourceVariable instead which has well-defined semantics.
1726      Defaults to False (will later change to True).
1727    constraint: An optional projection function to be applied to the variable
1728      after being updated by an `Optimizer` (e.g. used to implement norm
1729      constraints or value constraints for layer weights). The function must
1730      take as input the unprojected Tensor representing the value of the
1731      variable and return the Tensor for the projected value
1732      (which must have the same shape). Constraints are not safe to
1733      use when doing asynchronous distributed training.
1734    synchronization: Indicates when a distributed a variable will be
1735      aggregated. Accepted values are constants defined in the class
1736      `tf.VariableSynchronization`. By default the synchronization is set to
1737      `AUTO` and the current `DistributionStrategy` chooses
1738      when to synchronize. If `synchronization` is set to `ON_READ`,
1739      `trainable` must not be set to `True`.
1740    aggregation: Indicates how a distributed variable will be aggregated.
1741      Accepted values are constants defined in the class
1742      `tf.VariableAggregation`.
1743
1744  Returns:
1745    A tuple `(shards, partitions)` where `shards` is the list of `Variable`
1746    shards and `partitions` is the output of the partitioner on the input
1747    shape.
1748
1749  Raises:
1750    ValueError: when creating a new variable and shape is not declared,
1751      or when violating reuse during variable creation. Reuse is set inside
1752      `variable_scope`.
1753  """
1754  # pylint: disable=protected-access
1755  scope = get_variable_scope()
1756  if scope.custom_getter is not None:
1757    raise ValueError(
1758        "Private access to _get_partitioned_variable is not allowed when "
1759        "a custom getter is set.  Current custom getter: %s.  "
1760        "It is likely that you're using create_partitioned_variables.  "
1761        "If so, consider instead using get_variable with a non-empty "
1762        "partitioner parameter instead." % scope.custom_getter)
1763  return scope._get_partitioned_variable(
1764      _get_default_variable_store(),
1765      name,
1766      shape=shape,
1767      dtype=dtype,
1768      initializer=initializer,
1769      regularizer=regularizer,
1770      trainable=trainable,
1771      collections=collections,
1772      caching_device=caching_device,
1773      partitioner=partitioner,
1774      validate_shape=validate_shape,
1775      use_resource=use_resource,
1776      constraint=constraint,
1777      synchronization=synchronization,
1778      aggregation=aggregation)
1779  # pylint: enable=protected-access
1780
1781
1782# Named like a function for compatibility with the previous
1783# @tf_contextlib.contextmanager definition.
1784class _pure_variable_scope(object):  # pylint: disable=invalid-name
1785  """A context for the variable_scope, see `variable_scope` for docs."""
1786
1787  def __init__(self,
1788               name_or_scope,
1789               reuse=None,
1790               initializer=None,
1791               regularizer=None,
1792               caching_device=None,
1793               partitioner=None,
1794               custom_getter=None,
1795               old_name_scope=None,
1796               dtype=dtypes.float32,
1797               use_resource=None,
1798               constraint=None):
1799    """Creates a context for the variable_scope, see `variable_scope` for docs.
1800
1801    Note: this does not create a name scope.
1802
1803    Args:
1804      name_or_scope: `string` or `VariableScope`: the scope to open.
1805      reuse: `True` or None, or tf.AUTO_REUSE; if `None`, we inherit the parent
1806        scope's reuse flag.
1807      initializer: default initializer for variables within this scope.
1808      regularizer: default regularizer for variables within this scope.
1809      caching_device: default caching device for variables within this scope.
1810      partitioner: default partitioner for variables within this scope.
1811      custom_getter: default custom getter for variables within this scope.
1812      old_name_scope: the original name scope when re-entering a variable scope.
1813      dtype: type of the variables within this scope (defaults to `DT_FLOAT`).
1814      use_resource: If False, variables in this scope will be regular Variables.
1815        If True, experimental ResourceVariables will be creates instead, with
1816        well-defined semantics. Defaults to False (will later change to True).
1817      constraint: An optional projection function to be applied to the variable
1818        after being updated by an `Optimizer` (e.g. used to implement norm
1819        constraints or value constraints for layer weights). The function must
1820        take as input the unprojected Tensor representing the value of the
1821        variable and return the Tensor for the projected value
1822        (which must have the same shape). Constraints are not safe to
1823        use when doing asynchronous distributed training.
1824    """
1825    self._name_or_scope = name_or_scope
1826    self._reuse = reuse
1827    self._initializer = initializer
1828    self._regularizer = regularizer
1829    self._caching_device = caching_device
1830    self._partitioner = partitioner
1831    self._custom_getter = custom_getter
1832    self._old_name_scope = old_name_scope
1833    self._dtype = dtype
1834    self._use_resource = use_resource
1835    self._constraint = constraint
1836    self._var_store = _get_default_variable_store()
1837    self._var_scope_store = get_variable_scope_store()
1838    if isinstance(self._name_or_scope, VariableScope):
1839      self._new_name = self._name_or_scope.name
1840      name_scope = self._name_or_scope._name_scope  # pylint: disable=protected-access
1841      # Handler for the case when we jump to a shared scope.  We create a new
1842      #   VariableScope (self._var_scope_object) that contains a copy of the
1843      #   provided shared scope, possibly with changed reuse and initializer, if
1844      #   the user requested this.
1845      variable_scope_object = VariableScope(
1846          self._name_or_scope.reuse if not self._reuse else self._reuse,
1847          name=self._new_name,
1848          initializer=self._name_or_scope.initializer,
1849          regularizer=self._name_or_scope.regularizer,
1850          caching_device=self._name_or_scope.caching_device,
1851          partitioner=self._name_or_scope.partitioner,
1852          dtype=self._name_or_scope.dtype,
1853          custom_getter=self._name_or_scope.custom_getter,
1854          name_scope=name_scope,
1855          use_resource=self._name_or_scope.use_resource,
1856          constraint=self._constraint)
1857      if self._initializer is not None:
1858        variable_scope_object.set_initializer(self._initializer)
1859      if self._regularizer is not None:
1860        variable_scope_object.set_regularizer(self._regularizer)
1861      if self._caching_device is not None:
1862        variable_scope_object.set_caching_device(self._caching_device)
1863      if self._partitioner is not None:
1864        variable_scope_object.set_partitioner(self._partitioner)
1865      if self._custom_getter is not None:
1866        variable_scope_object.set_custom_getter(
1867            _maybe_wrap_custom_getter(
1868                self._custom_getter, self._name_or_scope.custom_getter))
1869      if self._dtype is not None:
1870        variable_scope_object.set_dtype(self._dtype)
1871      if self._use_resource is not None:
1872        variable_scope_object.set_use_resource(self._use_resource)
1873      self._cached_variable_scope_object = variable_scope_object
1874
1875  def __enter__(self):
1876    """Begins the scope block.
1877
1878    Returns:
1879      A VariableScope.
1880    Raises:
1881      ValueError: when trying to reuse within a create scope, or create within
1882        a reuse scope, or if reuse is not `None` or `True`.
1883      TypeError: when the types of some arguments are not appropriate.
1884    """
1885    self._old = self._var_scope_store.current_scope
1886    if isinstance(self._name_or_scope, VariableScope):
1887      self._var_scope_store.open_variable_scope(self._new_name)
1888      self._old_subscopes = copy.copy(
1889          self._var_scope_store.variable_scopes_count)
1890      variable_scope_object = self._cached_variable_scope_object
1891    else:
1892      # Handler for the case when we just prolong current variable scope.
1893      #   VariableScope with name extended by the provided one, and inherited
1894      #   reuse and initializer (except if the user provided values to set).
1895      self._new_name = (
1896          self._old.name + "/" + self._name_or_scope if self._old.name
1897          else self._name_or_scope)
1898      self._reuse = (self._reuse
1899                     or self._old.reuse)  # Re-using is inherited by sub-scopes.
1900      if self._old_name_scope is None:
1901        name_scope = self._name_or_scope
1902      else:
1903        name_scope = self._old_name_scope
1904      variable_scope_object = VariableScope(
1905          self._reuse,
1906          name=self._new_name,
1907          initializer=self._old.initializer,
1908          regularizer=self._old.regularizer,
1909          caching_device=self._old.caching_device,
1910          partitioner=self._old.partitioner,
1911          dtype=self._old.dtype,
1912          use_resource=self._old.use_resource,
1913          custom_getter=self._old.custom_getter,
1914          name_scope=name_scope,
1915          constraint=self._constraint)
1916      if self._initializer is not None:
1917        variable_scope_object.set_initializer(self._initializer)
1918      if self._regularizer is not None:
1919        variable_scope_object.set_regularizer(self._regularizer)
1920      if self._caching_device is not None:
1921        variable_scope_object.set_caching_device(self._caching_device)
1922      if self._partitioner is not None:
1923        variable_scope_object.set_partitioner(self._partitioner)
1924      if self._custom_getter is not None:
1925        variable_scope_object.set_custom_getter(
1926            _maybe_wrap_custom_getter(self._custom_getter,
1927                                      self._old.custom_getter))
1928      if self._dtype is not None:
1929        variable_scope_object.set_dtype(self._dtype)
1930      if self._use_resource is not None:
1931        variable_scope_object.set_use_resource(self._use_resource)
1932      self._var_scope_store.open_variable_scope(self._new_name)
1933    self._var_scope_store.current_scope = variable_scope_object
1934    return variable_scope_object
1935
1936  def __exit__(self, type_arg, value_arg, traceback_arg):
1937    # If jumping out from a non-prolonged scope, restore counts.
1938    if isinstance(self._name_or_scope, VariableScope):
1939      self._var_scope_store.variable_scopes_count = self._old_subscopes
1940    else:
1941      self._var_scope_store.close_variable_subscopes(self._new_name)
1942    self._var_scope_store.current_scope = self._old
1943
1944
1945def _maybe_wrap_custom_getter(custom_getter, old_getter):
1946  """Wrap a call to a custom_getter to use the old_getter internally."""
1947  if old_getter is None:
1948    return custom_getter
1949
1950  # The new custom_getter should call the old one
1951  def wrapped_custom_getter(getter, *args, **kwargs):
1952    # Call:
1953    #  custom_getter(
1954    #    lambda: old_getter(true_getter, ...), *args, **kwargs)
1955    # which means custom_getter will call old_getter, which
1956    # will call the true_getter, perform any intermediate
1957    # processing, and return the results to the current
1958    # getter, which will also perform additional processing.
1959    return custom_getter(
1960        functools.partial(old_getter, getter),
1961        *args, **kwargs)
1962  return wrapped_custom_getter
1963
1964
1965def _get_unique_variable_scope(prefix):
1966  """Get a name with the given prefix unique in the current variable scope."""
1967  var_scope_store = get_variable_scope_store()
1968  current_scope = get_variable_scope()
1969  name = current_scope.name + "/" + prefix if current_scope.name else prefix
1970  if var_scope_store.variable_scope_count(name) == 0:
1971    return prefix
1972  idx = 1
1973  while var_scope_store.variable_scope_count(name + ("_%d" % idx)) > 0:
1974    idx += 1
1975  return prefix + ("_%d" % idx)
1976
1977
1978# Named like a function for backwards compatibility with the
1979# @tf_contextlib.contextmanager version, which was switched to a class to avoid
1980# some object creation overhead.
1981@tf_export(v1=["variable_scope"])  # pylint: disable=invalid-name
1982class variable_scope(object):
1983  """A context manager for defining ops that creates variables (layers).
1984
1985  This context manager validates that the (optional) `values` are from the same
1986  graph, ensures that graph is the default graph, and pushes a name scope and a
1987  variable scope.
1988
1989  If `name_or_scope` is not None, it is used as is. If `name_or_scope` is None,
1990  then `default_name` is used.  In that case, if the same name has been
1991  previously used in the same scope, it will be made unique by appending `_N`
1992  to it.
1993
1994  Variable scope allows you to create new variables and to share already created
1995  ones while providing checks to not create or share by accident. For details,
1996  see the [Variable Scope How To](https://tensorflow.org/guide/variables), here
1997  we present only a few basic examples.
1998
1999  Simple example of how to create a new variable:
2000
2001  ```python
2002  with tf.variable_scope("foo"):
2003      with tf.variable_scope("bar"):
2004          v = tf.get_variable("v", [1])
2005          assert v.name == "foo/bar/v:0"
2006  ```
2007
2008  Simple example of how to reenter a premade variable scope safely:
2009
2010  ```python
2011  with tf.variable_scope("foo") as vs:
2012    pass
2013
2014  # Re-enter the variable scope.
2015  with tf.variable_scope(vs,
2016                         auxiliary_name_scope=False) as vs1:
2017    # Restore the original name_scope.
2018    with tf.name_scope(vs1.original_name_scope):
2019        v = tf.get_variable("v", [1])
2020        assert v.name == "foo/v:0"
2021        c = tf.constant([1], name="c")
2022        assert c.name == "foo/c:0"
2023  ```
2024
2025  Basic example of sharing a variable AUTO_REUSE:
2026
2027  ```python
2028  def foo():
2029    with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
2030      v = tf.get_variable("v", [1])
2031    return v
2032
2033  v1 = foo()  # Creates v.
2034  v2 = foo()  # Gets the same, existing v.
2035  assert v1 == v2
2036  ```
2037
2038  Basic example of sharing a variable with reuse=True:
2039
2040  ```python
2041  with tf.variable_scope("foo"):
2042      v = tf.get_variable("v", [1])
2043  with tf.variable_scope("foo", reuse=True):
2044      v1 = tf.get_variable("v", [1])
2045  assert v1 == v
2046  ```
2047
2048  Sharing a variable by capturing a scope and setting reuse:
2049
2050  ```python
2051  with tf.variable_scope("foo") as scope:
2052      v = tf.get_variable("v", [1])
2053      scope.reuse_variables()
2054      v1 = tf.get_variable("v", [1])
2055  assert v1 == v
2056  ```
2057
2058  To prevent accidental sharing of variables, we raise an exception when getting
2059  an existing variable in a non-reusing scope.
2060
2061  ```python
2062  with tf.variable_scope("foo"):
2063      v = tf.get_variable("v", [1])
2064      v1 = tf.get_variable("v", [1])
2065      #  Raises ValueError("... v already exists ...").
2066  ```
2067
2068  Similarly, we raise an exception when trying to get a variable that does not
2069  exist in reuse mode.
2070
2071  ```python
2072  with tf.variable_scope("foo", reuse=True):
2073      v = tf.get_variable("v", [1])
2074      #  Raises ValueError("... v does not exists ...").
2075  ```
2076
2077  Note that the `reuse` flag is inherited: if we open a reusing scope, then all
2078  its sub-scopes become reusing as well.
2079
2080  A note about name scoping: Setting `reuse` does not impact the naming of other
2081  ops such as mult. See related discussion on
2082  [github#6189](https://github.com/tensorflow/tensorflow/issues/6189)
2083
2084  Note that up to and including version 1.0, it was allowed (though explicitly
2085  discouraged) to pass False to the reuse argument, yielding undocumented
2086  behaviour slightly different from None. Starting at 1.1.0 passing None and
2087  False as reuse has exactly the same effect.
2088
2089  A note about using variable scopes in multi-threaded environment: Variable
2090  scopes are thread local, so one thread will not see another thread's current
2091  scope. Also, when using `default_name`, unique scopes names are also generated
2092  only on a per thread basis. If the same name was used within a different
2093  thread, that doesn't prevent a new thread from creating the same scope.
2094  However, the underlying variable store is shared across threads (within the
2095  same graph). As such, if another thread tries to create a new variable with
2096  the same name as a variable created by a previous thread, it will fail unless
2097  reuse is True.
2098
2099  Further, each thread starts with an empty variable scope. So if you wish to
2100  preserve name prefixes from a scope from the main thread, you should capture
2101  the main thread's scope and re-enter it in each thread. For e.g.
2102
2103  ```
2104  main_thread_scope = variable_scope.get_variable_scope()
2105
2106  # Thread's target function:
2107  def thread_target_fn(captured_scope):
2108    with variable_scope.variable_scope(captured_scope):
2109      # .... regular code for this thread
2110
2111
2112  thread = threading.Thread(target=thread_target_fn, args=(main_thread_scope,))
2113  ```
2114  """
2115
2116  def __init__(self,
2117               name_or_scope,
2118               default_name=None,
2119               values=None,
2120               initializer=None,
2121               regularizer=None,
2122               caching_device=None,
2123               partitioner=None,
2124               custom_getter=None,
2125               reuse=None,
2126               dtype=None,
2127               use_resource=None,
2128               constraint=None,
2129               auxiliary_name_scope=True):
2130    """Initialize the context manager.
2131
2132    Args:
2133      name_or_scope: `string` or `VariableScope`: the scope to open.
2134      default_name: The default name to use if the `name_or_scope` argument is
2135        `None`, this name will be uniquified. If name_or_scope is provided it
2136        won't be used and therefore it is not required and can be None.
2137      values: The list of `Tensor` arguments that are passed to the op function.
2138      initializer: default initializer for variables within this scope.
2139      regularizer: default regularizer for variables within this scope.
2140      caching_device: default caching device for variables within this scope.
2141      partitioner: default partitioner for variables within this scope.
2142      custom_getter: default custom getter for variables within this scope.
2143      reuse: `True`, None, or tf.AUTO_REUSE; if `True`, we go into reuse mode
2144        for this scope as well as all sub-scopes; if tf.AUTO_REUSE, we create
2145        variables if they do not exist, and return them otherwise; if None, we
2146        inherit the parent scope's reuse flag. When eager execution is enabled,
2147        new variables are always created unless an EagerVariableStore or
2148        template is currently active.
2149      dtype: type of variables created in this scope (defaults to the type
2150        in the passed scope, or inherited from parent scope).
2151      use_resource: If False, all variables will be regular Variables. If True,
2152        experimental ResourceVariables with well-defined semantics will be used
2153        instead. Defaults to False (will later change to True). When eager
2154        execution is enabled this argument is always forced to be True.
2155      constraint: An optional projection function to be applied to the variable
2156        after being updated by an `Optimizer` (e.g. used to implement norm
2157        constraints or value constraints for layer weights). The function must
2158        take as input the unprojected Tensor representing the value of the
2159        variable and return the Tensor for the projected value
2160        (which must have the same shape). Constraints are not safe to
2161        use when doing asynchronous distributed training.
2162      auxiliary_name_scope: If `True`, we create an auxiliary name scope with
2163        the scope. If `False`, we don't create it. Note that the argument is
2164        not inherited, and it only takes effect for once when creating. You
2165        should only use it for re-entering a premade variable scope.
2166
2167    Returns:
2168      A scope that can be captured and reused.
2169
2170    Raises:
2171      ValueError: when trying to reuse within a create scope, or create within
2172        a reuse scope.
2173      TypeError: when the types of some arguments are not appropriate.
2174    """
2175    self._name_or_scope = name_or_scope
2176    self._default_name = default_name
2177    self._values = values
2178    self._initializer = initializer
2179    self._regularizer = regularizer
2180    self._caching_device = caching_device
2181    self._partitioner = partitioner
2182    self._custom_getter = custom_getter
2183    self._reuse = reuse
2184    self._dtype = dtype
2185    self._use_resource = use_resource
2186    self._constraint = constraint
2187    if self._default_name is None and self._name_or_scope is None:
2188      raise TypeError("If default_name is None then name_or_scope is required")
2189    if self._reuse is False:
2190      # We don't allow non-inheriting scopes, False = None here.
2191      self._reuse = None
2192    if not (self._reuse is True
2193            or self._reuse is None
2194            or self._reuse is AUTO_REUSE):
2195      raise ValueError("The reuse parameter must be True or False or None.")
2196    if self._values is None:
2197      self._values = []
2198    self._in_graph_mode = not context.executing_eagerly()
2199    if self._in_graph_mode:
2200      self._graph = ops._get_graph_from_inputs(self._values)  # pylint: disable=protected-access
2201    self._cached_pure_variable_scope = None
2202    self._current_name_scope = None
2203    if not isinstance(auxiliary_name_scope, bool):
2204      raise TypeError("The auxiliary_name_scope must be `True` or `False`, "
2205                      "while get {}".format(auxiliary_name_scope))
2206    self._auxiliary_name_scope = auxiliary_name_scope
2207
2208  def __enter__(self):
2209    # If the default graph is building a function, then we should not replace it
2210    # with the cached graph.
2211    if ops.get_default_graph().building_function:
2212      self._building_function = True
2213    else:
2214      self._building_function = False
2215    if self._in_graph_mode and not self._building_function:
2216      self._graph_context_manager = self._graph.as_default()
2217      self._graph_context_manager.__enter__()
2218    if self._cached_pure_variable_scope is not None:
2219      # Fast path for re-entering variable_scopes. We've held on to the pure
2220      # variable scope from a previous successful __enter__, so we avoid some
2221      # overhead by re-using that object.
2222      if self._current_name_scope is not None:
2223        self._current_name_scope.__enter__()
2224      return self._cached_pure_variable_scope.__enter__()
2225
2226    try:
2227      return self._enter_scope_uncached()
2228    except Exception:
2229      if self._in_graph_mode and not self._building_function:
2230        if self._graph_context_manager is not None:
2231          self._graph_context_manager.__exit__(*sys.exc_info())
2232      raise
2233
2234  def _enter_scope_uncached(self):
2235    """Enters the context manager when there is no cached scope yet.
2236
2237    Returns:
2238      The entered variable scope.
2239
2240    Raises:
2241      TypeError: A wrong type is passed as `scope` at __init__().
2242      ValueError: `reuse` is incorrectly set at __init__().
2243    """
2244    if self._auxiliary_name_scope:
2245      # Create a new name scope later
2246      current_name_scope = None
2247    else:
2248      # Reenter the current name scope
2249      name_scope = ops.get_name_scope()
2250      if name_scope:
2251        # Hack to reenter
2252        name_scope += "/"
2253        current_name_scope = ops.name_scope(name_scope)
2254      else:
2255        # Root scope
2256        current_name_scope = ops.name_scope(name_scope)
2257
2258    # IMPORTANT: Only assign to self._cached_pure_variable_scope and
2259    # self._current_name_scope after successful __enter__() calls.
2260    if self._name_or_scope is not None:
2261      if not isinstance(self._name_or_scope,
2262                        (VariableScope,) + six.string_types):
2263        raise TypeError("VariableScope: name_or_scope must be a string or "
2264                        "VariableScope.")
2265      if isinstance(self._name_or_scope, six.string_types):
2266        name_scope = self._name_or_scope
2267      else:
2268        name_scope = self._name_or_scope.name.split("/")[-1]
2269      if name_scope or current_name_scope:
2270        current_name_scope = current_name_scope or ops.name_scope(name_scope)
2271        try:
2272          current_name_scope_name = current_name_scope.__enter__()
2273        except:
2274          current_name_scope.__exit__(*sys.exc_info())
2275          raise
2276        self._current_name_scope = current_name_scope
2277        if isinstance(self._name_or_scope, six.string_types):
2278          old_name_scope = current_name_scope_name
2279        else:
2280          old_name_scope = self._name_or_scope.original_name_scope
2281        pure_variable_scope = _pure_variable_scope(
2282            self._name_or_scope,
2283            reuse=self._reuse,
2284            initializer=self._initializer,
2285            regularizer=self._regularizer,
2286            caching_device=self._caching_device,
2287            partitioner=self._partitioner,
2288            custom_getter=self._custom_getter,
2289            old_name_scope=old_name_scope,
2290            dtype=self._dtype,
2291            use_resource=self._use_resource,
2292            constraint=self._constraint)
2293        try:
2294          entered_pure_variable_scope = pure_variable_scope.__enter__()
2295        except:
2296          pure_variable_scope.__exit__(*sys.exc_info())
2297          raise
2298        self._cached_pure_variable_scope = pure_variable_scope
2299        return entered_pure_variable_scope
2300      else:
2301        self._current_name_scope = None
2302        # This can only happen if someone is entering the root variable scope.
2303        pure_variable_scope = _pure_variable_scope(
2304            self._name_or_scope,
2305            reuse=self._reuse,
2306            initializer=self._initializer,
2307            regularizer=self._regularizer,
2308            caching_device=self._caching_device,
2309            partitioner=self._partitioner,
2310            custom_getter=self._custom_getter,
2311            dtype=self._dtype,
2312            use_resource=self._use_resource,
2313            constraint=self._constraint)
2314        try:
2315          entered_pure_variable_scope = pure_variable_scope.__enter__()
2316        except:
2317          pure_variable_scope.__exit__(*sys.exc_info())
2318          raise
2319        self._cached_pure_variable_scope = pure_variable_scope
2320        return entered_pure_variable_scope
2321
2322    else:  # Here name_or_scope is None. Using default name, but made unique.
2323      if self._reuse:
2324        raise ValueError("reuse=True cannot be used without a name_or_scope")
2325      current_name_scope = current_name_scope or ops.name_scope(
2326          self._default_name)
2327      try:
2328        current_name_scope_name = current_name_scope.__enter__()
2329      except:
2330        current_name_scope.__exit__(*sys.exc_info())
2331        raise
2332      self._current_name_scope = current_name_scope
2333      unique_default_name = _get_unique_variable_scope(self._default_name)
2334      pure_variable_scope = _pure_variable_scope(
2335          unique_default_name,
2336          initializer=self._initializer,
2337          regularizer=self._regularizer,
2338          caching_device=self._caching_device,
2339          partitioner=self._partitioner,
2340          custom_getter=self._custom_getter,
2341          old_name_scope=current_name_scope_name,
2342          dtype=self._dtype,
2343          use_resource=self._use_resource,
2344          constraint=self._constraint)
2345      try:
2346        entered_pure_variable_scope = pure_variable_scope.__enter__()
2347      except:
2348        pure_variable_scope.__exit__(*sys.exc_info())
2349        raise
2350      self._cached_pure_variable_scope = pure_variable_scope
2351      return entered_pure_variable_scope
2352
2353  def __exit__(self, type_arg, value_arg, traceback_arg):
2354    self._cached_pure_variable_scope.__exit__(
2355        type_arg, value_arg, traceback_arg)
2356    if self._current_name_scope:
2357      self._current_name_scope.__exit__(type_arg, value_arg, traceback_arg)
2358    if self._in_graph_mode and not self._building_function:
2359      self._graph_context_manager.__exit__(type_arg, value_arg, traceback_arg)
2360
2361
2362# pylint: disable=g-doc-return-or-yield
2363@tf_export(v1=["variable_op_scope"])
2364@tf_contextlib.contextmanager
2365def variable_op_scope(values,
2366                      name_or_scope,
2367                      default_name=None,
2368                      initializer=None,
2369                      regularizer=None,
2370                      caching_device=None,
2371                      partitioner=None,
2372                      custom_getter=None,
2373                      reuse=None,
2374                      dtype=None,
2375                      use_resource=None,
2376                      constraint=None):
2377  """Deprecated: context manager for defining an op that creates variables."""
2378  logging.warn("tf.variable_op_scope(values, name, default_name) is deprecated,"
2379               " use tf.variable_scope(name, default_name, values)")
2380  with variable_scope(name_or_scope,
2381                      default_name=default_name,
2382                      values=values,
2383                      initializer=initializer,
2384                      regularizer=regularizer,
2385                      caching_device=caching_device,
2386                      partitioner=partitioner,
2387                      custom_getter=custom_getter,
2388                      reuse=reuse,
2389                      dtype=dtype,
2390                      use_resource=use_resource,
2391                      constraint=constraint) as scope:
2392    yield scope
2393
2394
2395def _call_partitioner(partitioner, shape, dtype):
2396  """Call partitioner validating its inputs/output.
2397
2398  Args:
2399    partitioner: a function mapping `Tensor` shape and dtype to a
2400        list of partitions.
2401    shape: shape of the `Tensor` to partition, must have at least two
2402        dimensions.
2403    dtype: dtype of the elements in the `Tensor`.
2404
2405  Returns:
2406    A list with elements >=1 and exactly one >1. The index of that
2407    element corresponds to the partitioning axis.
2408  """
2409  if not shape.is_fully_defined():
2410    raise ValueError("Shape of a new partitioned variable must be "
2411                     "fully defined, but instead was %s." % (shape,))
2412  if shape.ndims < 1:
2413    raise ValueError("A partitioned Variable must have rank at least 1, "
2414                     "shape: %s" % shape)
2415
2416  slicing = partitioner(shape=shape, dtype=dtype)
2417  if not isinstance(slicing, collections_lib.Sequence):
2418    raise ValueError("Partitioner must return a sequence, but saw: %s"
2419                     % slicing)
2420  if len(slicing) != shape.ndims:
2421    raise ValueError(
2422        "Partitioner returned a partition list that does not match the "
2423        "Variable's rank: %s vs. %s" % (slicing, shape))
2424  if any(p < 1 for p in slicing):
2425    raise ValueError(
2426        "Partitioner returned zero partitions for some axes: %s" %
2427        slicing)
2428  if sum(p > 1 for p in slicing) > 1:
2429    raise ValueError(
2430        "Can only slice a variable along one dimension: "
2431        "shape: %s, partitioning: %s" % (shape, slicing))
2432  return slicing
2433
2434
2435# TODO(slebedev): could be inlined, but
2436# `_VariableStore._get_partitioned_variable` is too complex even
2437# without this logic.
2438def _get_slice_dim_and_num_slices(slicing):
2439  """Get slicing dimension and number of slices from the partitioner output."""
2440  for slice_dim, num_slices in enumerate(slicing):
2441    if num_slices > 1:
2442      break
2443  else:
2444    # Degenerate case: no partitioning applied.
2445    slice_dim = 0
2446    num_slices = 1
2447  return slice_dim, num_slices
2448
2449
2450def _iter_slices(full_shape, num_slices, slice_dim):
2451  """Slices a given a shape along the specified dimension."""
2452  num_slices_with_excess = full_shape[slice_dim] % num_slices
2453  offset = [0] * len(full_shape)
2454  min_slice_len = full_shape[slice_dim] // num_slices
2455  for i in xrange(num_slices):
2456    shape = full_shape[:]
2457    shape[slice_dim] = min_slice_len + bool(i < num_slices_with_excess)
2458    yield offset[:], shape
2459    offset[slice_dim] += shape[slice_dim]
2460
2461
2462def _get_trainable_value(synchronization, trainable):
2463  """Computes the trainable value based on the given arguments."""
2464  if synchronization == VariableSynchronization.ON_READ:
2465    if trainable:
2466      raise ValueError(
2467          "Synchronization value can be set to "
2468          "VariableSynchronization.ON_READ only for non-trainable variables. "
2469          "You have specified trainable=True and "
2470          "synchronization=VariableSynchronization.ON_READ.")
2471    else:
2472      # Set trainable to be false when variable is to be synced on read.
2473      trainable = False
2474  elif trainable is None:
2475    trainable = True
2476  return trainable
2477
2478
2479def default_variable_creator(next_creator=None, **kwargs):
2480  """Default variable creator."""
2481  assert next_creator is None
2482  initial_value = kwargs.get("initial_value", None)
2483  trainable = kwargs.get("trainable", None)
2484  collections = kwargs.get("collections", None)
2485  validate_shape = kwargs.get("validate_shape", True)
2486  caching_device = kwargs.get("caching_device", None)
2487  name = kwargs.get("name", None)
2488  variable_def = kwargs.get("variable_def", None)
2489  dtype = kwargs.get("dtype", None)
2490  expected_shape = kwargs.get("expected_shape", None)
2491  import_scope = kwargs.get("import_scope", None)
2492  constraint = kwargs.get("constraint", None)
2493  use_resource = kwargs.get("use_resource", None)
2494
2495  # Set trainable value based on synchronization value.
2496  synchronization = kwargs.get("synchronization", VariableSynchronization.AUTO)
2497  trainable = _get_trainable_value(
2498      synchronization=synchronization, trainable=trainable)
2499
2500  if use_resource is None:
2501    use_resource = get_variable_scope().use_resource
2502  if use_resource is None:
2503    use_resource = _DEFAULT_USE_RESOURCE
2504  use_resource = use_resource or context.executing_eagerly()
2505  if use_resource:
2506    distribute_strategy = kwargs.get("distribute_strategy", None)
2507    return resource_variable_ops.ResourceVariable(
2508        initial_value=initial_value, trainable=trainable,
2509        collections=collections, validate_shape=validate_shape,
2510        caching_device=caching_device, name=name, dtype=dtype,
2511        constraint=constraint, variable_def=variable_def,
2512        import_scope=import_scope, distribute_strategy=distribute_strategy)
2513  else:
2514    return variables.RefVariable(
2515        initial_value=initial_value, trainable=trainable,
2516        collections=collections, validate_shape=validate_shape,
2517        caching_device=caching_device, name=name, dtype=dtype,
2518        constraint=constraint, variable_def=variable_def,
2519        expected_shape=expected_shape, import_scope=import_scope)
2520
2521
2522def default_variable_creator_v2(next_creator=None, **kwargs):
2523  """Default variable creator."""
2524  assert next_creator is None
2525  initial_value = kwargs.get("initial_value", None)
2526  trainable = kwargs.get("trainable", None)
2527  validate_shape = kwargs.get("validate_shape", True)
2528  caching_device = kwargs.get("caching_device", None)
2529  name = kwargs.get("name", None)
2530  variable_def = kwargs.get("variable_def", None)
2531  dtype = kwargs.get("dtype", None)
2532  import_scope = kwargs.get("import_scope", None)
2533  constraint = kwargs.get("constraint", None)
2534  distribute_strategy = kwargs.get("distribute_strategy", None)
2535
2536  # Set trainable value based on synchronization value.
2537  synchronization = kwargs.get("synchronization", VariableSynchronization.AUTO)
2538  trainable = _get_trainable_value(
2539      synchronization=synchronization, trainable=trainable)
2540
2541  return resource_variable_ops.ResourceVariable(
2542      initial_value=initial_value, trainable=trainable,
2543      validate_shape=validate_shape, caching_device=caching_device,
2544      name=name, dtype=dtype, constraint=constraint, variable_def=variable_def,
2545      import_scope=import_scope, distribute_strategy=distribute_strategy)
2546
2547
2548variables.default_variable_creator = default_variable_creator
2549variables.default_variable_creator_v2 = default_variable_creator_v2
2550
2551
2552def _make_getter(captured_getter, captured_previous):
2553  """Gets around capturing loop variables in python being broken."""
2554  return lambda **kwargs: captured_getter(captured_previous, **kwargs)
2555
2556
2557# TODO(apassos) remove forwarding symbol
2558variable = variables.VariableV1
2559
2560
2561@tf_export(v1=["variable_creator_scope"])
2562@tf_contextlib.contextmanager
2563def variable_creator_scope_v1(variable_creator):
2564  """Scope which defines a variable creation function to be used by variable().
2565
2566  variable_creator is expected to be a function with the following signature:
2567
2568  ```
2569    def variable_creator(next_creator, **kwargs)
2570  ```
2571
2572  The creator is supposed to eventually call the next_creator to create a
2573  variable if it does want to create a variable and not call Variable or
2574  ResourceVariable directly. This helps make creators composable. A creator may
2575  choose to create multiple variables, return already existing variables, or
2576  simply register that a variable was created and defer to the next creators in
2577  line. Creators can also modify the keyword arguments seen by the next
2578  creators.
2579
2580  Custom getters in the variable scope will eventually resolve down to these
2581  custom creators when they do create variables.
2582
2583  The valid keyword arguments in kwds are:
2584      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
2585        which is the initial value for the Variable. The initial value must have
2586        a shape specified unless `validate_shape` is set to False. Can also be a
2587        callable with no argument that returns the initial value when called. In
2588        that case, `dtype` must be specified. (Note that initializer functions
2589        from init_ops.py must first be bound to a shape before being used here.)
2590      trainable: If `True`, the default, also adds the variable to the graph
2591        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
2592        the default list of variables to use by the `Optimizer` classes.
2593        `trainable` defaults to `True` unless `synchronization` is
2594        set to `ON_READ`.
2595      collections: List of graph collections keys. The new variable is added to
2596        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
2597      validate_shape: If `False`, allows the variable to be initialized with a
2598        value of unknown shape. If `True`, the default, the shape of
2599        `initial_value` must be known.
2600      caching_device: Optional device string describing where the Variable
2601        should be cached for reading.  Defaults to the Variable's device.
2602        If not `None`, caches on another device.  Typical use is to cache
2603        on the device where the Ops using the Variable reside, to deduplicate
2604        copying through `Switch` and other conditional statements.
2605      name: Optional name for the variable. Defaults to `'Variable'` and gets
2606        uniquified automatically.
2607      dtype: If set, initial_value will be converted to the given type.
2608        If `None`, either the datatype will be kept (if `initial_value` is
2609        a Tensor), or `convert_to_tensor` will decide.
2610      constraint: A constraint function to be applied to the variable after
2611        updates by some algorithms.
2612      use_resource: if True, a ResourceVariable is always created.
2613      synchronization: Indicates when a distributed a variable will be
2614        aggregated. Accepted values are constants defined in the class
2615        `tf.VariableSynchronization`. By default the synchronization is set to
2616        `AUTO` and the current `DistributionStrategy` chooses
2617        when to synchronize. If `synchronization` is set to `ON_READ`,
2618        `trainable` must not be set to `True`.
2619      aggregation: Indicates how a distributed variable will be aggregated.
2620        Accepted values are constants defined in the class
2621        `tf.VariableAggregation`.
2622
2623  This set may grow over time, so it's important the signature of creators is as
2624  mentioned above.
2625
2626  Args:
2627    variable_creator: the passed creator
2628
2629  Yields:
2630    A scope in which the creator is active
2631  """
2632  with ops.get_default_graph()._variable_creator_scope(variable_creator):  # pylint: disable=protected-access
2633    yield
2634
2635
2636# Note: only the docstrings differ between this and v1.
2637@tf_export("variable_creator_scope", v1=[])
2638@tf_contextlib.contextmanager
2639def variable_creator_scope(variable_creator):
2640  """Scope which defines a variable creation function to be used by variable().
2641
2642  variable_creator is expected to be a function with the following signature:
2643
2644  ```
2645    def variable_creator(next_creator, **kwargs)
2646  ```
2647
2648  The creator is supposed to eventually call the next_creator to create a
2649  variable if it does want to create a variable and not call Variable or
2650  ResourceVariable directly. This helps make creators composable. A creator may
2651  choose to create multiple variables, return already existing variables, or
2652  simply register that a variable was created and defer to the next creators in
2653  line. Creators can also modify the keyword arguments seen by the next
2654  creators.
2655
2656  Custom getters in the variable scope will eventually resolve down to these
2657  custom creators when they do create variables.
2658
2659  The valid keyword arguments in kwds are:
2660      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
2661        which is the initial value for the Variable. The initial value must have
2662        a shape specified unless `validate_shape` is set to False. Can also be a
2663        callable with no argument that returns the initial value when called. In
2664        that case, `dtype` must be specified. (Note that initializer functions
2665        from init_ops.py must first be bound to a shape before being used here.)
2666      trainable: If `True`, the default, GradientTapes automatically watch
2667        uses of this Variable.
2668      validate_shape: If `False`, allows the variable to be initialized with a
2669        value of unknown shape. If `True`, the default, the shape of
2670        `initial_value` must be known.
2671      caching_device: Optional device string describing where the Variable
2672        should be cached for reading.  Defaults to the Variable's device.
2673        If not `None`, caches on another device.  Typical use is to cache
2674        on the device where the Ops using the Variable reside, to deduplicate
2675        copying through `Switch` and other conditional statements.
2676      name: Optional name for the variable. Defaults to `'Variable'` and gets
2677        uniquified automatically.
2678      dtype: If set, initial_value will be converted to the given type.
2679        If `None`, either the datatype will be kept (if `initial_value` is
2680        a Tensor), or `convert_to_tensor` will decide.
2681      constraint: A constraint function to be applied to the variable after
2682        updates by some algorithms.
2683      synchronization: Indicates when a distributed a variable will be
2684        aggregated. Accepted values are constants defined in the class
2685        `tf.VariableSynchronization`. By default the synchronization is set to
2686        `AUTO` and the current `DistributionStrategy` chooses
2687        when to synchronize. If `synchronization` is set to `ON_READ`,
2688        `trainable` must not be set to `True`.
2689      aggregation: Indicates how a distributed variable will be aggregated.
2690        Accepted values are constants defined in the class
2691        `tf.VariableAggregation`.
2692
2693  This set may grow over time, so it's important the signature of creators is as
2694  mentioned above.
2695
2696  Args:
2697    variable_creator: the passed creator
2698
2699  Yields:
2700    A scope in which the creator is active
2701  """
2702  with ops.get_default_graph()._variable_creator_scope(variable_creator):  # pylint: disable=protected-access
2703    yield
2704