• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A class to store named variables and a scope operator to manage sharing."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import enum  # pylint: disable=g-bad-import-order
23import functools
24import sys
25import threading
26import traceback
27
28import six
29from six import iteritems
30from six.moves import xrange, zip  # pylint: disable=redefined-builtin
31
32from tensorflow.python import tf2
33from tensorflow.python.client import session
34from tensorflow.python.eager import context
35from tensorflow.python.eager import monitoring
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_shape
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import init_ops
41from tensorflow.python.ops import resource_variable_ops
42from tensorflow.python.ops import variables
43from tensorflow.python.platform import tf_logging as logging
44from tensorflow.python.types import core
45from tensorflow.python.util import deprecation
46from tensorflow.python.util import function_utils
47from tensorflow.python.util import tf_contextlib
48from tensorflow.python.util import tf_inspect
49from tensorflow.python.util.compat import collections_abc
50from tensorflow.python.util.tf_export import tf_export
51
52__all__ = [
53    "AUTO_REUSE", "VariableScope", "get_variable_scope", "get_variable",
54    "get_local_variable", "variable_scope", "variable_op_scope",
55    "no_regularizer", "VariableSynchronization", "VariableAggregation"
56]
57
58_api_usage_gauge = monitoring.BoolGauge(
59    "/tensorflow/api/resource_variables",
60    "Whether variable_scope.enable_resource_variables() is called.")
61
62
63class _PartitionInfo(object):
64  """Holds partition info used by initializer functions."""
65
66  __slots__ = ["_full_shape", "_var_offset"]
67
68  def __init__(self, full_shape, var_offset):
69    """Constructor.
70
71    Args:
72      full_shape: Tuple or list of `int` indicating the full combined shape of
73        the partitioned variables.
74      var_offset: Tuple or list of `int` specifying offset of this partition
75        with respect to the full variable for each dimension.
76
77    Raises:
78      TypeError: If `full_shape` or `var_offset` is not a sequence.
79      ValueError: If `full_shape` or `var_offset` differ in length. If
80        `var_offset` exceeds `full_shape` in any dimension.
81    """
82    if not isinstance(full_shape, collections_abc.Sequence) or isinstance(
83        full_shape, six.string_types):
84      raise TypeError(
85          "`full_shape` must be a sequence (like tuple or list) instead of " +
86          type(full_shape).__name__)
87
88    if not isinstance(var_offset, collections_abc.Sequence) or isinstance(
89        var_offset, six.string_types):
90      raise TypeError(
91          "`var_offset` must be a sequence (like tuple or list) instead of " +
92          type(var_offset).__name__)
93
94    if len(var_offset) != len(full_shape):
95      raise ValueError(
96          "Expected equal length, but `var_offset` is of length {} while "
97          "full_shape is of length {}.".format(
98              len(var_offset), len(full_shape)))
99
100    for offset, shape in zip(var_offset, full_shape):
101      if offset < 0 or offset >= shape:
102        raise ValueError(
103            "Expected 0 <= offset < shape but found offset={}, shape={} for "
104            "var_offset={}, full_shape={}".format(offset, shape, var_offset,
105                                                  full_shape))
106
107    self._full_shape = full_shape
108    self._var_offset = var_offset
109
110  @property
111  def full_shape(self):
112    return self._full_shape
113
114  @property
115  def var_offset(self):
116    return self._var_offset
117
118  def single_offset(self, shape):
119    """Returns the offset when the variable is partitioned in at most one dim.
120
121    Args:
122      shape: Tuple or list of `int` indicating the shape of one specific
123        variable partition.
124
125    Returns:
126      `int` representing the offset in the dimension along which the variable is
127       partitioned. Returns 0 if the variable is not being partitioned.
128
129    Raises:
130      ValueError: Depending on self.single_slice_dim().
131    """
132
133    single_slice_dim = self.single_slice_dim(shape)
134    # If this variable is not being partitioned at all, single_slice_dim() could
135    # return None.
136    if single_slice_dim is None:
137      return 0
138    return self.var_offset[single_slice_dim]
139
140  def single_slice_dim(self, shape):
141    """Returns the slice dim when the variable is partitioned only in one dim.
142
143    Args:
144      shape: Tuple or list of `int` indicating the shape of one specific
145        variable partition.
146
147    Returns:
148      `int` representing the dimension that the variable is partitioned in, or
149      `None` if the variable doesn't seem to be partitioned at all.
150
151    Raises:
152      TypeError: If `shape` is not a sequence.
153      ValueError: If `shape` is not the same length as `self.full_shape`. If
154        the variable is partitioned in more than one dimension.
155    """
156    if not isinstance(shape, collections_abc.Sequence) or isinstance(
157        shape, six.string_types):
158      raise TypeError(
159          "`shape` must be a sequence (like tuple or list) instead of " +
160          type(shape).__name__)
161
162    if len(shape) != len(self.full_shape):
163      raise ValueError(
164          "Expected equal length, but received shape={} of length {} while "
165          "self.full_shape={} is of length {}.".format(shape, len(shape),
166                                                       self.full_shape,
167                                                       len(self.full_shape)))
168
169    for i in xrange(len(shape)):
170      if self.var_offset[i] + shape[i] > self.full_shape[i]:
171        raise ValueError(
172            "With self.var_offset={}, a partition of shape={} would exceed "
173            "self.full_shape={} in dimension {}.".format(
174                self.var_offset, shape, self.full_shape, i))
175
176    slice_dim = None
177    for i in xrange(len(shape)):
178      if shape[i] == self.full_shape[i]:
179        continue
180      if slice_dim is not None:
181        raise ValueError(
182            "Cannot use single_slice_dim() with shape={} and "
183            "self.full_shape={} since slice dim could be either dimension {} "
184            "or {}.".format(shape, self.full_shape, i, slice_dim))
185      slice_dim = i
186
187    return slice_dim
188
189
190class _ReuseMode(enum.Enum):
191  """Mode for variable access within a variable scope."""
192
193  # Indicates that variables are to be fetched if they already exist or
194  # otherwise created.
195  AUTO_REUSE = 1
196
197  # TODO(alive): For TensorFlow 2.0, Deprecate True/False/None API in favor of
198  #              enum values.
199  # REUSE_FALSE = 2
200  # REUSE_TRUE = 3
201
202
203# TODO(apassos) remove these forwarding symbols.
204VariableSynchronization = variables.VariableSynchronization  # pylint: disable=invalid-name
205VariableAggregation = variables.VariableAggregation  # pylint: disable=invalid-name
206
207AUTO_REUSE = _ReuseMode.AUTO_REUSE
208tf_export(v1=["AUTO_REUSE"]).export_constant(__name__, "AUTO_REUSE")
209AUTO_REUSE.__doc__ = """
210When passed in as the value for the `reuse` flag, AUTO_REUSE indicates that
211get_variable() should create the requested variable if it doesn't exist or, if
212it does exist, simply return it.
213"""
214
215_DEFAULT_USE_RESOURCE = tf2.enabled()
216
217
218@tf_export(v1=["enable_resource_variables"])
219def enable_resource_variables():
220  """Creates resource variables by default.
221
222  Resource variables are improved versions of TensorFlow variables with a
223  well-defined memory model. Accessing a resource variable reads its value, and
224  all ops which access a specific read value of the variable are guaranteed to
225  see the same value for that tensor. Writes which happen after a read (by
226  having a control or data dependency on the read) are guaranteed not to affect
227  the value of the read tensor, and similarly writes which happen before a read
228  are guaranteed to affect the value. No guarantees are made about unordered
229  read/write pairs.
230
231  Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0
232  feature.
233  """
234  global _DEFAULT_USE_RESOURCE
235  _DEFAULT_USE_RESOURCE = True
236  _api_usage_gauge.get_cell().set(True)
237
238
239@tf_export(v1=["resource_variables_enabled"])
240def resource_variables_enabled():
241  """Returns `True` if resource variables are enabled.
242
243  Resource variables are improved versions of TensorFlow variables with a
244  well-defined memory model. Accessing a resource variable reads its value, and
245  all ops which access a specific read value of the variable are guaranteed to
246  see the same value for that tensor. Writes which happen after a read (by
247  having a control or data dependency on the read) are guaranteed not to affect
248  the value of the read tensor, and similarly writes which happen before a read
249  are guaranteed to affect the value. No guarantees are made about unordered
250  read/write pairs.
251
252  Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0
253  feature.
254  """
255  global _DEFAULT_USE_RESOURCE
256  return _DEFAULT_USE_RESOURCE
257
258
259@deprecation.deprecated(
260    None, "non-resource variables are not supported in the long term")
261@tf_export(v1=["disable_resource_variables"])
262def disable_resource_variables():
263  """Opts out of resource variables.
264
265  If your code needs tf.disable_resource_variables() to be called to work
266  properly please file a bug.
267  """
268  global _DEFAULT_USE_RESOURCE
269  _DEFAULT_USE_RESOURCE = False
270  _api_usage_gauge.get_cell().set(False)
271
272
273def _needs_no_arguments(python_callable):
274  """Returns true if the callable needs no arguments to call."""
275  # TODO(bfontain): Switch to inspect.signature when we are python 3 only.
276  # signature = inspect.signature(python_callable)
277  # return not [1 for param in signature.parameters.values()
278  #             if param.default == param.empty]
279  num_arguments = len(tf_inspect.getargspec(python_callable).args)
280  if not tf_inspect.isfunction(python_callable) and not isinstance(
281      python_callable, functools.partial):
282    # getargspec includes self for function objects (which aren't
283    # functools.partial). This has no default so we need to remove it.
284    # It is not even an argument so its odd that getargspec returns this.
285    # Note that this is fixed with inspect.signature in Python 3.
286    num_arguments -= 1
287  return num_arguments == len(
288      tf_inspect.getargspec(python_callable).defaults or [])
289
290
291class _VariableStore(object):
292  """Variable store that carries a number of named Variables.
293
294  New variable names and new variables can be created; all stored
295  variables are initialized with the initializer passed to __init__.
296
297  Attributes:
298    vars: a dictionary with string names (same as passed in GetVar) as keys and
299      the corresponding TensorFlow Variables as values.
300  """
301
302  __slots__ = ["_vars", "_partitioned_vars", "_store_eager_variables"]
303
304  def __init__(self):
305    """Create a variable store."""
306    self._vars = {}  # A dictionary of the stored TensorFlow variables.
307    self._partitioned_vars = {}  # A dict of the stored PartitionedVariables.
308    self._store_eager_variables = False
309
310  def get_variable(self,
311                   name,
312                   shape=None,
313                   dtype=dtypes.float32,
314                   initializer=None,
315                   regularizer=None,
316                   reuse=None,
317                   trainable=None,
318                   collections=None,
319                   caching_device=None,
320                   partitioner=None,
321                   validate_shape=True,
322                   use_resource=None,
323                   custom_getter=None,
324                   constraint=None,
325                   synchronization=VariableSynchronization.AUTO,
326                   aggregation=VariableAggregation.NONE):
327    """Gets an existing variable with these parameters or create a new one.
328
329    If a variable with the given name is already stored, we return the stored
330    variable. Otherwise, we create a new one.
331
332    Set `reuse` to `True` when you only want to reuse existing Variables.
333    Set `reuse` to `False` when you only want to create new Variables.
334    Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want
335    variables to be created if they don't exist or returned if they do.
336
337    If initializer is `None` (the default), the default initializer passed in
338    the constructor is used. If that one is `None` too, we use a new
339    `glorot_uniform_initializer`. If initializer is a Tensor, we use
340    it as a value and derive the shape from the initializer.
341
342    If a partitioner is provided, a `PartitionedVariable` is returned.
343    Accessing this object as a `Tensor` returns the shards concatenated along
344    the partition axis.
345
346    Some useful partitioners are available.  See, e.g.,
347    `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
348
349    Args:
350      name: The name of the new or existing variable.
351      shape: Shape of the new or existing variable.
352      dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
353      initializer: Initializer for the variable.
354      regularizer: A (Tensor -> Tensor or None) function; the result of applying
355        it on a newly created variable will be added to the collection
356        GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
357      reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of
358        variables. When eager execution is enabled  this argument is always
359        forced to be False.
360      trainable: If `True` also add the variable to the graph collection
361        `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). `trainable`
362        defaults to `True`, unless `synchronization` is set to `ON_READ`, in
363        which case it defaults to `False`.
364      collections: List of graph collections keys to add the `Variable` to.
365        Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
366      caching_device: Optional device string or function describing where the
367        Variable should be cached for reading.  Defaults to the Variable's
368        device.  If not `None`, caches on another device.  Typical use is to
369        cache on the device where the Ops using the `Variable` reside, to
370        deduplicate copying through `Switch` and other conditional statements.
371      partitioner: Optional callable that accepts a fully defined `TensorShape`
372        and dtype of the `Variable` to be created, and returns a list of
373        partitions for each axis (currently only one axis can be partitioned).
374      validate_shape: If False, allows the variable to be initialized with a
375        value of unknown shape. If True, the default, the shape of initial_value
376        must be known.
377      use_resource: If False, creates a regular Variable. If True, creates
378        instead an experimental ResourceVariable which has well-defined
379        semantics. Defaults to False (will later change to True). When eager
380        execution is enabled this argument is always forced to be true.
381      custom_getter: Callable that takes as a first argument the true getter,
382        and allows overwriting the internal get_variable method. The signature
383        of `custom_getter` should match that of this method,
384        but the most future-proof version will allow for changes: `def
385          custom_getter(getter, *args, **kwargs)`.  Direct access to
386        all `get_variable` parameters is also allowed: `def
387          custom_getter(getter, name, *args, **kwargs)`.  A simple identity
388        custom getter that simply creates variables with modified names is:
389          ```python
390        def custom_getter(getter, name, *args, **kwargs): return getter(name +
391          '_suffix', *args, **kwargs) ```
392      constraint: An optional projection function to be applied to the variable
393        after being updated by an `Optimizer` (e.g. used to implement norm
394        constraints or value constraints for layer weights). The function must
395        take as input the unprojected Tensor representing the value of the
396        variable and return the Tensor for the projected value (which must have
397        the same shape). Constraints are not safe to use when doing asynchronous
398        distributed training.
399      synchronization: Indicates when a distributed a variable will be
400        aggregated. Accepted values are constants defined in the class
401        `tf.VariableSynchronization`. By default the synchronization is set to
402        `AUTO` and the current `DistributionStrategy` chooses when to
403        synchronize.
404      aggregation: Indicates how a distributed variable will be aggregated.
405        Accepted values are constants defined in the class
406        `tf.VariableAggregation`.
407
408    Returns:
409      The created or existing `Variable` (or `PartitionedVariable`, if a
410      partitioner was used).
411
412    Raises:
413      ValueError: when creating a new variable and shape is not declared,
414        when reusing a variable and specifying a conflicting shape,
415        or when violating reuse during variable creation.
416      RuntimeError: when eager execution is enabled and not called from an
417        EagerVariableStore.
418    """
419    if custom_getter is not None and not callable(custom_getter):
420      raise ValueError("Passed a custom_getter which is not callable: %s" %
421                       custom_getter)
422
423    with ops.init_scope():
424      if context.executing_eagerly():
425        # Variable creation and initialization takes place in `init_scope`s;
426        # as such, if an `init_scope` lifts us into the eager context, then we
427        # need to use `ResourceVariable`s.
428        use_resource = True
429
430    # Note that it's fine to reuse eager variables whose initialization was
431    # lifted from a function-building graph into the eager context (that's why
432    # the following clause is not wrapped in an `init_scope`); lifted variables
433    # are tracked by the graph's `VariableStore`.
434    if context.executing_eagerly():
435      if not self._store_eager_variables and reuse:
436        raise RuntimeError(
437            "When eager execution is enabled variable reuse is only supported"
438            " when an EagerVariableStore is active. See the documentation on"
439            " EagerVariableStore for example usage.")
440      if self._store_eager_variables:
441        reuse = AUTO_REUSE
442
443    # If a *_ref type is passed in an error would be triggered further down the
444    # stack. We prevent this using base_dtype to get a non-ref version of the
445    # type, before doing anything else. When _ref types are removed in favor of
446    # resources, this line can be removed.
447    try:
448      dtype = dtype.base_dtype
449    except AttributeError:
450      # .base_dtype not existing means that we will try and use the raw dtype
451      # which was passed in - this might be a NumPy type which is valid.
452      pass
453
454    # This is the main logic of get_variable.  However, custom_getter
455    # may override this logic.  So we save it as a callable and pass
456    # it to custom_getter.
457    # Note: the parameters of _true_getter, and their documentation, match
458    # *exactly* item-for-item with the docstring of this method.
459    def _true_getter(  # pylint: disable=missing-docstring
460        name,
461        shape=None,
462        dtype=dtypes.float32,
463        initializer=None,
464        regularizer=None,
465        reuse=None,
466        trainable=None,
467        collections=None,
468        caching_device=None,
469        partitioner=None,
470        validate_shape=True,
471        use_resource=None,
472        constraint=None,
473        synchronization=VariableSynchronization.AUTO,
474        aggregation=VariableAggregation.NONE):
475      is_scalar = (
476          shape is not None and isinstance(shape, collections_abc.Sequence) and
477          not shape)
478      # Partitioned variable case
479      if partitioner is not None and not is_scalar:
480        if not callable(partitioner):
481          raise ValueError("Partitioner must be callable, but received: %s" %
482                           partitioner)
483        with ops.name_scope(None):
484          return self._get_partitioned_variable(
485              name=name,
486              shape=shape,
487              dtype=dtype,
488              initializer=initializer,
489              regularizer=regularizer,
490              reuse=reuse,
491              trainable=trainable,
492              collections=collections,
493              caching_device=caching_device,
494              partitioner=partitioner,
495              validate_shape=validate_shape,
496              use_resource=use_resource,
497              constraint=constraint,
498              synchronization=synchronization,
499              aggregation=aggregation)
500
501      # Special case for partitioned variable to allow reuse without having to
502      # specify partitioner.
503      if (reuse is True and partitioner is None
504          and name in self._partitioned_vars):
505        return self._get_partitioned_variable(
506            name=name,
507            shape=shape,
508            dtype=dtype,
509            initializer=initializer,
510            regularizer=regularizer,
511            reuse=reuse,
512            trainable=trainable,
513            collections=collections,
514            caching_device=caching_device,
515            partitioner=None,
516            validate_shape=validate_shape,
517            use_resource=use_resource,
518            constraint=constraint,
519            synchronization=synchronization,
520            aggregation=aggregation)
521
522      # Single variable case
523      if "%s/part_0" % name in self._vars:
524        raise ValueError(
525            "No partitioner was provided, but a partitioned version of the "
526            "variable was found: %s/part_0. Perhaps a variable of the same "
527            "name was already created with partitioning?" % name)
528
529      return self._get_single_variable(
530          name=name,
531          shape=shape,
532          dtype=dtype,
533          initializer=initializer,
534          regularizer=regularizer,
535          reuse=reuse,
536          trainable=trainable,
537          collections=collections,
538          caching_device=caching_device,
539          validate_shape=validate_shape,
540          use_resource=use_resource,
541          constraint=constraint,
542          synchronization=synchronization,
543          aggregation=aggregation)
544
545    synchronization, aggregation, trainable = (
546        variables.validate_synchronization_aggregation_trainable(
547            synchronization, aggregation, trainable, name))
548
549    if custom_getter is not None:
550      # Handle backwards compatibility with getter arguments that were added
551      # to the API after users started writing custom getters.
552      custom_getter_kwargs = {
553          "getter": _true_getter,
554          "name": name,
555          "shape": shape,
556          "dtype": dtype,
557          "initializer": initializer,
558          "regularizer": regularizer,
559          "reuse": reuse,
560          "trainable": trainable,
561          "collections": collections,
562          "caching_device": caching_device,
563          "partitioner": partitioner,
564          "validate_shape": validate_shape,
565          "use_resource": use_resource,
566          "synchronization": synchronization,
567          "aggregation": aggregation,
568      }
569      # `fn_args` and `has_kwargs` can handle functions, `functools.partial`,
570      # `lambda`.
571      if ("constraint" in function_utils.fn_args(custom_getter) or
572          function_utils.has_kwargs(custom_getter)):
573        custom_getter_kwargs["constraint"] = constraint
574      return custom_getter(**custom_getter_kwargs)
575    else:
576      return _true_getter(
577          name,
578          shape=shape,
579          dtype=dtype,
580          initializer=initializer,
581          regularizer=regularizer,
582          reuse=reuse,
583          trainable=trainable,
584          collections=collections,
585          caching_device=caching_device,
586          partitioner=partitioner,
587          validate_shape=validate_shape,
588          use_resource=use_resource,
589          constraint=constraint,
590          synchronization=synchronization,
591          aggregation=aggregation)
592
593  def _get_partitioned_variable(self,
594                                name,
595                                partitioner,
596                                shape=None,
597                                dtype=dtypes.float32,
598                                initializer=None,
599                                regularizer=None,
600                                reuse=None,
601                                trainable=None,
602                                collections=None,
603                                caching_device=None,
604                                validate_shape=True,
605                                use_resource=None,
606                                constraint=None,
607                                synchronization=VariableSynchronization.AUTO,
608                                aggregation=VariableAggregation.NONE):
609    """Gets or creates a sharded variable list with these parameters.
610
611    The `partitioner` must be a callable that accepts a fully defined
612    `TensorShape` and returns a sequence of integers (the `partitions`).
613    These integers describe how to partition the given sharded `Variable`
614    along the given dimension.  That is, `partitions[1] = 3` means split
615    the `Variable` into 3 shards along dimension 1.  Currently, sharding along
616    only one axis is supported.
617
618    If the list of variables with the given name (prefix) is already stored,
619    we return the stored variables. Otherwise, we create a new one.
620
621    Set `reuse` to `True` when you only want to reuse existing Variables.
622    Set `reuse` to `False` when you only want to create new Variables.
623    Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want
624    variables to be created if they don't exist or returned if they do.
625
626    If initializer is `None` (the default), the default initializer passed in
627    the constructor is used. If that one is `None` too, we use a new
628    `glorot_uniform_initializer`. If initializer is a Tensor, we use
629    it as a value and derive the shape from the initializer.
630
631    If the initializer is a callable, then it will be called for each
632    shard.  Otherwise the initializer should match the shape of the entire
633    sharded Variable, and it will be sliced accordingly for each shard.
634
635    Some useful partitioners are available.  See, e.g.,
636    `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
637
638    Args:
639      name: the name of the new or existing sharded variable.
640      partitioner: Optional callable that accepts a fully defined `TensorShape`
641        and `dtype` of the Variable to be created, and returns a list of
642        partitions for each axis (currently only one axis can be partitioned).
643      shape: shape of the new or existing sharded variable.
644      dtype: type of the new or existing sharded variable (defaults to
645        `DT_FLOAT`).
646      initializer: initializer for the sharded variable.
647      regularizer: a (Tensor -> Tensor or None) function; the result of applying
648        it on a newly created variable will be added to the collection
649        GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
650      reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of
651        variables.
652      trainable: If `True` also add the variable to the graph collection
653        `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
654      collections: List of graph collections keys to add the Variable to.
655        Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
656      caching_device: Optional device string or function describing where the
657        Variable should be cached for reading.  Defaults to the Variable's
658        device.  If not `None`, caches on another device.  Typical use is to
659        cache on the device where the Ops using the Variable reside, to
660        deduplicate copying through `Switch` and other conditional statements.
661      validate_shape: If False, allows the variable to be initialized with a
662        value of unknown shape. If True, the default, the shape of initial_value
663        must be known.
664      use_resource: If False, creates a regular Variable. If True, creates an
665        experimental ResourceVariable which has well-defined semantics. Defaults
666        to False (will later change to True).
667      constraint: An optional projection function to be applied to the variable
668        after being updated by an `Optimizer` (e.g. used to implement norm
669        constraints or value constraints for layer weights). The function must
670        take as input the unprojected Tensor representing the value of the
671        variable and return the Tensor for the projected value (which must have
672        the same shape). Constraints are not safe to use when doing asynchronous
673        distributed training.
674      synchronization: Indicates when a distributed a variable will be
675        aggregated. Accepted values are constants defined in the class
676        `tf.VariableSynchronization`. By default the synchronization is set to
677        `AUTO` and the current `DistributionStrategy` chooses when to
678        synchronize.
679      aggregation: Indicates how a distributed variable will be aggregated.
680        Accepted values are constants defined in the class
681        `tf.VariableAggregation`.
682
683    Returns:
684      A `PartitionedVariable` object.
685
686    Raises:
687      ValueError: when creating a new variable and shape is not declared,
688        when reusing a variable and specifying a conflicting shape,
689        when violating reuse during variable creation, or if an existing
690        sharded variable exists for the given name but with different sharding.
691    """
692    initializing_from_value = initializer is not None and isinstance(
693        initializer, ops.Tensor)
694    if name in self._vars:
695      raise ValueError(
696          "A partitioner was provided, but an unpartitioned version of the "
697          "variable was found: %s.  Perhaps a variable of the same name was "
698          "already created without partitioning?" % name)
699
700    shape = tensor_shape.as_shape(shape)
701    if initializing_from_value:
702      shape = shape.merge_with(initializer.get_shape())
703
704    partitions = None
705    if not reuse or partitioner:
706      partitions = _call_partitioner(partitioner, shape, dtype)
707
708    if name in self._partitioned_vars:
709      if reuse is False:
710        raise ValueError(
711            "Partitioned variable with name %s already exists. Did you mean to "
712            "set reuse=True or reuse=tf.AUTO_REUSE in VarScope?" % name)
713
714      existing_var = self._partitioned_vars[name]
715      if not shape.is_compatible_with(existing_var.get_shape()):
716        raise ValueError(
717            "Trying to reuse partitioned variable %s, but specified shape %s "
718            "and found shape %s." % (name, shape, existing_var.get_shape()))
719      if not dtype.is_compatible_with(existing_var.dtype):
720        raise ValueError(
721            "Trying to reuse partitioned variable %s, but specified dtype %s "
722            "and found dtype %s." % (name, dtype.name, existing_var.dtype.name))
723
724      # pylint: disable=protected-access
725      if (partitions is not None and
726          existing_var._get_partitions() != partitions):
727        raise ValueError(
728            "Trying to reuse partitioned variable %s, but specified partitions "
729            "%s and found partitions %s." %
730            (name, partitions, existing_var._get_partitions()))
731      # pylint: enable=protected-access
732
733      return existing_var
734
735    if reuse is True:
736      raise ValueError("PartitionedVariable %s does not exist, or was not "
737                       "created with tf.get_variable(). Did you mean to set "
738                       "reuse=False or reuse=tf.AUTO_REUSE in VarScope?" % name)
739
740    slice_dim, num_slices = _get_slice_dim_and_num_slices(partitions)
741
742    if "%s/part_0" % name in self._vars:
743      if "%s/part_%d" % (name, num_slices - 1) not in self._vars:
744        raise ValueError(
745            "Partitioner returned a different partitioning than what was "
746            "already found.  Partitioner returned %d shards, and shard "
747            "%s/part_0 was found, but %s/part_%d was not." %
748            (num_slices, name, name, num_slices - 1))
749      if "%s/part_%d" % (name, num_slices) in self._vars:
750        raise ValueError(
751            "Partitioner returned a different partitioning than what was "
752            "already found.  Partitioner returned %d shards, and shard "
753            "%s/part_0 was found, but so was the extra shard %s/part_%d." %
754            (num_slices, name, name, num_slices))
755
756    vs = []
757    for i, (var_offset, var_shape) in enumerate(
758        _iter_slices(shape.as_list(), num_slices, slice_dim)):
759      partition_info = _PartitionInfo(
760          full_shape=shape.as_list(), var_offset=var_offset)
761      var_full_name = "%s/part_%d" % (name, i)
762      with ops.name_scope(
763          var_full_name + "/PartitionedInitializer", skip_on_eager=False):
764        # Create the tensor to initialize the variable with default value.
765        if initializer is None:
766          init, initializing_from_value = self._get_default_initializer(
767              name=name, shape=shape, dtype=dtype)
768          if initializing_from_value:
769            init_shape = None
770          else:
771            init_shape = var_shape
772        elif callable(initializer):
773          init = initializer
774          init_shape = var_shape
775        elif isinstance(initializer, ops.Tensor):
776          init = array_ops.slice(initializer, var_offset, var_shape)
777          # Use the dtype of the given tensor.
778          dtype = init.dtype.base_dtype
779          init_shape = None
780        else:
781          init = ops.convert_to_tensor(initializer, dtype=dtype)
782          init = array_ops.slice(init, var_offset, var_shape)
783          init_shape = None
784
785      with ops.name_scope(None):
786        var = self._get_single_variable(
787            name=var_full_name,
788            shape=init_shape,
789            dtype=dtype,
790            initializer=init,
791            partition_info=partition_info,
792            regularizer=regularizer,
793            reuse=reuse,
794            trainable=trainable,
795            collections=collections,
796            caching_device=caching_device,
797            validate_shape=validate_shape,
798            use_resource=use_resource,
799            constraint=constraint,
800            synchronization=synchronization,
801            aggregation=aggregation)
802
803      # pylint: disable=protected-access
804      var._set_save_slice_info(
805          variables.Variable.SaveSliceInfo(name, shape.as_list(), var_offset,
806                                           var_shape))
807      vs.append(var)
808      # pylint: enable=protected-access
809
810    partitioned_var = variables.PartitionedVariable(
811        name=name,
812        shape=shape,
813        dtype=dtype,
814        variable_list=vs,
815        partitions=partitions)
816    if not context.executing_eagerly() or self._store_eager_variables:
817      self._partitioned_vars[name] = partitioned_var
818    return partitioned_var
819
820  def _get_single_variable(self,
821                           name,
822                           shape=None,
823                           dtype=dtypes.float32,
824                           initializer=None,
825                           regularizer=None,
826                           partition_info=None,
827                           reuse=None,
828                           trainable=None,
829                           collections=None,
830                           caching_device=None,
831                           validate_shape=True,
832                           use_resource=None,
833                           constraint=None,
834                           synchronization=VariableSynchronization.AUTO,
835                           aggregation=VariableAggregation.NONE):
836    """Get or create a single Variable (e.g.
837
838    a shard or entire variable).
839
840    See the documentation of get_variable above (ignore partitioning components)
841    for details.
842
843    Args:
844      name: see get_variable.
845      shape: see get_variable.
846      dtype: see get_variable.
847      initializer: see get_variable.
848      regularizer: see get_variable.
849      partition_info: _PartitionInfo object.
850      reuse: see get_variable.
851      trainable: see get_variable.
852      collections: see get_variable.
853      caching_device: see get_variable.
854      validate_shape: see get_variable.
855      use_resource: see get_variable.
856      constraint: see get_variable.
857      synchronization: see get_variable.
858      aggregation: see get_variable.
859
860    Returns:
861      A Variable.  See documentation of get_variable above.
862
863    Raises:
864      ValueError: See documentation of get_variable above.
865    """
866    # Set to true if initializer is a constant.
867    initializing_from_value = False
868    if initializer is not None and not callable(initializer):
869      initializing_from_value = True
870    if shape is not None and initializing_from_value:
871      raise ValueError("If initializer is a constant, do not specify shape.")
872
873    dtype = dtypes.as_dtype(dtype)
874    shape = tensor_shape.as_shape(shape)
875
876    if name in self._vars:
877      # Here we handle the case when returning an existing variable.
878      if reuse is False:
879        var = self._vars[name]
880        err_msg = ("Variable %s already exists, disallowed."
881                   " Did you mean to set reuse=True or "
882                   "reuse=tf.AUTO_REUSE in VarScope?" % name)
883        # ResourceVariables don't have an op associated with so no traceback
884        if isinstance(var, resource_variable_ops.ResourceVariable):
885          raise ValueError(err_msg)
886        tb = var.op.traceback[::-1]
887        # Throw away internal tf entries and only take a few lines. In some
888        # cases the traceback can be longer (e.g. if someone uses factory
889        # functions to create variables) so we take more than needed in the
890        # default case.
891        tb = [x for x in tb if "tensorflow/python" not in x[0]][:5]
892        raise ValueError("%s Originally defined at:\n\n%s" %
893                         (err_msg, "".join(traceback.format_list(tb))))
894      found_var = self._vars[name]
895      if not shape.is_compatible_with(found_var.get_shape()):
896        raise ValueError("Trying to share variable %s, but specified shape %s"
897                         " and found shape %s." %
898                         (name, shape, found_var.get_shape()))
899      if not dtype.is_compatible_with(found_var.dtype):
900        dtype_str = dtype.name
901        found_type_str = found_var.dtype.name
902        raise ValueError("Trying to share variable %s, but specified dtype %s"
903                         " and found dtype %s." %
904                         (name, dtype_str, found_type_str))
905      return found_var
906
907    # The code below handles only the case of creating a new variable.
908    if reuse is True:
909      raise ValueError("Variable %s does not exist, or was not created with "
910                       "tf.get_variable(). Did you mean to set "
911                       "reuse=tf.AUTO_REUSE in VarScope?" % name)
912
913    # Create the tensor to initialize the variable with default value.
914    if initializer is None:
915      initializer, initializing_from_value = self._get_default_initializer(
916          name=name, shape=shape, dtype=dtype)
917    # Enter an init scope when creating the initializer.
918    with ops.init_scope():
919      if initializing_from_value:
920        init_val = initializer
921        variable_dtype = None
922      else:
923        # Instantiate initializer if provided initializer is a type object.
924        if tf_inspect.isclass(initializer):
925          initializer = initializer()
926        if shape.is_fully_defined():
927          if "partition_info" in tf_inspect.getargspec(initializer).args:
928            init_val = functools.partial(initializer,
929                                         shape.as_list(),
930                                         dtype=dtype,
931                                         partition_info=partition_info)
932          else:
933            init_val = functools.partial(initializer,
934                                         shape.as_list(), dtype=dtype)
935          variable_dtype = dtype.base_dtype
936        elif _needs_no_arguments(initializer):
937          init_val = initializer
938          variable_dtype = None
939        else:
940          raise ValueError("The initializer passed is not valid. It should "
941                           "be a callable with no arguments and the "
942                           "shape should not be provided or an instance of "
943                           "`tf.keras.initializers.*' and `shape` should be "
944                           "fully defined.")
945
946    # Create the variable.
947    if use_resource is None:
948      # Set the default value if unspecified.
949      use_resource = _DEFAULT_USE_RESOURCE
950    v = variables.VariableV1(
951        initial_value=init_val,
952        name=name,
953        trainable=trainable,
954        collections=collections,
955        caching_device=caching_device,
956        dtype=variable_dtype,
957        validate_shape=validate_shape,
958        constraint=constraint,
959        use_resource=use_resource,
960        synchronization=synchronization,
961        aggregation=aggregation)
962    if context.executing_eagerly() and self._store_eager_variables:
963      if collections:
964        ops.add_to_collections(collections, v)
965      else:
966        ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, v)
967      if trainable:
968        ops.add_to_collection(ops.GraphKeys.TRAINABLE_VARIABLES, v)
969
970    if not context.executing_eagerly() or self._store_eager_variables:
971      # In eager mode we do not want to keep default references to Variable
972      # objects as this will prevent their memory from being released.
973      self._vars[name] = v
974    logging.vlog(1, "Created variable %s with shape %s and init %s", v.name,
975                 format(shape), initializer)
976
977    # Run the regularizer if requested and save the resulting loss.
978    if regularizer:
979      def make_regularizer_op():
980        with ops.colocate_with(v):
981          with ops.name_scope(name + "/Regularizer/"):
982            return regularizer(v)
983
984      if regularizer(v) is not None:
985        lazy_eval_tensor = _LazyEvalTensor(make_regularizer_op)
986        ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES,
987                              lazy_eval_tensor)
988
989    return v
990
991  # Initialize variable when no initializer provided
992  def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32):
993    """Provide a default initializer and a corresponding value.
994
995    Args:
996      name: see get_variable.
997      shape: see get_variable.
998      dtype: see get_variable.
999
1000    Returns:
1001      initializer and initializing_from_value. See get_variable above.
1002
1003    Raises:
1004      ValueError: When giving unsupported dtype.
1005    """
1006    del shape
1007    # If dtype is DT_FLOAT, provide a uniform unit scaling initializer
1008    if dtype.is_floating:
1009      initializer = init_ops.glorot_uniform_initializer()
1010      initializing_from_value = False
1011    # If dtype is DT_INT/DT_UINT, provide a default value `zero`
1012    # If dtype is DT_BOOL, provide a default value `FALSE`
1013    elif (dtype.is_integer or dtype.is_unsigned or dtype.is_bool or
1014          dtype == dtypes.string):
1015      initializer = init_ops.zeros_initializer()
1016      initializing_from_value = False
1017    # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
1018    else:
1019      raise ValueError("An initializer for variable %s of %s is required" %
1020                       (name, dtype.base_dtype))
1021
1022    return initializer, initializing_from_value
1023
1024
1025class _LazyEvalTensor(core.Tensor):
1026  """A Tensor-like object that only evaluates its thunk when used."""
1027
1028  def __init__(self, thunk):
1029    """Initializes a _LazyEvalTensor object.
1030
1031    Args:
1032      thunk: A callable. A thunk which computes the value of the tensor.
1033    """
1034    self._thunk = thunk
1035    self._master_tensor = thunk()
1036
1037  def _as_tensor(self, dtype=None, name=None, as_ref=False):
1038    del name
1039    assert not as_ref
1040    assert dtype in [None, self.dtype]
1041
1042    return self._thunk()
1043
1044
1045def _make_master_property(name):
1046  @property
1047  def prop(self):
1048    return getattr(self._master_tensor, name)  # pylint: disable=protected-access
1049  return prop
1050
1051_master_property_list = ("device", "dtype", "graph", "name", "op", "shape",
1052                         "value_index")
1053for _name in _master_property_list:
1054  setattr(_LazyEvalTensor, _name, _make_master_property(_name))
1055
1056
1057def _make_master_method(name):
1058  def method(self, *args, **kwargs):
1059    return getattr(self._master_tensor, name)(*args, **kwargs)  # pylint: disable=protected-access
1060  return method
1061
1062_master_method_list = ("get_shape", "__str__", "shape_as_list")
1063for _name in _master_method_list:
1064  setattr(_LazyEvalTensor, _name, _make_master_method(_name))
1065
1066
1067def _make_op_method(name):
1068  def method(self, *args, **kwargs):
1069    return getattr(self._as_tensor(), name)(*args, **kwargs)  # pylint: disable=protected-access
1070  return method
1071
1072_op_list = ("__abs__", "__add__", "__and__", "__bool__", "__div__", "__eq__",
1073            "__floordiv__", "__ge__", "__getitem__", "__gt__", "__invert__",
1074            "__iter__", "__le__", "__len__", "__lt__", "__matmul__", "__mod__",
1075            "__mul__", "__ne__", "__neg__", "__nonzero__", "__or__", "__pow__",
1076            "__radd__", "__rand__", "__rdiv__", "__rfloordiv__", "__rmatmul__",
1077            "__rmod__", "__rmul__", "__ror__", "__rpow__", "__rsub__",
1078            "__rtruediv__", "__rxor__", "__sub__", "__truediv__", "__xor__",
1079            "eval", "numpy")
1080for _name in _op_list:
1081  setattr(_LazyEvalTensor, _name, _make_op_method(_name))
1082
1083
1084ops.register_tensor_conversion_function(
1085    _LazyEvalTensor,
1086    lambda val, dtype, name, as_ref: val._as_tensor(dtype, name, as_ref)  # pylint: disable=protected-access
1087    )
1088
1089session.register_session_run_conversion_functions(
1090    _LazyEvalTensor,
1091    lambda fetch: ([fetch._master_tensor], lambda fetched_vals: fetched_vals[0])  # pylint: disable=protected-access
1092    )
1093
1094
1095# To stop regularization, use this regularizer
1096@tf_export(v1=["no_regularizer"])
1097def no_regularizer(_):
1098  """Use this function to prevent regularization of variables."""
1099  return None
1100
1101
1102# TODO(alive): support caching devices and partitioned variables in Eager mode.
1103@tf_export(v1=["VariableScope"])
1104class VariableScope(object):
1105  """Variable scope object to carry defaults to provide to `get_variable`.
1106
1107  Many of the arguments we need for `get_variable` in a variable store are most
1108  easily handled with a context. This object is used for the defaults.
1109
1110  Attributes:
1111    name: name of the current scope, used as prefix in get_variable.
1112    initializer: default initializer passed to get_variable.
1113    regularizer: default regularizer passed to get_variable.
1114    reuse: Boolean, None, or tf.compat.v1.AUTO_REUSE, setting the reuse in
1115      get_variable. When eager execution is enabled this argument is always
1116      forced to be False.
1117    caching_device: string, callable, or None: the caching device passed to
1118      get_variable.
1119    partitioner: callable or `None`: the partitioner passed to `get_variable`.
1120    custom_getter: default custom getter passed to get_variable.
1121    name_scope: The name passed to `tf.name_scope`.
1122    dtype: default type passed to get_variable (defaults to DT_FLOAT).
1123    use_resource: if False, create a normal Variable; if True create an
1124      experimental ResourceVariable with well-defined semantics. Defaults to
1125      False (will later change to True). When eager execution is enabled this
1126      argument is always forced to be True.
1127    constraint: An optional projection function to be applied to the variable
1128      after being updated by an `Optimizer` (e.g. used to implement norm
1129      constraints or value constraints for layer weights). The function must
1130      take as input the unprojected Tensor representing the value of the
1131      variable and return the Tensor for the projected value (which must have
1132      the same shape). Constraints are not safe to use when doing asynchronous
1133      distributed training.
1134  """
1135
1136  def __init__(self,
1137               reuse,
1138               name="",
1139               initializer=None,
1140               regularizer=None,
1141               caching_device=None,
1142               partitioner=None,
1143               custom_getter=None,
1144               name_scope="",
1145               dtype=dtypes.float32,
1146               use_resource=None,
1147               constraint=None):
1148    """Creates a new VariableScope with the given properties."""
1149    self._name = name
1150    self._initializer = initializer
1151    self._regularizer = regularizer
1152    self._reuse = reuse
1153    self._caching_device = caching_device
1154    self._partitioner = partitioner
1155    self._custom_getter = custom_getter
1156    self._name_scope = name_scope
1157    self._dtype = dtype
1158    self._use_resource = use_resource
1159    self._constraint = constraint
1160    if context.executing_eagerly():
1161      if self._caching_device is not None:
1162        raise NotImplementedError("Caching devices is not yet supported "
1163                                  "when eager execution is enabled.")
1164      self._reuse = AUTO_REUSE
1165      self._use_resource = True
1166
1167  @property
1168  def name(self):
1169    return self._name
1170
1171  @property
1172  def original_name_scope(self):
1173    return self._name_scope
1174
1175  @property
1176  def reuse(self):
1177    return self._reuse
1178
1179  @property
1180  def initializer(self):
1181    return self._initializer
1182
1183  @property
1184  def dtype(self):
1185    return self._dtype
1186
1187  @property
1188  def use_resource(self):
1189    return self._use_resource
1190
1191  @property
1192  def regularizer(self):
1193    return self._regularizer
1194
1195  @property
1196  def caching_device(self):
1197    return self._caching_device
1198
1199  @property
1200  def partitioner(self):
1201    return self._partitioner
1202
1203  @property
1204  def custom_getter(self):
1205    return self._custom_getter
1206
1207  @property
1208  def constraint(self):
1209    return self._constraint
1210
1211  def reuse_variables(self):
1212    """Reuse variables in this scope."""
1213    self._reuse = True
1214
1215  def set_initializer(self, initializer):
1216    """Set initializer for this scope."""
1217    self._initializer = initializer
1218
1219  def set_dtype(self, dtype):
1220    """Set data type for this scope."""
1221    self._dtype = dtype
1222
1223  def set_use_resource(self, use_resource):
1224    """Sets whether to use ResourceVariables for this scope."""
1225    if context.executing_eagerly() and not use_resource:
1226      raise ValueError("When eager execution is enabled, "
1227                       "use_resource cannot be set to false.")
1228    self._use_resource = use_resource
1229
1230  def set_regularizer(self, regularizer):
1231    """Set regularizer for this scope."""
1232    self._regularizer = regularizer
1233
1234  def set_caching_device(self, caching_device):
1235    """Set caching_device for this scope."""
1236    if context.executing_eagerly():
1237      raise NotImplementedError("Caching devices are not yet supported "
1238                                "when eager execution is enabled.")
1239    self._caching_device = caching_device
1240
1241  def set_partitioner(self, partitioner):
1242    """Set partitioner for this scope."""
1243    self._partitioner = partitioner
1244
1245  def set_custom_getter(self, custom_getter):
1246    """Set custom getter for this scope."""
1247    self._custom_getter = custom_getter
1248
1249  def get_collection(self, name):
1250    """Get this scope's variables."""
1251    scope = self._name + "/" if self._name else ""
1252    return ops.get_collection(name, scope)
1253
1254  def trainable_variables(self):
1255    """Get this scope's trainable variables."""
1256    return self.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
1257
1258  def global_variables(self):
1259    """Get this scope's global variables."""
1260    return self.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
1261
1262  def local_variables(self):
1263    """Get this scope's local variables."""
1264    return self.get_collection(ops.GraphKeys.LOCAL_VARIABLES)
1265
1266  def get_variable(self,
1267                   var_store,
1268                   name,
1269                   shape=None,
1270                   dtype=None,
1271                   initializer=None,
1272                   regularizer=None,
1273                   reuse=None,
1274                   trainable=None,
1275                   collections=None,
1276                   caching_device=None,
1277                   partitioner=None,
1278                   validate_shape=True,
1279                   use_resource=None,
1280                   custom_getter=None,
1281                   constraint=None,
1282                   synchronization=VariableSynchronization.AUTO,
1283                   aggregation=VariableAggregation.NONE):
1284    """Gets an existing variable with this name or create a new one."""
1285    if regularizer is None:
1286      regularizer = self._regularizer
1287    if caching_device is None:
1288      caching_device = self._caching_device
1289    if partitioner is None:
1290      partitioner = self._partitioner
1291    if custom_getter is None:
1292      custom_getter = self._custom_getter
1293    if context.executing_eagerly():
1294      reuse = False
1295      use_resource = True
1296    else:
1297      if reuse is None:
1298        reuse = self._reuse
1299      if use_resource is None:
1300        use_resource = self._use_resource
1301
1302    full_name = self.name + "/" + name if self.name else name
1303    # Variable names only depend on variable_scope (full_name here),
1304    # not name_scope, so we reset it below for the time of variable creation.
1305    with ops.name_scope(None, skip_on_eager=False):
1306      # Check that `initializer` dtype and `dtype` are consistent before
1307      # replacing them with defaults.
1308      if (dtype is not None and initializer is not None and
1309          not callable(initializer)):
1310        init_dtype = ops.convert_to_tensor(initializer).dtype.base_dtype
1311        if init_dtype != dtype:
1312          raise ValueError("Initializer type '%s' and explicit dtype '%s' "
1313                           "don't match." % (init_dtype, dtype))
1314      if initializer is None:
1315        initializer = self._initializer
1316      if constraint is None:
1317        constraint = self._constraint
1318      if dtype is None:
1319        dtype = self._dtype
1320      return var_store.get_variable(
1321          full_name,
1322          shape=shape,
1323          dtype=dtype,
1324          initializer=initializer,
1325          regularizer=regularizer,
1326          reuse=reuse,
1327          trainable=trainable,
1328          collections=collections,
1329          caching_device=caching_device,
1330          partitioner=partitioner,
1331          validate_shape=validate_shape,
1332          use_resource=use_resource,
1333          custom_getter=custom_getter,
1334          constraint=constraint,
1335          synchronization=synchronization,
1336          aggregation=aggregation)
1337
1338  def _get_partitioned_variable(self,
1339                                var_store,
1340                                name,
1341                                shape=None,
1342                                dtype=None,
1343                                initializer=None,
1344                                regularizer=None,
1345                                trainable=None,
1346                                collections=None,
1347                                caching_device=None,
1348                                partitioner=None,
1349                                validate_shape=True,
1350                                use_resource=None,
1351                                constraint=None,
1352                                synchronization=VariableSynchronization.AUTO,
1353                                aggregation=VariableAggregation.NONE):
1354    """Gets an existing variable with this name or create a new one."""
1355    if initializer is None:
1356      initializer = self._initializer
1357    if regularizer is None:
1358      regularizer = self._regularizer
1359    if constraint is None:
1360      constraint = self._constraint
1361    if caching_device is None:
1362      caching_device = self._caching_device
1363    if partitioner is None:
1364      partitioner = self._partitioner
1365    if dtype is None:
1366      dtype = self._dtype
1367    if use_resource is None:
1368      use_resource = self._use_resource
1369
1370    if self._custom_getter is not None:
1371      raise ValueError(
1372          "Private access to _get_partitioned_variable is not allowed when "
1373          "a custom getter is set.  Current custom getter: %s.  "
1374          "It is likely that you're using create_partitioned_variables.  "
1375          "If so, consider instead using get_variable with a non-empty "
1376          "partitioner parameter instead." % self._custom_getter)
1377
1378    if partitioner is None:
1379      raise ValueError("No partitioner was specified")
1380
1381    # This allows the variable scope name to be used as the variable name if
1382    # this function is invoked with an empty name arg, for backward
1383    # compatibility with create_partitioned_variables().
1384    full_name_list = []
1385    if self.name:
1386      full_name_list.append(self.name)
1387    if name:
1388      full_name_list.append(name)
1389    full_name = "/".join(full_name_list)
1390
1391    # Variable names only depend on variable_scope (full_name here),
1392    # not name_scope, so we reset it below for the time of variable creation.
1393    with ops.name_scope(None, skip_on_eager=False):
1394      # pylint: disable=protected-access
1395      return var_store._get_partitioned_variable(
1396          full_name,
1397          shape=shape,
1398          dtype=dtype,
1399          initializer=initializer,
1400          regularizer=regularizer,
1401          reuse=self.reuse,
1402          trainable=trainable,
1403          collections=collections,
1404          caching_device=caching_device,
1405          partitioner=partitioner,
1406          validate_shape=validate_shape,
1407          use_resource=use_resource,
1408          constraint=constraint,
1409          synchronization=synchronization,
1410          aggregation=aggregation)
1411      # pylint: enable=protected-access
1412
1413
1414_VARSTORE_KEY = ("__variable_store",)
1415_VARSCOPESTORE_KEY = ("__varscope",)
1416
1417
1418class _VariableScopeStore(threading.local):
1419  """A thread local store for the current variable scope and scope counts."""
1420
1421  def __init__(self):
1422    super(_VariableScopeStore, self).__init__()
1423    self.current_scope = VariableScope(False)
1424    self.variable_scopes_count = {}
1425
1426  def open_variable_scope(self, scope_name):
1427    if scope_name in self.variable_scopes_count:
1428      self.variable_scopes_count[scope_name] += 1
1429    else:
1430      self.variable_scopes_count[scope_name] = 1
1431
1432  def close_variable_subscopes(self, scope_name):
1433    for k in list(self.variable_scopes_count.keys()):
1434      if scope_name is None or k.startswith(scope_name + "/"):
1435        self.variable_scopes_count[k] = 0
1436
1437  def variable_scope_count(self, scope_name):
1438    return self.variable_scopes_count.get(scope_name, 0)
1439
1440
1441def get_variable_scope_store():
1442  """Returns the variable scope store for current thread."""
1443  scope_store = ops.get_collection(_VARSCOPESTORE_KEY)
1444
1445  if not scope_store:
1446    scope_store = _VariableScopeStore()
1447    ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store)
1448  else:
1449    scope_store = scope_store[0]
1450
1451  return scope_store
1452
1453
1454@tf_export(v1=["get_variable_scope"])
1455def get_variable_scope():
1456  """Returns the current variable scope."""
1457  return get_variable_scope_store().current_scope
1458
1459
1460def _get_default_variable_store():
1461  store = ops.get_collection(_VARSTORE_KEY)
1462  if store:
1463    return store[0]
1464  store = _VariableStore()
1465  ops.add_to_collection(_VARSTORE_KEY, store)
1466  return store
1467
1468
1469@tf_contextlib.contextmanager
1470def with_variable_store(store):
1471  store_collection = ops.get_collection_ref(_VARSTORE_KEY)
1472  old = list(store_collection)
1473  store_collection[:] = [store]
1474  try:
1475    yield
1476  finally:
1477    store_collection[:] = old
1478
1479
1480class EagerVariableStore(object):
1481  """Wrapper allowing functional layers to be used with eager execution.
1482
1483  When eager execution is enabled Variables get deleted when they go out of
1484  scope, and are not stored in global collections by default. A lot of code
1485  (mostly the functional layers in tf.layers) assumes that variables are kept in
1486  a global list.
1487
1488  EagerVariableStore can be used in conjunction with this code to make it
1489  eager-friendly. For example, to create a dense layer, use:
1490
1491  ```
1492    container = tfe.EagerVariableStore()
1493    for input in dataset_iterator:
1494      with container.as_default():
1495        x = tf.compat.v1.layers.dense(input, name="l1")
1496    print(container.variables)  # Should print the variables used in the layer.
1497  ```
1498  """
1499
1500  def __init__(self, store=None):
1501    if store is not None:
1502      if not store._store_eager_variables:  # pylint: disable=protected-access
1503        raise ValueError("Cannot construct EagerVariableStore from a "
1504                         "VariableStore object that does not hold eager "
1505                         "variables.")
1506      self._store = store
1507    else:
1508      self._store = _VariableStore()
1509    self._store._store_eager_variables = True  # pylint: disable=protected-access
1510
1511  def as_default(self):
1512    return with_variable_store(self._store)
1513
1514  def variables(self):
1515    return sorted(self._store._vars.values(), key=lambda x: x.name)  # pylint: disable=protected-access
1516
1517  def trainable_variables(self):
1518    # pylint: disable=protected-access
1519    return sorted([x for x in self._store._vars.values() if x.trainable],
1520                  key=lambda x: x.name)
1521    # pylint: enable=protected-access
1522
1523  def non_trainable_variables(self):
1524    # pylint: disable=protected-access
1525    return sorted([x for x in self._store._vars.values() if not x.trainable],
1526                  key=lambda x: x.name)
1527    # pylint: enable=protected-access
1528
1529  def copy(self):
1530    """Copy this variable store and all of its contents.
1531
1532    Variables contained in this store will be copied over to the new variable
1533    store, meaning that they can be modified without affecting the variables in
1534    this store.
1535
1536    Returns:
1537      A new EagerVariableStore instance containing copied variables.
1538    """
1539    # pylint: disable=protected-access
1540    new_store = EagerVariableStore()
1541    for key, var in iteritems(self._store._vars):
1542      # Strip device out of variable name.
1543      try:
1544        index = var.name.index(":")
1545      except ValueError:
1546        stripped_var_name = var.name
1547      else:
1548        stripped_var_name = var.name[:index]
1549
1550      # Create new variable with same value, name, and "trainable" flag.
1551      new_var = resource_variable_ops.ResourceVariable(
1552          var.read_value(), name=stripped_var_name, trainable=var.trainable)
1553      new_store._store._vars[key] = new_var
1554    return new_store
1555    # pylint: enable=protected-access
1556
1557
1558# The argument list for get_variable must match arguments to get_local_variable.
1559# So, if you are updating the arguments, also update arguments to
1560# get_local_variable below.
1561@tf_export(v1=["get_variable"])
1562def get_variable(name,
1563                 shape=None,
1564                 dtype=None,
1565                 initializer=None,
1566                 regularizer=None,
1567                 trainable=None,
1568                 collections=None,
1569                 caching_device=None,
1570                 partitioner=None,
1571                 validate_shape=True,
1572                 use_resource=None,
1573                 custom_getter=None,
1574                 constraint=None,
1575                 synchronization=VariableSynchronization.AUTO,
1576                 aggregation=VariableAggregation.NONE):
1577  return get_variable_scope().get_variable(
1578      _get_default_variable_store(),
1579      name,
1580      shape=shape,
1581      dtype=dtype,
1582      initializer=initializer,
1583      regularizer=regularizer,
1584      trainable=trainable,
1585      collections=collections,
1586      caching_device=caching_device,
1587      partitioner=partitioner,
1588      validate_shape=validate_shape,
1589      use_resource=use_resource,
1590      custom_getter=custom_getter,
1591      constraint=constraint,
1592      synchronization=synchronization,
1593      aggregation=aggregation)
1594
1595
1596get_variable_or_local_docstring = ("""%s
1597
1598%sThis function prefixes the name with the current variable scope
1599and performs reuse checks. See the
1600[Variable Scope How To](https://tensorflow.org/guide/variables)
1601for an extensive description of how reusing works. Here is a basic example:
1602
1603```python
1604def foo():
1605  with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
1606    v = tf.get_variable("v", [1])
1607  return v
1608
1609v1 = foo()  # Creates v.
1610v2 = foo()  # Gets the same, existing v.
1611assert v1 == v2
1612```
1613
1614If initializer is `None` (the default), the default initializer passed in
1615the variable scope will be used. If that one is `None` too, a
1616`glorot_uniform_initializer` will be used. The initializer can also be
1617a Tensor, in which case the variable is initialized to this value and shape.
1618
1619Similarly, if the regularizer is `None` (the default), the default regularizer
1620passed in the variable scope will be used (if that is `None` too,
1621then by default no regularization is performed).
1622
1623If a partitioner is provided, a `PartitionedVariable` is returned.
1624Accessing this object as a `Tensor` returns the shards concatenated along
1625the partition axis.
1626
1627Some useful partitioners are available.  See, e.g.,
1628`variable_axis_size_partitioner` and `min_max_variable_partitioner`.
1629
1630Args:
1631  name: The name of the new or existing variable.
1632  shape: Shape of the new or existing variable.
1633  dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
1634  initializer: Initializer for the variable if one is created. Can either be
1635    an initializer object or a Tensor. If it's a Tensor, its shape must be known
1636    unless validate_shape is False.
1637  regularizer: A (Tensor -> Tensor or None) function; the result of
1638    applying it on a newly created variable will be added to the collection
1639    `tf.GraphKeys.REGULARIZATION_LOSSES` and can be used for regularization.
1640  %scollections: List of graph collections keys to add the Variable to.
1641    Defaults to `[%s]` (see `tf.Variable`).
1642  caching_device: Optional device string or function describing where the
1643    Variable should be cached for reading.  Defaults to the Variable's
1644    device.  If not `None`, caches on another device.  Typical use is to
1645    cache on the device where the Ops using the Variable reside, to
1646    deduplicate copying through `Switch` and other conditional statements.
1647  partitioner: Optional callable that accepts a fully defined `TensorShape`
1648    and `dtype` of the Variable to be created, and returns a list of
1649    partitions for each axis (currently only one axis can be partitioned).
1650  validate_shape: If False, allows the variable to be initialized with a
1651      value of unknown shape. If True, the default, the shape of initial_value
1652      must be known. For this to be used the initializer must be a Tensor and
1653      not an initializer object.
1654  use_resource: If False, creates a regular Variable. If true, creates an
1655    experimental ResourceVariable instead with well-defined semantics.
1656    Defaults to False (will later change to True). When eager execution is
1657    enabled this argument is always forced to be True.
1658  custom_getter: Callable that takes as a first argument the true getter, and
1659    allows overwriting the internal get_variable method.
1660    The signature of `custom_getter` should match that of this method,
1661    but the most future-proof version will allow for changes:
1662    `def custom_getter(getter, *args, **kwargs)`.  Direct access to
1663    all `get_variable` parameters is also allowed:
1664    `def custom_getter(getter, name, *args, **kwargs)`.  A simple identity
1665    custom getter that simply creates variables with modified names is:
1666    ```python
1667    def custom_getter(getter, name, *args, **kwargs):
1668      return getter(name + '_suffix', *args, **kwargs)
1669    ```
1670  constraint: An optional projection function to be applied to the variable
1671    after being updated by an `Optimizer` (e.g. used to implement norm
1672    constraints or value constraints for layer weights). The function must
1673    take as input the unprojected Tensor representing the value of the
1674    variable and return the Tensor for the projected value
1675    (which must have the same shape). Constraints are not safe to
1676    use when doing asynchronous distributed training.
1677  synchronization: Indicates when a distributed a variable will be
1678    aggregated. Accepted values are constants defined in the class
1679    `tf.VariableSynchronization`. By default the synchronization is set to
1680    `AUTO` and the current `DistributionStrategy` chooses
1681    when to synchronize.
1682  aggregation: Indicates how a distributed variable will be aggregated.
1683    Accepted values are constants defined in the class
1684    `tf.VariableAggregation`.
1685
1686Returns:
1687  The created or existing `Variable` (or `PartitionedVariable`, if a
1688  partitioner was used).
1689
1690Raises:
1691  ValueError: when creating a new variable and shape is not declared,
1692    when violating reuse during variable creation, or when `initializer` dtype
1693    and `dtype` don't match. Reuse is set inside `variable_scope`.
1694""")
1695get_variable.__doc__ = get_variable_or_local_docstring % (
1696    "Gets an existing variable with these parameters or create a new one.", "",
1697    "trainable: If `True` also add the variable to the graph collection\n"
1698    "    `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).\n  ",
1699    "GraphKeys.GLOBAL_VARIABLES")
1700
1701
1702# The argument list for get_local_variable must match arguments to get_variable.
1703# So, if you are updating the arguments, also update arguments to get_variable.
1704@tf_export(v1=["get_local_variable"])
1705def get_local_variable(  # pylint: disable=missing-docstring
1706    name,
1707    shape=None,
1708    dtype=None,
1709    initializer=None,
1710    regularizer=None,
1711    trainable=False,  # pylint: disable=unused-argument
1712    collections=None,
1713    caching_device=None,
1714    partitioner=None,
1715    validate_shape=True,
1716    use_resource=None,
1717    custom_getter=None,
1718    constraint=None,
1719    synchronization=VariableSynchronization.AUTO,
1720    aggregation=VariableAggregation.NONE):
1721  if collections:
1722    collections += [ops.GraphKeys.LOCAL_VARIABLES]
1723  else:
1724    collections = [ops.GraphKeys.LOCAL_VARIABLES]
1725  return get_variable(
1726      name,
1727      shape=shape,
1728      dtype=dtype,
1729      initializer=initializer,
1730      regularizer=regularizer,
1731      trainable=False,
1732      collections=collections,
1733      caching_device=caching_device,
1734      partitioner=partitioner,
1735      validate_shape=validate_shape,
1736      use_resource=use_resource,
1737      synchronization=synchronization,
1738      aggregation=aggregation,
1739      custom_getter=custom_getter,
1740      constraint=constraint)
1741
1742
1743get_local_variable.__doc__ = get_variable_or_local_docstring % (
1744    "Gets an existing *local* variable or creates a new one.",
1745    "Behavior is the same as in `get_variable`, except that variables are\n"
1746    "added to the `LOCAL_VARIABLES` collection and `trainable` is set to\n"
1747    "`False`.\n", "", "GraphKeys.LOCAL_VARIABLES")
1748
1749
1750def _get_partitioned_variable(name,
1751                              shape=None,
1752                              dtype=None,
1753                              initializer=None,
1754                              regularizer=None,
1755                              trainable=True,
1756                              collections=None,
1757                              caching_device=None,
1758                              partitioner=None,
1759                              validate_shape=True,
1760                              use_resource=None,
1761                              constraint=None,
1762                              synchronization=VariableSynchronization.AUTO,
1763                              aggregation=VariableAggregation.NONE):
1764  """Gets or creates a sharded variable list with these parameters.
1765
1766  The `partitioner` must be a callable that accepts a fully defined
1767  `TensorShape` and returns a sequence of integers (the `partitions`).
1768  These integers describe how to partition the given sharded `Variable`
1769  along the given dimension.  That is, `partitions[1] = 3` means split
1770  the `Variable` into 3 shards along dimension 1.  Currently, sharding along
1771  only one axis is supported.
1772
1773  If the list of variables with the given name (prefix) is already stored,
1774  we return the stored variables. Otherwise, we create a new one.
1775
1776  If initializer is `None` (the default), the default initializer passed in
1777  the constructor is used. If that one is `None` too, we use a new
1778  `glorot_uniform_initializer`. If initializer is a Tensor, we use
1779  it as a value and derive the shape from the initializer.
1780
1781  If the initializer is a callable, then it will be called for each
1782  shard.  Otherwise the initializer should match the shape of the entire
1783  sharded Variable, and it will be sliced accordingly for each shard.
1784
1785  Some useful partitioners are available.  See, e.g.,
1786  `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
1787
1788  Args:
1789    name: The name of the new or existing variable.
1790    shape: Shape of the new or existing variable.
1791    dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
1792    initializer: Initializer for the variable if one is created.
1793    regularizer: A (Tensor -> Tensor or None) function; the result of applying
1794      it on a newly created variable will be added to the collection
1795      GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
1796    trainable: If `True` also add the variable to the graph collection
1797      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
1798    collections: List of graph collections keys to add the Variable to. Defaults
1799      to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
1800    caching_device: Optional device string or function describing where the
1801      Variable should be cached for reading.  Defaults to the Variable's device.
1802      If not `None`, caches on another device.  Typical use is to cache on the
1803      device where the Ops using the Variable reside, to deduplicate copying
1804      through `Switch` and other conditional statements.
1805    partitioner: Optional callable that accepts a fully defined `TensorShape`
1806      and `dtype` of the Variable to be created, and returns a list of
1807      partitions for each axis (currently only one axis can be partitioned).
1808    validate_shape: If False, allows the variable to be initialized with a value
1809      of unknown shape. If True, the default, the shape of initial_value must be
1810      known.
1811    use_resource: If False, creates a regular Variable. If True, creates an
1812      experimental ResourceVariable instead which has well-defined semantics.
1813      Defaults to False (will later change to True).
1814    constraint: An optional projection function to be applied to the variable
1815      after being updated by an `Optimizer` (e.g. used to implement norm
1816      constraints or value constraints for layer weights). The function must
1817      take as input the unprojected Tensor representing the value of the
1818      variable and return the Tensor for the projected value (which must have
1819      the same shape). Constraints are not safe to use when doing asynchronous
1820      distributed training.
1821    synchronization: Indicates when a distributed a variable will be aggregated.
1822      Accepted values are constants defined in the class
1823      `tf.VariableSynchronization`. By default the synchronization is set to
1824      `AUTO` and the current `DistributionStrategy` chooses when to synchronize.
1825    aggregation: Indicates how a distributed variable will be aggregated.
1826      Accepted values are constants defined in the class
1827      `tf.VariableAggregation`.
1828
1829  Returns:
1830    A tuple `(shards, partitions)` where `shards` is the list of `Variable`
1831    shards and `partitions` is the output of the partitioner on the input
1832    shape.
1833
1834  Raises:
1835    ValueError: when creating a new variable and shape is not declared,
1836      or when violating reuse during variable creation. Reuse is set inside
1837      `variable_scope`.
1838  """
1839  # pylint: disable=protected-access
1840  scope = get_variable_scope()
1841  if scope.custom_getter is not None:
1842    raise ValueError(
1843        "Private access to _get_partitioned_variable is not allowed when "
1844        "a custom getter is set.  Current custom getter: %s.  "
1845        "It is likely that you're using create_partitioned_variables.  "
1846        "If so, consider instead using get_variable with a non-empty "
1847        "partitioner parameter instead." % scope.custom_getter)
1848  return scope._get_partitioned_variable(
1849      _get_default_variable_store(),
1850      name,
1851      shape=shape,
1852      dtype=dtype,
1853      initializer=initializer,
1854      regularizer=regularizer,
1855      trainable=trainable,
1856      collections=collections,
1857      caching_device=caching_device,
1858      partitioner=partitioner,
1859      validate_shape=validate_shape,
1860      use_resource=use_resource,
1861      constraint=constraint,
1862      synchronization=synchronization,
1863      aggregation=aggregation)
1864  # pylint: enable=protected-access
1865
1866
1867# Named like a function for compatibility with the previous
1868# @tf_contextlib.contextmanager definition.
1869class _pure_variable_scope(object):  # pylint: disable=invalid-name
1870  """A context for the variable_scope, see `variable_scope` for docs."""
1871
1872  def __init__(self,
1873               name_or_scope,
1874               reuse=None,
1875               initializer=None,
1876               regularizer=None,
1877               caching_device=None,
1878               partitioner=None,
1879               custom_getter=None,
1880               old_name_scope=None,
1881               dtype=dtypes.float32,
1882               use_resource=None,
1883               constraint=None):
1884    """Creates a context for the variable_scope, see `variable_scope` for docs.
1885
1886    Note: this does not create a name scope.
1887
1888    Args:
1889      name_or_scope: `string` or `VariableScope`: the scope to open.
1890      reuse: `True` or None, or tf.compat.v1.AUTO_REUSE; if `None`, we inherit
1891        the parent scope's reuse flag.
1892      initializer: default initializer for variables within this scope.
1893      regularizer: default regularizer for variables within this scope.
1894      caching_device: default caching device for variables within this scope.
1895      partitioner: default partitioner for variables within this scope.
1896      custom_getter: default custom getter for variables within this scope.
1897      old_name_scope: the original name scope when re-entering a variable scope.
1898      dtype: type of the variables within this scope (defaults to `DT_FLOAT`).
1899      use_resource: If False, variables in this scope will be regular Variables.
1900        If True, experimental ResourceVariables will be creates instead, with
1901        well-defined semantics. Defaults to False (will later change to True).
1902      constraint: An optional projection function to be applied to the variable
1903        after being updated by an `Optimizer` (e.g. used to implement norm
1904        constraints or value constraints for layer weights). The function must
1905        take as input the unprojected Tensor representing the value of the
1906        variable and return the Tensor for the projected value (which must have
1907        the same shape). Constraints are not safe to use when doing asynchronous
1908        distributed training.
1909    """
1910    self._name_or_scope = name_or_scope
1911    self._reuse = reuse
1912    self._initializer = initializer
1913    self._regularizer = regularizer
1914    self._caching_device = caching_device
1915    self._partitioner = partitioner
1916    self._custom_getter = custom_getter
1917    self._old_name_scope = old_name_scope
1918    self._dtype = dtype
1919    self._use_resource = use_resource
1920    self._constraint = constraint
1921    self._var_store = _get_default_variable_store()
1922    self._var_scope_store = get_variable_scope_store()
1923    self._last_variable_scope_object = None
1924    if isinstance(self._name_or_scope, VariableScope):
1925      self._new_name = self._name_or_scope.name
1926      name_scope = self._name_or_scope._name_scope  # pylint: disable=protected-access
1927      # Handler for the case when we jump to a shared scope.  We create a new
1928      #   VariableScope (self._var_scope_object) that contains a copy of the
1929      #   provided shared scope, possibly with changed reuse and initializer, if
1930      #   the user requested this.
1931      variable_scope_object = VariableScope(
1932          self._name_or_scope.reuse if not self._reuse else self._reuse,
1933          name=self._new_name,
1934          initializer=self._name_or_scope.initializer,
1935          regularizer=self._name_or_scope.regularizer,
1936          caching_device=self._name_or_scope.caching_device,
1937          partitioner=self._name_or_scope.partitioner,
1938          dtype=self._name_or_scope.dtype,
1939          custom_getter=self._name_or_scope.custom_getter,
1940          name_scope=name_scope,
1941          use_resource=self._name_or_scope.use_resource,
1942          constraint=self._constraint)
1943      if self._initializer is not None:
1944        variable_scope_object.set_initializer(self._initializer)
1945      if self._regularizer is not None:
1946        variable_scope_object.set_regularizer(self._regularizer)
1947      if self._caching_device is not None:
1948        variable_scope_object.set_caching_device(self._caching_device)
1949      if self._partitioner is not None:
1950        variable_scope_object.set_partitioner(self._partitioner)
1951      if self._custom_getter is not None:
1952        variable_scope_object.set_custom_getter(
1953            _maybe_wrap_custom_getter(self._custom_getter,
1954                                      self._name_or_scope.custom_getter))
1955      if self._dtype is not None:
1956        variable_scope_object.set_dtype(self._dtype)
1957      if self._use_resource is not None:
1958        variable_scope_object.set_use_resource(self._use_resource)
1959      self._cached_variable_scope_object = variable_scope_object
1960
1961  def __enter__(self):
1962    """Begins the scope block.
1963
1964    Returns:
1965      A VariableScope.
1966    Raises:
1967      ValueError: when trying to reuse within a create scope, or create within
1968        a reuse scope, or if reuse is not `None` or `True`.
1969      TypeError: when the types of some arguments are not appropriate.
1970    """
1971    self._old = self._var_scope_store.current_scope
1972    if isinstance(self._name_or_scope, VariableScope):
1973      self._var_scope_store.open_variable_scope(self._new_name)
1974      self._old_subscopes = copy.copy(
1975          self._var_scope_store.variable_scopes_count)
1976      variable_scope_object = self._cached_variable_scope_object
1977    else:
1978      # Handler for the case when we just prolong current variable scope.
1979      #   VariableScope with name extended by the provided one, and inherited
1980      #   reuse and initializer (except if the user provided values to set).
1981      self._new_name = (
1982          self._old.name + "/" +
1983          self._name_or_scope if self._old.name else self._name_or_scope)
1984      self._reuse = (self._reuse or
1985                     self._old.reuse)  # Re-using is inherited by sub-scopes.
1986      if self._old_name_scope is None:
1987        name_scope = self._name_or_scope
1988      else:
1989        name_scope = self._old_name_scope
1990      variable_scope_object = VariableScope(
1991          self._reuse,
1992          name=self._new_name,
1993          initializer=self._old.initializer,
1994          regularizer=self._old.regularizer,
1995          caching_device=self._old.caching_device,
1996          partitioner=self._old.partitioner,
1997          dtype=self._old.dtype,
1998          use_resource=self._old.use_resource,
1999          custom_getter=self._old.custom_getter,
2000          name_scope=name_scope,
2001          constraint=self._constraint)
2002      if self._initializer is not None:
2003        variable_scope_object.set_initializer(self._initializer)
2004      if self._regularizer is not None:
2005        variable_scope_object.set_regularizer(self._regularizer)
2006      if self._caching_device is not None:
2007        variable_scope_object.set_caching_device(self._caching_device)
2008      if self._partitioner is not None:
2009        variable_scope_object.set_partitioner(self._partitioner)
2010      if self._custom_getter is not None:
2011        variable_scope_object.set_custom_getter(
2012            _maybe_wrap_custom_getter(self._custom_getter,
2013                                      self._old.custom_getter))
2014      if self._dtype is not None:
2015        variable_scope_object.set_dtype(self._dtype)
2016      if self._use_resource is not None:
2017        variable_scope_object.set_use_resource(self._use_resource)
2018      self._var_scope_store.open_variable_scope(self._new_name)
2019    self._var_scope_store.current_scope = variable_scope_object
2020    self._last_variable_scope_object = variable_scope_object
2021    return variable_scope_object
2022
2023  def __exit__(self, type_arg, value_arg, traceback_arg):
2024    if (self._var_scope_store.current_scope is
2025        not self._last_variable_scope_object):
2026      raise RuntimeError("Improper nesting of variable_scope.")
2027    # If jumping out from a non-prolonged scope, restore counts.
2028    if isinstance(self._name_or_scope, VariableScope):
2029      self._var_scope_store.variable_scopes_count = self._old_subscopes
2030    else:
2031      self._var_scope_store.close_variable_subscopes(self._new_name)
2032    self._var_scope_store.current_scope = self._old
2033
2034
2035def _maybe_wrap_custom_getter(custom_getter, old_getter):
2036  """Wrap a call to a custom_getter to use the old_getter internally."""
2037  if old_getter is None:
2038    return custom_getter
2039
2040  # The new custom_getter should call the old one
2041  def wrapped_custom_getter(getter, *args, **kwargs):
2042    # Call:
2043    #  custom_getter(
2044    #    lambda: old_getter(true_getter, ...), *args, **kwargs)
2045    # which means custom_getter will call old_getter, which
2046    # will call the true_getter, perform any intermediate
2047    # processing, and return the results to the current
2048    # getter, which will also perform additional processing.
2049    return custom_getter(functools.partial(old_getter, getter), *args, **kwargs)
2050
2051  return wrapped_custom_getter
2052
2053
2054def _get_unique_variable_scope(prefix):
2055  """Get a name with the given prefix unique in the current variable scope."""
2056  var_scope_store = get_variable_scope_store()
2057  current_scope = get_variable_scope()
2058  name = current_scope.name + "/" + prefix if current_scope.name else prefix
2059  if var_scope_store.variable_scope_count(name) == 0:
2060    return prefix
2061  idx = 1
2062  while var_scope_store.variable_scope_count(name + ("_%d" % idx)) > 0:
2063    idx += 1
2064  return prefix + ("_%d" % idx)
2065
2066
2067# Named like a function for backwards compatibility with the
2068# @tf_contextlib.contextmanager version, which was switched to a class to avoid
2069# some object creation overhead.
2070@tf_export(v1=["variable_scope"])  # pylint: disable=invalid-name
2071class variable_scope(object):
2072  """A context manager for defining ops that creates variables (layers).
2073
2074  This context manager validates that the (optional) `values` are from the same
2075  graph, ensures that graph is the default graph, and pushes a name scope and a
2076  variable scope.
2077
2078  If `name_or_scope` is not None, it is used as is. If `name_or_scope` is None,
2079  then `default_name` is used.  In that case, if the same name has been
2080  previously used in the same scope, it will be made unique by appending `_N`
2081  to it.
2082
2083  Variable scope allows you to create new variables and to share already created
2084  ones while providing checks to not create or share by accident. For details,
2085  see the [Variable Scope How To](https://tensorflow.org/guide/variables), here
2086  we present only a few basic examples.
2087
2088  The Variable Scope works as expected when the Eager Execution is Disabled.
2089
2090  ```python
2091  tf.compat.v1.disable_eager_execution()
2092  ```
2093
2094  Simple example of how to create a new variable:
2095
2096  ```python
2097  with tf.compat.v1.variable_scope("foo"):
2098      with tf.compat.v1.variable_scope("bar"):
2099          v = tf.compat.v1.get_variable("v", [1])
2100          assert v.name == "foo/bar/v:0"
2101  ```
2102
2103  Simple example of how to reenter a premade variable scope safely:
2104
2105  ```python
2106  with tf.compat.v1.variable_scope("foo") as vs:
2107    pass
2108
2109  # Re-enter the variable scope.
2110  with tf.compat.v1.variable_scope(vs,
2111                         auxiliary_name_scope=False) as vs1:
2112    # Restore the original name_scope.
2113    with tf.name_scope(vs1.original_name_scope):
2114        v = tf.compat.v1.get_variable("v", [1])
2115        assert v.name == "foo/v:0"
2116        c = tf.constant([1], name="c")
2117        assert c.name == "foo/c:0"
2118  ```
2119
2120  Keep in mind that the counters for `default_name` are discarded once the
2121  parent scope is exited. Therefore when the code re-enters the scope (for
2122  instance by saving it), all nested default_name counters will be restarted.
2123
2124  For instance:
2125
2126  ```python
2127  with tf.compat.v1.variable_scope("foo") as vs:
2128    with tf.compat.v1.variable_scope(None, default_name="bar"):
2129      v = tf.compat.v1.get_variable("a", [1])
2130      assert v.name == "foo/bar/a:0", v.name
2131    with tf.compat.v1.variable_scope(None, default_name="bar"):
2132      v = tf.compat.v1.get_variable("b", [1])
2133      assert v.name == "foo/bar_1/b:0"
2134
2135  with tf.compat.v1.variable_scope(vs):
2136    with tf.compat.v1.variable_scope(None, default_name="bar"):
2137      v = tf.compat.v1.get_variable("c", [1])
2138      assert v.name == "foo/bar/c:0"   # Uses bar instead of bar_2!
2139  ```
2140
2141  Basic example of sharing a variable AUTO_REUSE:
2142
2143  ```python
2144  def foo():
2145    with tf.compat.v1.variable_scope("foo", reuse=tf.compat.v1.AUTO_REUSE):
2146      v = tf.compat.v1.get_variable("v", [1])
2147    return v
2148
2149  v1 = foo()  # Creates v.
2150  v2 = foo()  # Gets the same, existing v.
2151  assert v1 == v2
2152  ```
2153
2154  Basic example of sharing a variable with reuse=True:
2155
2156  ```python
2157  with tf.compat.v1.variable_scope("foo"):
2158      v = tf.compat.v1.get_variable("v", [1])
2159  with tf.compat.v1.variable_scope("foo", reuse=True):
2160      v1 = tf.compat.v1.get_variable("v", [1])
2161  assert v1 == v
2162  ```
2163
2164  Sharing a variable by capturing a scope and setting reuse:
2165
2166  ```python
2167  with tf.compat.v1.variable_scope("foo") as scope:
2168      v = tf.compat.v1.get_variable("v", [1])
2169      scope.reuse_variables()
2170      v1 = tf.compat.v1.get_variable("v", [1])
2171  assert v1 == v
2172  ```
2173
2174  To prevent accidental sharing of variables, we raise an exception when getting
2175  an existing variable in a non-reusing scope.
2176
2177  ```python
2178  with tf.compat.v1.variable_scope("foo"):
2179      v = tf.compat.v1.get_variable("v", [1])
2180      v1 = tf.compat.v1.get_variable("v", [1])
2181      #  Raises ValueError("... v already exists ...").
2182  ```
2183
2184  Similarly, we raise an exception when trying to get a variable that does not
2185  exist in reuse mode.
2186
2187  ```python
2188  with tf.compat.v1.variable_scope("foo", reuse=True):
2189      v = tf.compat.v1.get_variable("v", [1])
2190      #  Raises ValueError("... v does not exists ...").
2191  ```
2192
2193  Note that the `reuse` flag is inherited: if we open a reusing scope, then all
2194  its sub-scopes become reusing as well.
2195
2196  A note about name scoping: Setting `reuse` does not impact the naming of other
2197  ops such as mult. See related discussion on
2198  [github#6189](https://github.com/tensorflow/tensorflow/issues/6189)
2199
2200  Note that up to and including version 1.0, it was allowed (though explicitly
2201  discouraged) to pass False to the reuse argument, yielding undocumented
2202  behaviour slightly different from None. Starting at 1.1.0 passing None and
2203  False as reuse has exactly the same effect.
2204
2205  A note about using variable scopes in multi-threaded environment: Variable
2206  scopes are thread local, so one thread will not see another thread's current
2207  scope. Also, when using `default_name`, unique scopes names are also generated
2208  only on a per thread basis. If the same name was used within a different
2209  thread, that doesn't prevent a new thread from creating the same scope.
2210  However, the underlying variable store is shared across threads (within the
2211  same graph). As such, if another thread tries to create a new variable with
2212  the same name as a variable created by a previous thread, it will fail unless
2213  reuse is True.
2214
2215  Further, each thread starts with an empty variable scope. So if you wish to
2216  preserve name prefixes from a scope from the main thread, you should capture
2217  the main thread's scope and re-enter it in each thread. For e.g.
2218
2219  ```
2220  main_thread_scope = variable_scope.get_variable_scope()
2221
2222  # Thread's target function:
2223  def thread_target_fn(captured_scope):
2224    with variable_scope.variable_scope(captured_scope):
2225      # .... regular code for this thread
2226
2227
2228  thread = threading.Thread(target=thread_target_fn, args=(main_thread_scope,))
2229  ```
2230  """
2231
2232  def __init__(self,
2233               name_or_scope,
2234               default_name=None,
2235               values=None,
2236               initializer=None,
2237               regularizer=None,
2238               caching_device=None,
2239               partitioner=None,
2240               custom_getter=None,
2241               reuse=None,
2242               dtype=None,
2243               use_resource=None,
2244               constraint=None,
2245               auxiliary_name_scope=True):
2246    """Initialize the context manager.
2247
2248    Args:
2249      name_or_scope: `string` or `VariableScope`: the scope to open.
2250      default_name: The default name to use if the `name_or_scope` argument is
2251        `None`, this name will be uniquified. If name_or_scope is provided it
2252        won't be used and therefore it is not required and can be None.
2253      values: The list of `Tensor` arguments that are passed to the op function.
2254      initializer: default initializer for variables within this scope.
2255      regularizer: default regularizer for variables within this scope.
2256      caching_device: default caching device for variables within this scope.
2257      partitioner: default partitioner for variables within this scope.
2258      custom_getter: default custom getter for variables within this scope.
2259      reuse: `True`, None, or tf.compat.v1.AUTO_REUSE; if `True`, we go into
2260        reuse mode for this scope as well as all sub-scopes; if
2261        tf.compat.v1.AUTO_REUSE, we create variables if they do not exist, and
2262        return them otherwise; if None, we inherit the parent scope's reuse
2263        flag. When eager execution is enabled, new variables are always created
2264        unless an EagerVariableStore or template is currently active.
2265      dtype: type of variables created in this scope (defaults to the type in
2266        the passed scope, or inherited from parent scope).
2267      use_resource: If False, all variables will be regular Variables. If True,
2268        experimental ResourceVariables with well-defined semantics will be used
2269        instead. Defaults to False (will later change to True). When eager
2270        execution is enabled this argument is always forced to be True.
2271      constraint: An optional projection function to be applied to the variable
2272        after being updated by an `Optimizer` (e.g. used to implement norm
2273        constraints or value constraints for layer weights). The function must
2274        take as input the unprojected Tensor representing the value of the
2275        variable and return the Tensor for the projected value (which must have
2276        the same shape). Constraints are not safe to use when doing asynchronous
2277        distributed training.
2278      auxiliary_name_scope: If `True`, we create an auxiliary name scope with
2279        the scope. If `False`, we don't create it. Note that the argument is not
2280        inherited, and it only takes effect for once when creating. You should
2281        only use it for re-entering a premade variable scope.
2282
2283    Returns:
2284      A scope that can be captured and reused.
2285
2286    Raises:
2287      ValueError: when trying to reuse within a create scope, or create within
2288        a reuse scope.
2289      TypeError: when the types of some arguments are not appropriate.
2290    """
2291    self._name_or_scope = name_or_scope
2292    self._default_name = default_name
2293    self._values = values
2294    self._initializer = initializer
2295    self._regularizer = regularizer
2296    self._caching_device = caching_device
2297    self._partitioner = partitioner
2298    self._custom_getter = custom_getter
2299    self._reuse = reuse
2300    self._dtype = dtype
2301    self._use_resource = use_resource
2302    self._constraint = constraint
2303    if self._default_name is None and self._name_or_scope is None:
2304      raise TypeError("If default_name is None then name_or_scope is required")
2305    if self._reuse is False:
2306      # We don't allow non-inheriting scopes, False = None here.
2307      self._reuse = None
2308    if not (self._reuse is True
2309            or self._reuse is None
2310            or self._reuse is AUTO_REUSE):
2311      raise ValueError("The reuse parameter must be True or False or None.")
2312    if self._values is None:
2313      self._values = []
2314    self._in_graph_mode = not context.executing_eagerly()
2315    if self._in_graph_mode:
2316      self._graph = ops._get_graph_from_inputs(self._values)  # pylint: disable=protected-access
2317    self._cached_pure_variable_scope = None
2318    self._current_name_scope = None
2319    if not isinstance(auxiliary_name_scope, bool):
2320      raise TypeError("The auxiliary_name_scope must be `True` or `False`, "
2321                      "while get {}".format(auxiliary_name_scope))
2322    self._auxiliary_name_scope = auxiliary_name_scope
2323
2324  def __enter__(self):
2325    # If the default graph is building a function, then we should not replace it
2326    # with the cached graph.
2327    if ops.get_default_graph().building_function:
2328      self._building_function = True
2329    else:
2330      self._building_function = False
2331    if self._in_graph_mode and not self._building_function:
2332      self._graph_context_manager = self._graph.as_default()
2333      self._graph_context_manager.__enter__()
2334    if self._cached_pure_variable_scope is not None:
2335      # Fast path for re-entering variable_scopes. We've held on to the pure
2336      # variable scope from a previous successful __enter__, so we avoid some
2337      # overhead by re-using that object.
2338      if self._current_name_scope is not None:
2339        self._current_name_scope.__enter__()
2340      return self._cached_pure_variable_scope.__enter__()
2341
2342    try:
2343      return self._enter_scope_uncached()
2344    except:
2345      if (self._in_graph_mode and not self._building_function and
2346          self._graph_context_manager is not None):
2347        self._graph_context_manager.__exit__(*sys.exc_info())
2348      raise
2349
2350  def _enter_scope_uncached(self):
2351    """Enters the context manager when there is no cached scope yet.
2352
2353    Returns:
2354      The entered variable scope.
2355
2356    Raises:
2357      TypeError: A wrong type is passed as `scope` at __init__().
2358      ValueError: `reuse` is incorrectly set at __init__().
2359    """
2360    if self._auxiliary_name_scope:
2361      # Create a new name scope later
2362      current_name_scope = None
2363    else:
2364      # Reenter the current name scope
2365      name_scope = ops.get_name_scope()
2366      if name_scope:
2367        # Hack to reenter
2368        name_scope += "/"
2369        current_name_scope = ops.name_scope(name_scope, skip_on_eager=False)
2370      else:
2371        # Root scope
2372        current_name_scope = ops.name_scope(name_scope, skip_on_eager=False)
2373
2374    # IMPORTANT: Only assign to self._cached_pure_variable_scope and
2375    # self._current_name_scope after successful __enter__() calls.
2376    if self._name_or_scope is not None:
2377      if not isinstance(self._name_or_scope,
2378                        (VariableScope,) + six.string_types):
2379        raise TypeError("VariableScope: name_or_scope must be a string or "
2380                        "VariableScope.")
2381      if isinstance(self._name_or_scope, six.string_types):
2382        name_scope = self._name_or_scope
2383      else:
2384        name_scope = self._name_or_scope.name.split("/")[-1]
2385      if name_scope or current_name_scope:
2386        current_name_scope = current_name_scope or ops.name_scope(
2387            name_scope, skip_on_eager=False)
2388        try:
2389          current_name_scope_name = current_name_scope.__enter__()
2390        except:
2391          current_name_scope.__exit__(*sys.exc_info())
2392          raise
2393        self._current_name_scope = current_name_scope
2394        if isinstance(self._name_or_scope, six.string_types):
2395          old_name_scope = current_name_scope_name
2396        else:
2397          old_name_scope = self._name_or_scope.original_name_scope
2398        pure_variable_scope = _pure_variable_scope(
2399            self._name_or_scope,
2400            reuse=self._reuse,
2401            initializer=self._initializer,
2402            regularizer=self._regularizer,
2403            caching_device=self._caching_device,
2404            partitioner=self._partitioner,
2405            custom_getter=self._custom_getter,
2406            old_name_scope=old_name_scope,
2407            dtype=self._dtype,
2408            use_resource=self._use_resource,
2409            constraint=self._constraint)
2410        try:
2411          entered_pure_variable_scope = pure_variable_scope.__enter__()
2412        except:
2413          pure_variable_scope.__exit__(*sys.exc_info())
2414          raise
2415        self._cached_pure_variable_scope = pure_variable_scope
2416        return entered_pure_variable_scope
2417      else:
2418        self._current_name_scope = None
2419        # This can only happen if someone is entering the root variable scope.
2420        pure_variable_scope = _pure_variable_scope(
2421            self._name_or_scope,
2422            reuse=self._reuse,
2423            initializer=self._initializer,
2424            regularizer=self._regularizer,
2425            caching_device=self._caching_device,
2426            partitioner=self._partitioner,
2427            custom_getter=self._custom_getter,
2428            dtype=self._dtype,
2429            use_resource=self._use_resource,
2430            constraint=self._constraint)
2431        try:
2432          entered_pure_variable_scope = pure_variable_scope.__enter__()
2433        except:
2434          pure_variable_scope.__exit__(*sys.exc_info())
2435          raise
2436        self._cached_pure_variable_scope = pure_variable_scope
2437        return entered_pure_variable_scope
2438
2439    else:  # Here name_or_scope is None. Using default name, but made unique.
2440      if self._reuse:
2441        raise ValueError("reuse=True cannot be used without a name_or_scope")
2442      current_name_scope = current_name_scope or ops.name_scope(
2443          self._default_name, skip_on_eager=False)
2444      try:
2445        current_name_scope_name = current_name_scope.__enter__()
2446      except:
2447        current_name_scope.__exit__(*sys.exc_info())
2448        raise
2449      self._current_name_scope = current_name_scope
2450      unique_default_name = _get_unique_variable_scope(self._default_name)
2451      pure_variable_scope = _pure_variable_scope(
2452          unique_default_name,
2453          initializer=self._initializer,
2454          regularizer=self._regularizer,
2455          caching_device=self._caching_device,
2456          partitioner=self._partitioner,
2457          custom_getter=self._custom_getter,
2458          old_name_scope=current_name_scope_name,
2459          dtype=self._dtype,
2460          use_resource=self._use_resource,
2461          constraint=self._constraint)
2462      try:
2463        entered_pure_variable_scope = pure_variable_scope.__enter__()
2464      except:
2465        pure_variable_scope.__exit__(*sys.exc_info())
2466        raise
2467      self._cached_pure_variable_scope = pure_variable_scope
2468      return entered_pure_variable_scope
2469
2470  def __exit__(self, type_arg, value_arg, traceback_arg):
2471    try:
2472      self._cached_pure_variable_scope.__exit__(type_arg, value_arg,
2473                                                traceback_arg)
2474    finally:
2475      try:
2476        if self._current_name_scope:
2477          self._current_name_scope.__exit__(type_arg, value_arg,
2478                                            traceback_arg)
2479      finally:
2480        if self._in_graph_mode and not self._building_function:
2481          self._graph_context_manager.__exit__(type_arg, value_arg,
2482                                               traceback_arg)
2483
2484
2485# pylint: disable=g-doc-return-or-yield
2486@tf_export(v1=["variable_op_scope"])
2487@tf_contextlib.contextmanager
2488def variable_op_scope(values,
2489                      name_or_scope,
2490                      default_name=None,
2491                      initializer=None,
2492                      regularizer=None,
2493                      caching_device=None,
2494                      partitioner=None,
2495                      custom_getter=None,
2496                      reuse=None,
2497                      dtype=None,
2498                      use_resource=None,
2499                      constraint=None):
2500  """Deprecated: context manager for defining an op that creates variables."""
2501  logging.warn("tf.variable_op_scope(values, name, default_name) is deprecated,"
2502               " use tf.variable_scope(name, default_name, values)")
2503  with variable_scope(
2504      name_or_scope,
2505      default_name=default_name,
2506      values=values,
2507      initializer=initializer,
2508      regularizer=regularizer,
2509      caching_device=caching_device,
2510      partitioner=partitioner,
2511      custom_getter=custom_getter,
2512      reuse=reuse,
2513      dtype=dtype,
2514      use_resource=use_resource,
2515      constraint=constraint) as scope:
2516    yield scope
2517
2518
2519def _call_partitioner(partitioner, shape, dtype):
2520  """Call partitioner validating its inputs/output.
2521
2522  Args:
2523    partitioner: a function mapping `Tensor` shape and dtype to a list of
2524      partitions.
2525    shape: shape of the `Tensor` to partition, must have at least two
2526      dimensions.
2527    dtype: dtype of the elements in the `Tensor`.
2528
2529  Returns:
2530    A list with elements >=1 and exactly one >1. The index of that
2531    element corresponds to the partitioning axis.
2532  """
2533  if not shape.is_fully_defined():
2534    raise ValueError("Shape of a new partitioned variable must be "
2535                     "fully defined, but instead was %s." % (shape,))
2536  if shape.ndims < 1:
2537    raise ValueError("A partitioned Variable must have rank at least 1, "
2538                     "shape: %s" % shape)
2539
2540  slicing = partitioner(shape=shape, dtype=dtype)
2541  if not isinstance(slicing, collections_abc.Sequence):
2542    raise ValueError("Partitioner must return a sequence, but saw: %s" %
2543                     slicing)
2544  if len(slicing) != shape.ndims:
2545    raise ValueError(
2546        "Partitioner returned a partition list that does not match the "
2547        "Variable's rank: %s vs. %s" % (slicing, shape))
2548  if any(p < 1 for p in slicing):
2549    raise ValueError("Partitioner returned zero partitions for some axes: %s" %
2550                     slicing)
2551  if sum(p > 1 for p in slicing) > 1:
2552    raise ValueError("Can only slice a variable along one dimension: "
2553                     "shape: %s, partitioning: %s" % (shape, slicing))
2554  return slicing
2555
2556
2557# TODO(slebedev): could be inlined, but
2558# `_VariableStore._get_partitioned_variable` is too complex even
2559# without this logic.
2560def _get_slice_dim_and_num_slices(slicing):
2561  """Get slicing dimension and number of slices from the partitioner output."""
2562  for slice_dim, num_slices in enumerate(slicing):
2563    if num_slices > 1:
2564      break
2565  else:
2566    # Degenerate case: no partitioning applied.
2567    slice_dim = 0
2568    num_slices = 1
2569  return slice_dim, num_slices
2570
2571
2572def _iter_slices(full_shape, num_slices, slice_dim):
2573  """Slices a given a shape along the specified dimension."""
2574  num_slices_with_excess = full_shape[slice_dim] % num_slices
2575  offset = [0] * len(full_shape)
2576  min_slice_len = full_shape[slice_dim] // num_slices
2577  for i in xrange(num_slices):
2578    shape = full_shape[:]
2579    shape[slice_dim] = min_slice_len + bool(i < num_slices_with_excess)
2580    yield offset[:], shape
2581    offset[slice_dim] += shape[slice_dim]
2582
2583
2584def default_variable_creator(next_creator=None, **kwargs):
2585  """Default variable creator."""
2586  assert next_creator is None
2587  initial_value = kwargs.get("initial_value", None)
2588  trainable = kwargs.get("trainable", None)
2589  collections = kwargs.get("collections", None)
2590  validate_shape = kwargs.get("validate_shape", True)
2591  caching_device = kwargs.get("caching_device", None)
2592  name = kwargs.get("name", None)
2593  variable_def = kwargs.get("variable_def", None)
2594  dtype = kwargs.get("dtype", None)
2595  expected_shape = kwargs.get("expected_shape", None)
2596  import_scope = kwargs.get("import_scope", None)
2597  constraint = kwargs.get("constraint", None)
2598  use_resource = kwargs.get("use_resource", None)
2599  synchronization = kwargs.get("synchronization", None)
2600  aggregation = kwargs.get("aggregation", None)
2601  shape = kwargs.get("shape", None)
2602
2603  if use_resource is None:
2604    use_resource = get_variable_scope().use_resource
2605  if use_resource is None:
2606    use_resource = _DEFAULT_USE_RESOURCE
2607  use_resource = use_resource or context.executing_eagerly()
2608  if use_resource:
2609    distribute_strategy = kwargs.get("distribute_strategy", None)
2610    return resource_variable_ops.ResourceVariable(
2611        initial_value=initial_value,
2612        trainable=trainable,
2613        collections=collections,
2614        validate_shape=validate_shape,
2615        caching_device=caching_device,
2616        name=name,
2617        dtype=dtype,
2618        constraint=constraint,
2619        variable_def=variable_def,
2620        import_scope=import_scope,
2621        distribute_strategy=distribute_strategy,
2622        synchronization=synchronization,
2623        aggregation=aggregation,
2624        shape=shape)
2625  else:
2626    return variables.RefVariable(
2627        initial_value=initial_value,
2628        trainable=trainable,
2629        collections=collections,
2630        validate_shape=validate_shape,
2631        caching_device=caching_device,
2632        name=name,
2633        dtype=dtype,
2634        constraint=constraint,
2635        variable_def=variable_def,
2636        expected_shape=expected_shape,
2637        import_scope=import_scope,
2638        synchronization=synchronization,
2639        aggregation=aggregation,
2640        shape=shape)
2641
2642
2643def default_variable_creator_v2(next_creator=None, **kwargs):
2644  """Default variable creator."""
2645  assert next_creator is None
2646  initial_value = kwargs.get("initial_value", None)
2647  trainable = kwargs.get("trainable", None)
2648  validate_shape = kwargs.get("validate_shape", True)
2649  caching_device = kwargs.get("caching_device", None)
2650  name = kwargs.get("name", None)
2651  variable_def = kwargs.get("variable_def", None)
2652  dtype = kwargs.get("dtype", None)
2653  import_scope = kwargs.get("import_scope", None)
2654  constraint = kwargs.get("constraint", None)
2655  distribute_strategy = kwargs.get("distribute_strategy", None)
2656  synchronization = kwargs.get("synchronization", None)
2657  aggregation = kwargs.get("aggregation", None)
2658  shape = kwargs.get("shape", None)
2659
2660  return resource_variable_ops.ResourceVariable(
2661      initial_value=initial_value,
2662      trainable=trainable,
2663      validate_shape=validate_shape,
2664      caching_device=caching_device,
2665      name=name,
2666      dtype=dtype,
2667      constraint=constraint,
2668      variable_def=variable_def,
2669      import_scope=import_scope,
2670      distribute_strategy=distribute_strategy,
2671      synchronization=synchronization,
2672      aggregation=aggregation,
2673      shape=shape)
2674
2675
2676variables.default_variable_creator = default_variable_creator
2677variables.default_variable_creator_v2 = default_variable_creator_v2
2678
2679
2680def _make_getter(captured_getter, captured_previous):
2681  """Gets around capturing loop variables in python being broken."""
2682  return lambda **kwargs: captured_getter(captured_previous, **kwargs)
2683
2684
2685# TODO(apassos) remove forwarding symbol
2686variable = variables.VariableV1
2687
2688
2689@tf_export(v1=["variable_creator_scope"])
2690@tf_contextlib.contextmanager
2691def variable_creator_scope_v1(variable_creator):
2692  """Scope which defines a variable creation function to be used by variable().
2693
2694  variable_creator is expected to be a function with the following signature:
2695
2696  ```
2697    def variable_creator(next_creator, **kwargs)
2698  ```
2699
2700  The creator is supposed to eventually call the next_creator to create a
2701  variable if it does want to create a variable and not call Variable or
2702  ResourceVariable directly. This helps make creators composable. A creator may
2703  choose to create multiple variables, return already existing variables, or
2704  simply register that a variable was created and defer to the next creators in
2705  line. Creators can also modify the keyword arguments seen by the next
2706  creators.
2707
2708  Custom getters in the variable scope will eventually resolve down to these
2709  custom creators when they do create variables.
2710
2711  The valid keyword arguments in kwds are:
2712
2713   * initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
2714        which is the initial value for the Variable. The initial value must have
2715        a shape specified unless `validate_shape` is set to False. Can also be a
2716        callable with no argument that returns the initial value when called. In
2717        that case, `dtype` must be specified. (Note that initializer functions
2718        from init_ops.py must first be bound to a shape before being used here.)
2719   * trainable: If `True`, the default, also adds the variable to the graph
2720        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
2721        the default list of variables to use by the `Optimizer` classes.
2722        `trainable` defaults to `True`, unless `synchronization` is
2723        set to `ON_READ`, in which case it defaults to `False`.
2724   * collections: List of graph collections keys. The new variable is added to
2725        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
2726   * validate_shape: If `False`, allows the variable to be initialized with a
2727        value of unknown shape. If `True`, the default, the shape of
2728        `initial_value` must be known.
2729   * caching_device: Optional device string describing where the Variable
2730        should be cached for reading.  Defaults to the Variable's device.
2731        If not `None`, caches on another device.  Typical use is to cache
2732        on the device where the Ops using the Variable reside, to deduplicate
2733        copying through `Switch` and other conditional statements.
2734   * name: Optional name for the variable. Defaults to `'Variable'` and gets
2735        uniquified automatically.
2736   * dtype: If set, initial_value will be converted to the given type.
2737        If `None`, either the datatype will be kept (if `initial_value` is
2738        a Tensor), or `convert_to_tensor` will decide.
2739   * constraint: A constraint function to be applied to the variable after
2740        updates by some algorithms.
2741   * use_resource: if True, a ResourceVariable is always created.
2742   * synchronization: Indicates when a distributed a variable will be
2743        aggregated. Accepted values are constants defined in the class
2744        `tf.VariableSynchronization`. By default the synchronization is set to
2745        `AUTO` and the current `DistributionStrategy` chooses
2746        when to synchronize.
2747   * aggregation: Indicates how a distributed variable will be aggregated.
2748        Accepted values are constants defined in the class
2749        `tf.VariableAggregation`.
2750
2751  This set may grow over time, so it's important the signature of creators is as
2752  mentioned above.
2753
2754  Args:
2755    variable_creator: the passed creator
2756
2757  Yields:
2758    A scope in which the creator is active
2759  """
2760  with ops.get_default_graph()._variable_creator_scope(variable_creator):  # pylint: disable=protected-access
2761    yield
2762
2763
2764# Note: only the docstrings differ between this and v1.
2765@tf_export("variable_creator_scope", v1=[])
2766@tf_contextlib.contextmanager
2767def variable_creator_scope(variable_creator):
2768  """Scope which defines a variable creation function to be used by variable().
2769
2770  variable_creator is expected to be a function with the following signature:
2771
2772  ```
2773    def variable_creator(next_creator, **kwargs)
2774  ```
2775
2776  The creator is supposed to eventually call the next_creator to create a
2777  variable if it does want to create a variable and not call Variable or
2778  ResourceVariable directly. This helps make creators composable. A creator may
2779  choose to create multiple variables, return already existing variables, or
2780  simply register that a variable was created and defer to the next creators in
2781  line. Creators can also modify the keyword arguments seen by the next
2782  creators.
2783
2784  Custom getters in the variable scope will eventually resolve down to these
2785  custom creators when they do create variables.
2786
2787  The valid keyword arguments in kwds are:
2788
2789   * initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
2790        which is the initial value for the Variable. The initial value must have
2791        a shape specified unless `validate_shape` is set to False. Can also be a
2792        callable with no argument that returns the initial value when called. In
2793        that case, `dtype` must be specified. (Note that initializer functions
2794        from init_ops.py must first be bound to a shape before being used here.)
2795   * trainable: If `True`, the default, GradientTapes automatically watch
2796        uses of this Variable.
2797   * validate_shape: If `False`, allows the variable to be initialized with a
2798        value of unknown shape. If `True`, the default, the shape of
2799        `initial_value` must be known.
2800   * caching_device: Optional device string describing where the Variable
2801        should be cached for reading.  Defaults to the Variable's device.
2802        If not `None`, caches on another device.  Typical use is to cache
2803        on the device where the Ops using the Variable reside, to deduplicate
2804        copying through `Switch` and other conditional statements.
2805   * name: Optional name for the variable. Defaults to `'Variable'` and gets
2806        uniquified automatically.
2807      dtype: If set, initial_value will be converted to the given type.
2808        If `None`, either the datatype will be kept (if `initial_value` is
2809        a Tensor), or `convert_to_tensor` will decide.
2810   * constraint: A constraint function to be applied to the variable after
2811        updates by some algorithms.
2812   * synchronization: Indicates when a distributed a variable will be
2813        aggregated. Accepted values are constants defined in the class
2814        `tf.VariableSynchronization`. By default the synchronization is set to
2815        `AUTO` and the current `DistributionStrategy` chooses
2816        when to synchronize.
2817   * aggregation: Indicates how a distributed variable will be aggregated.
2818        Accepted values are constants defined in the class
2819        `tf.VariableAggregation`.
2820
2821  This set may grow over time, so it's important the signature of creators is as
2822  mentioned above.
2823
2824  Args:
2825    variable_creator: the passed creator
2826
2827  Yields:
2828    A scope in which the creator is active
2829  """
2830  with ops.get_default_graph()._variable_creator_scope(variable_creator):  # pylint: disable=protected-access
2831    yield
2832