• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A class to store named variables and a scope operator to manage sharing."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections as collections_lib
22import copy
23import enum  # pylint: disable=g-bad-import-order
24import functools
25import sys
26import threading
27import traceback
28
29import six
30from six import iteritems
31from six.moves import xrange, zip  # pylint: disable=redefined-builtin
32
33from tensorflow.python import tf2
34from tensorflow.python.client import session
35from tensorflow.python.eager import context
36from tensorflow.python.eager import monitoring
37from tensorflow.python.framework import dtypes
38from tensorflow.python.framework import ops
39from tensorflow.python.framework import tensor_shape
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import init_ops
42from tensorflow.python.ops import resource_variable_ops
43from tensorflow.python.ops import variables
44from tensorflow.python.platform import tf_logging as logging
45from tensorflow.python.util import deprecation
46from tensorflow.python.util import function_utils
47from tensorflow.python.util import tf_contextlib
48from tensorflow.python.util import tf_inspect
49from tensorflow.python.util.tf_export import tf_export
50
51__all__ = [
52    "AUTO_REUSE", "VariableScope", "get_variable_scope", "get_variable",
53    "get_local_variable", "variable_scope", "variable_op_scope",
54    "no_regularizer", "VariableSynchronization", "VariableAggregation"
55]
56
57_api_usage_gauge = monitoring.BoolGauge(
58    "/tensorflow/api/resource_variables",
59    "Whether variable_scope.enable_resource_variables() is called.")
60
61
62class _PartitionInfo(object):
63  """Holds partition info used by initializer functions."""
64
65  def __init__(self, full_shape, var_offset):
66    """Constructor.
67
68    Args:
69      full_shape: Tuple or list of `int` indicating the full combined shape of
70        the partitioned variables.
71      var_offset: Tuple or list of `int` specifying offset of this partition
72        with respect to the full variable for each dimension.
73
74    Raises:
75      TypeError: If `full_shape` or `var_offset` is not a sequence.
76      ValueError: If `full_shape` or `var_offset` differ in length. If
77        `var_offset` exceeds `full_shape` in any dimension.
78    """
79    if not isinstance(full_shape, collections_lib.Sequence) or isinstance(
80        full_shape, six.string_types):
81      raise TypeError(
82          "`full_shape` must be a sequence (like tuple or list) instead of " +
83          type(full_shape).__name__)
84
85    if not isinstance(var_offset, collections_lib.Sequence) or isinstance(
86        var_offset, six.string_types):
87      raise TypeError(
88          "`var_offset` must be a sequence (like tuple or list) instead of " +
89          type(var_offset).__name__)
90
91    if len(var_offset) != len(full_shape):
92      raise ValueError(
93          "Expected equal length, but `var_offset` is of length {} while "
94          "full_shape is of length {}.".format(
95              len(var_offset), len(full_shape)))
96
97    for offset, shape in zip(var_offset, full_shape):
98      if offset < 0 or offset >= shape:
99        raise ValueError(
100            "Expected 0 <= offset < shape but found offset={}, shape={} for "
101            "var_offset={}, full_shape={}".format(offset, shape, var_offset,
102                                                  full_shape))
103
104    self._full_shape = full_shape
105    self._var_offset = var_offset
106
107  @property
108  def full_shape(self):
109    return self._full_shape
110
111  @property
112  def var_offset(self):
113    return self._var_offset
114
115  def single_offset(self, shape):
116    """Returns the offset when the variable is partitioned in at most one dim.
117
118    Args:
119      shape: Tuple or list of `int` indicating the shape of one specific
120        variable partition.
121
122    Returns:
123      `int` representing the offset in the dimension along which the variable is
124       partitioned. Returns 0 if the variable is not being partitioned.
125
126    Raises:
127      ValueError: Depending on self.single_slice_dim().
128    """
129
130    single_slice_dim = self.single_slice_dim(shape)
131    # If this variable is not being partitioned at all, single_slice_dim() could
132    # return None.
133    if single_slice_dim is None:
134      return 0
135    return self.var_offset[single_slice_dim]
136
137  def single_slice_dim(self, shape):
138    """Returns the slice dim when the variable is partitioned only in one dim.
139
140    Args:
141      shape: Tuple or list of `int` indicating the shape of one specific
142        variable partition.
143
144    Returns:
145      `int` representing the dimension that the variable is partitioned in, or
146      `None` if the variable doesn't seem to be partitioned at all.
147
148    Raises:
149      TypeError: If `shape` is not a sequence.
150      ValueError: If `shape` is not the same length as `self.full_shape`. If
151        the variable is partitioned in more than one dimension.
152    """
153    if not isinstance(shape, collections_lib.Sequence) or isinstance(
154        shape, six.string_types):
155      raise TypeError(
156          "`shape` must be a sequence (like tuple or list) instead of " +
157          type(shape).__name__)
158
159    if len(shape) != len(self.full_shape):
160      raise ValueError(
161          "Expected equal length, but received shape={} of length {} while "
162          "self.full_shape={} is of length {}.".format(shape, len(shape),
163                                                       self.full_shape,
164                                                       len(self.full_shape)))
165
166    for i in xrange(len(shape)):
167      if self.var_offset[i] + shape[i] > self.full_shape[i]:
168        raise ValueError(
169            "With self.var_offset={}, a partition of shape={} would exceed "
170            "self.full_shape={} in dimension {}.".format(
171                self.var_offset, shape, self.full_shape, i))
172
173    slice_dim = None
174    for i in xrange(len(shape)):
175      if shape[i] == self.full_shape[i]:
176        continue
177      if slice_dim is not None:
178        raise ValueError(
179            "Cannot use single_slice_dim() with shape={} and "
180            "self.full_shape={} since slice dim could be either dimension {} "
181            "or {}.".format(shape, self.full_shape, i, slice_dim))
182      slice_dim = i
183
184    return slice_dim
185
186
187class _ReuseMode(enum.Enum):
188  """Mode for variable access within a variable scope."""
189
190  # Indicates that variables are to be fetched if they already exist or
191  # otherwise created.
192  AUTO_REUSE = 1
193
194  # TODO(alive): For TensorFlow 2.0, Deprecate True/False/None API in favor of
195  #              enum values.
196  # REUSE_FALSE = 2
197  # REUSE_TRUE = 3
198
199
200# TODO(apassos) remove these forwarding symbols.
201VariableSynchronization = variables.VariableSynchronization  # pylint: disable=invalid-name
202VariableAggregation = variables.VariableAggregation  # pylint: disable=invalid-name
203
204AUTO_REUSE = _ReuseMode.AUTO_REUSE
205tf_export(v1=["AUTO_REUSE"]).export_constant(__name__, "AUTO_REUSE")
206AUTO_REUSE.__doc__ = """
207When passed in as the value for the `reuse` flag, AUTO_REUSE indicates that
208get_variable() should create the requested variable if it doesn't exist or, if
209it does exist, simply return it.
210"""
211
212_DEFAULT_USE_RESOURCE = tf2.enabled()
213
214
215@tf_export(v1=["enable_resource_variables"])
216def enable_resource_variables():
217  """Creates resource variables by default.
218
219  Resource variables are improved versions of TensorFlow variables with a
220  well-defined memory model. Accessing a resource variable reads its value, and
221  all ops which access a specific read value of the variable are guaranteed to
222  see the same value for that tensor. Writes which happen after a read (by
223  having a control or data dependency on the read) are guaranteed not to affect
224  the value of the read tensor, and similarly writes which happen before a read
225  are guaranteed to affect the value. No guarantees are made about unordered
226  read/write pairs.
227
228  Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0
229  feature.
230  """
231  global _DEFAULT_USE_RESOURCE
232  _DEFAULT_USE_RESOURCE = True
233  _api_usage_gauge.get_cell().set(True)
234
235
236@tf_export(v1=["resource_variables_enabled"])
237def resource_variables_enabled():
238  """Returns `True` if resource variables are enabled.
239
240  Resource variables are improved versions of TensorFlow variables with a
241  well-defined memory model. Accessing a resource variable reads its value, and
242  all ops which access a specific read value of the variable are guaranteed to
243  see the same value for that tensor. Writes which happen after a read (by
244  having a control or data dependency on the read) are guaranteed not to affect
245  the value of the read tensor, and similarly writes which happen before a read
246  are guaranteed to affect the value. No guarantees are made about unordered
247  read/write pairs.
248
249  Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0
250  feature.
251  """
252  global _DEFAULT_USE_RESOURCE
253  return _DEFAULT_USE_RESOURCE
254
255
256@deprecation.deprecated(
257    None, "non-resource variables are not supported in the long term")
258@tf_export(v1=["disable_resource_variables"])
259def disable_resource_variables():
260  """Opts out of resource variables.
261
262  If your code needs tf.disable_resource_variables() to be called to work
263  properly please file a bug.
264  """
265  global _DEFAULT_USE_RESOURCE
266  _DEFAULT_USE_RESOURCE = False
267  _api_usage_gauge.get_cell().set(False)
268
269
270class _VariableStore(object):
271  """Variable store that carries a number of named Variables.
272
273  New variable names and new variables can be created; all stored
274  variables are initialized with the initializer passed to __init__.
275
276  Attributes:
277    vars: a dictionary with string names (same as passed in GetVar) as keys and
278      the corresponding TensorFlow Variables as values.
279  """
280
281  def __init__(self):
282    """Create a variable store."""
283    self._vars = {}  # A dictionary of the stored TensorFlow variables.
284    self._partitioned_vars = {}  # A dict of the stored PartitionedVariables.
285    self._store_eager_variables = False
286
287  def get_variable(self,
288                   name,
289                   shape=None,
290                   dtype=dtypes.float32,
291                   initializer=None,
292                   regularizer=None,
293                   reuse=None,
294                   trainable=None,
295                   collections=None,
296                   caching_device=None,
297                   partitioner=None,
298                   validate_shape=True,
299                   use_resource=None,
300                   custom_getter=None,
301                   constraint=None,
302                   synchronization=VariableSynchronization.AUTO,
303                   aggregation=VariableAggregation.NONE):
304    """Gets an existing variable with these parameters or create a new one.
305
306    If a variable with the given name is already stored, we return the stored
307    variable. Otherwise, we create a new one.
308
309    Set `reuse` to `True` when you only want to reuse existing Variables.
310    Set `reuse` to `False` when you only want to create new Variables.
311    Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want
312    variables to be created if they don't exist or returned if they do.
313
314    If initializer is `None` (the default), the default initializer passed in
315    the constructor is used. If that one is `None` too, we use a new
316    `glorot_uniform_initializer`. If initializer is a Tensor, we use
317    it as a value and derive the shape from the initializer.
318
319    If a partitioner is provided, a `PartitionedVariable` is returned.
320    Accessing this object as a `Tensor` returns the shards concatenated along
321    the partition axis.
322
323    Some useful partitioners are available.  See, e.g.,
324    `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
325
326    Args:
327      name: The name of the new or existing variable.
328      shape: Shape of the new or existing variable.
329      dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
330      initializer: Initializer for the variable.
331      regularizer: A (Tensor -> Tensor or None) function; the result of applying
332        it on a newly created variable will be added to the collection
333        GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
334      reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of
335        variables. When eager execution is enabled  this argument is always
336        forced to be False.
337      trainable: If `True` also add the variable to the graph collection
338        `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). `trainable`
339        defaults to `True`, unless `synchronization` is set to `ON_READ`, in
340        which case it defaults to `False`.
341      collections: List of graph collections keys to add the `Variable` to.
342        Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
343      caching_device: Optional device string or function describing where the
344        Variable should be cached for reading.  Defaults to the Variable's
345        device.  If not `None`, caches on another device.  Typical use is to
346        cache on the device where the Ops using the `Variable` reside, to
347        deduplicate copying through `Switch` and other conditional statements.
348      partitioner: Optional callable that accepts a fully defined `TensorShape`
349        and dtype of the `Variable` to be created, and returns a list of
350        partitions for each axis (currently only one axis can be partitioned).
351      validate_shape: If False, allows the variable to be initialized with a
352        value of unknown shape. If True, the default, the shape of initial_value
353        must be known.
354      use_resource: If False, creates a regular Variable. If True, creates
355        instead an experimental ResourceVariable which has well-defined
356        semantics. Defaults to False (will later change to True). When eager
357        execution is enabled this argument is always forced to be true.
358      custom_getter: Callable that takes as a first argument the true getter,
359        and allows overwriting the internal get_variable method. The signature
360        of `custom_getter` should match that of this method,
361        but the most future-proof version will allow for changes: `def
362          custom_getter(getter, *args, **kwargs)`.  Direct access to
363        all `get_variable` parameters is also allowed: `def
364          custom_getter(getter, name, *args, **kwargs)`.  A simple identity
365        custom getter that simply creates variables with modified names is:
366          ```python
367        def custom_getter(getter, name, *args, **kwargs): return getter(name +
368          '_suffix', *args, **kwargs) ```
369      constraint: An optional projection function to be applied to the variable
370        after being updated by an `Optimizer` (e.g. used to implement norm
371        constraints or value constraints for layer weights). The function must
372        take as input the unprojected Tensor representing the value of the
373        variable and return the Tensor for the projected value (which must have
374        the same shape). Constraints are not safe to use when doing asynchronous
375        distributed training.
376      synchronization: Indicates when a distributed a variable will be
377        aggregated. Accepted values are constants defined in the class
378        `tf.VariableSynchronization`. By default the synchronization is set to
379        `AUTO` and the current `DistributionStrategy` chooses when to
380        synchronize.
381      aggregation: Indicates how a distributed variable will be aggregated.
382        Accepted values are constants defined in the class
383        `tf.VariableAggregation`.
384
385    Returns:
386      The created or existing `Variable` (or `PartitionedVariable`, if a
387      partitioner was used).
388
389    Raises:
390      ValueError: when creating a new variable and shape is not declared,
391        when reusing a variable and specifying a conflicting shape,
392        or when violating reuse during variable creation.
393      RuntimeError: when eager execution is enabled and not called from an
394        EagerVariableStore.
395    """
396    if custom_getter is not None and not callable(custom_getter):
397      raise ValueError("Passed a custom_getter which is not callable: %s" %
398                       custom_getter)
399
400    with ops.init_scope():
401      if context.executing_eagerly():
402        # Variable creation and initialization takes place in `init_scope`s;
403        # as such, if an `init_scope` lifts us into the eager context, then we
404        # need to use `ResourceVariable`s.
405        use_resource = True
406
407    # Note that it's fine to reuse eager variables whose initialization was
408    # lifted from a function-building graph into the eager context (that's why
409    # the following clause is not wrapped in an `init_scope`); lifted variables
410    # are tracked by the graph's `VariableStore`.
411    if context.executing_eagerly():
412      if not self._store_eager_variables and reuse:
413        raise RuntimeError(
414            "When eager execution is enabled variable reuse is only supported"
415            " when an EagerVariableStore is active. See the documentation on"
416            " EagerVariableStore for example usage.")
417      if self._store_eager_variables:
418        reuse = AUTO_REUSE
419
420    # If a *_ref type is passed in an error would be triggered further down the
421    # stack. We prevent this using base_dtype to get a non-ref version of the
422    # type, before doing anything else. When _ref types are removed in favor of
423    # resources, this line can be removed.
424    try:
425      dtype = dtype.base_dtype
426    except AttributeError:
427      # .base_dtype not existing means that we will try and use the raw dtype
428      # which was passed in - this might be a NumPy type which is valid.
429      pass
430
431    # This is the main logic of get_variable.  However, custom_getter
432    # may override this logic.  So we save it as a callable and pass
433    # it to custom_getter.
434    # Note: the parameters of _true_getter, and their documentation, match
435    # *exactly* item-for-item with the docstring of this method.
436    def _true_getter(  # pylint: disable=missing-docstring
437        name,
438        shape=None,
439        dtype=dtypes.float32,
440        initializer=None,
441        regularizer=None,
442        reuse=None,
443        trainable=None,
444        collections=None,
445        caching_device=None,
446        partitioner=None,
447        validate_shape=True,
448        use_resource=None,
449        constraint=None,
450        synchronization=VariableSynchronization.AUTO,
451        aggregation=VariableAggregation.NONE):
452      is_scalar = (
453          shape is not None and isinstance(shape, collections_lib.Sequence) and
454          not shape)
455      # Partitioned variable case
456      if partitioner is not None and not is_scalar:
457        if not callable(partitioner):
458          raise ValueError("Partitioner must be callable, but received: %s" %
459                           partitioner)
460        with ops.name_scope(None):
461          return self._get_partitioned_variable(
462              name=name,
463              shape=shape,
464              dtype=dtype,
465              initializer=initializer,
466              regularizer=regularizer,
467              reuse=reuse,
468              trainable=trainable,
469              collections=collections,
470              caching_device=caching_device,
471              partitioner=partitioner,
472              validate_shape=validate_shape,
473              use_resource=use_resource,
474              constraint=constraint,
475              synchronization=synchronization,
476              aggregation=aggregation)
477
478      # Special case for partitioned variable to allow reuse without having to
479      # specify partitioner.
480      if (reuse is True and partitioner is None
481          and name in self._partitioned_vars):
482        return self._get_partitioned_variable(
483            name=name,
484            shape=shape,
485            dtype=dtype,
486            initializer=initializer,
487            regularizer=regularizer,
488            reuse=reuse,
489            trainable=trainable,
490            collections=collections,
491            caching_device=caching_device,
492            partitioner=None,
493            validate_shape=validate_shape,
494            use_resource=use_resource,
495            constraint=constraint,
496            synchronization=synchronization,
497            aggregation=aggregation)
498
499      # Single variable case
500      if "%s/part_0" % name in self._vars:
501        raise ValueError(
502            "No partitioner was provided, but a partitioned version of the "
503            "variable was found: %s/part_0. Perhaps a variable of the same "
504            "name was already created with partitioning?" % name)
505
506      return self._get_single_variable(
507          name=name,
508          shape=shape,
509          dtype=dtype,
510          initializer=initializer,
511          regularizer=regularizer,
512          reuse=reuse,
513          trainable=trainable,
514          collections=collections,
515          caching_device=caching_device,
516          validate_shape=validate_shape,
517          use_resource=use_resource,
518          constraint=constraint,
519          synchronization=synchronization,
520          aggregation=aggregation)
521
522    synchronization, aggregation, trainable = (
523        variables.validate_synchronization_aggregation_trainable(
524            synchronization, aggregation, trainable, name))
525
526    if custom_getter is not None:
527      # Handle backwards compatibility with getter arguments that were added
528      # to the API after users started writing custom getters.
529      custom_getter_kwargs = {
530          "getter": _true_getter,
531          "name": name,
532          "shape": shape,
533          "dtype": dtype,
534          "initializer": initializer,
535          "regularizer": regularizer,
536          "reuse": reuse,
537          "trainable": trainable,
538          "collections": collections,
539          "caching_device": caching_device,
540          "partitioner": partitioner,
541          "validate_shape": validate_shape,
542          "use_resource": use_resource,
543          "synchronization": synchronization,
544          "aggregation": aggregation,
545      }
546      # `fn_args` and `has_kwargs` can handle functions, `functools.partial`,
547      # `lambda`.
548      if ("constraint" in function_utils.fn_args(custom_getter) or
549          function_utils.has_kwargs(custom_getter)):
550        custom_getter_kwargs["constraint"] = constraint
551      return custom_getter(**custom_getter_kwargs)
552    else:
553      return _true_getter(
554          name,
555          shape=shape,
556          dtype=dtype,
557          initializer=initializer,
558          regularizer=regularizer,
559          reuse=reuse,
560          trainable=trainable,
561          collections=collections,
562          caching_device=caching_device,
563          partitioner=partitioner,
564          validate_shape=validate_shape,
565          use_resource=use_resource,
566          constraint=constraint,
567          synchronization=synchronization,
568          aggregation=aggregation)
569
570  def _get_partitioned_variable(self,
571                                name,
572                                partitioner,
573                                shape=None,
574                                dtype=dtypes.float32,
575                                initializer=None,
576                                regularizer=None,
577                                reuse=None,
578                                trainable=None,
579                                collections=None,
580                                caching_device=None,
581                                validate_shape=True,
582                                use_resource=None,
583                                constraint=None,
584                                synchronization=VariableSynchronization.AUTO,
585                                aggregation=VariableAggregation.NONE):
586    """Gets or creates a sharded variable list with these parameters.
587
588    The `partitioner` must be a callable that accepts a fully defined
589    `TensorShape` and returns a sequence of integers (the `partitions`).
590    These integers describe how to partition the given sharded `Variable`
591    along the given dimension.  That is, `partitions[1] = 3` means split
592    the `Variable` into 3 shards along dimension 1.  Currently, sharding along
593    only one axis is supported.
594
595    If the list of variables with the given name (prefix) is already stored,
596    we return the stored variables. Otherwise, we create a new one.
597
598    Set `reuse` to `True` when you only want to reuse existing Variables.
599    Set `reuse` to `False` when you only want to create new Variables.
600    Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want
601    variables to be created if they don't exist or returned if they do.
602
603    If initializer is `None` (the default), the default initializer passed in
604    the constructor is used. If that one is `None` too, we use a new
605    `glorot_uniform_initializer`. If initializer is a Tensor, we use
606    it as a value and derive the shape from the initializer.
607
608    If the initializer is a callable, then it will be called for each
609    shard.  Otherwise the initializer should match the shape of the entire
610    sharded Variable, and it will be sliced accordingly for each shard.
611
612    Some useful partitioners are available.  See, e.g.,
613    `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
614
615    Args:
616      name: the name of the new or existing sharded variable.
617      partitioner: Optional callable that accepts a fully defined `TensorShape`
618        and `dtype` of the Variable to be created, and returns a list of
619        partitions for each axis (currently only one axis can be partitioned).
620      shape: shape of the new or existing sharded variable.
621      dtype: type of the new or existing sharded variable (defaults to
622        `DT_FLOAT`).
623      initializer: initializer for the sharded variable.
624      regularizer: a (Tensor -> Tensor or None) function; the result of applying
625        it on a newly created variable will be added to the collection
626        GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
627      reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of
628        variables.
629      trainable: If `True` also add the variable to the graph collection
630        `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
631      collections: List of graph collections keys to add the Variable to.
632        Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
633      caching_device: Optional device string or function describing where the
634        Variable should be cached for reading.  Defaults to the Variable's
635        device.  If not `None`, caches on another device.  Typical use is to
636        cache on the device where the Ops using the Variable reside, to
637        deduplicate copying through `Switch` and other conditional statements.
638      validate_shape: If False, allows the variable to be initialized with a
639        value of unknown shape. If True, the default, the shape of initial_value
640        must be known.
641      use_resource: If False, creates a regular Variable. If True, creates an
642        experimental ResourceVariable which has well-defined semantics. Defaults
643        to False (will later change to True).
644      constraint: An optional projection function to be applied to the variable
645        after being updated by an `Optimizer` (e.g. used to implement norm
646        constraints or value constraints for layer weights). The function must
647        take as input the unprojected Tensor representing the value of the
648        variable and return the Tensor for the projected value (which must have
649        the same shape). Constraints are not safe to use when doing asynchronous
650        distributed training.
651      synchronization: Indicates when a distributed a variable will be
652        aggregated. Accepted values are constants defined in the class
653        `tf.VariableSynchronization`. By default the synchronization is set to
654        `AUTO` and the current `DistributionStrategy` chooses when to
655        synchronize.
656      aggregation: Indicates how a distributed variable will be aggregated.
657        Accepted values are constants defined in the class
658        `tf.VariableAggregation`.
659
660    Returns:
661      A `PartitionedVariable` object.
662
663    Raises:
664      ValueError: when creating a new variable and shape is not declared,
665        when reusing a variable and specifying a conflicting shape,
666        when violating reuse during variable creation, or if an existing
667        sharded variable exists for the given name but with different sharding.
668    """
669    initializing_from_value = initializer is not None and isinstance(
670        initializer, ops.Tensor)
671    if name in self._vars:
672      raise ValueError(
673          "A partitioner was provided, but an unpartitioned version of the "
674          "variable was found: %s.  Perhaps a variable of the same name was "
675          "already created without partitioning?" % name)
676
677    shape = tensor_shape.as_shape(shape)
678    if initializing_from_value:
679      shape = shape.merge_with(initializer.get_shape())
680
681    partitions = None
682    if not reuse or partitioner:
683      partitions = _call_partitioner(partitioner, shape, dtype)
684
685    if name in self._partitioned_vars:
686      if reuse is False:
687        raise ValueError(
688            "Partitioned variable with name %s already exists. Did you mean to "
689            "set reuse=True or reuse=tf.AUTO_REUSE in VarScope?" % name)
690
691      existing_var = self._partitioned_vars[name]
692      if not shape.is_compatible_with(existing_var.get_shape()):
693        raise ValueError(
694            "Trying to reuse partitioned variable %s, but specified shape %s "
695            "and found shape %s." % (name, shape, existing_var.get_shape()))
696      if not dtype.is_compatible_with(existing_var.dtype):
697        raise ValueError(
698            "Trying to reuse partitioned variable %s, but specified dtype %s "
699            "and found dtype %s." % (name, dtype.name, existing_var.dtype.name))
700
701      # pylint: disable=protected-access
702      if (partitions is not None and
703          existing_var._get_partitions() != partitions):
704        raise ValueError(
705            "Trying to reuse partitioned variable %s, but specified partitions "
706            "%s and found partitions %s." %
707            (name, partitions, existing_var._get_partitions()))
708      # pylint: enable=protected-access
709
710      return existing_var
711
712    if reuse is True:
713      raise ValueError("PartitionedVariable %s does not exist, or was not "
714                       "created with tf.get_variable(). Did you mean to set "
715                       "reuse=False or reuse=tf.AUTO_REUSE in VarScope?" % name)
716
717    slice_dim, num_slices = _get_slice_dim_and_num_slices(partitions)
718
719    if "%s/part_0" % name in self._vars:
720      if "%s/part_%d" % (name, num_slices - 1) not in self._vars:
721        raise ValueError(
722            "Partitioner returned a different partitioning than what was "
723            "already found.  Partitioner returned %d shards, and shard "
724            "%s/part_0 was found, but %s/part_%d was not." %
725            (num_slices, name, name, num_slices - 1))
726      if "%s/part_%d" % (name, num_slices) in self._vars:
727        raise ValueError(
728            "Partitioner returned a different partitioning than what was "
729            "already found.  Partitioner returned %d shards, and shard "
730            "%s/part_0 was found, but so was the extra shard %s/part_%d." %
731            (num_slices, name, name, num_slices))
732
733    vs = []
734    for i, (var_offset, var_shape) in enumerate(
735        _iter_slices(shape.as_list(), num_slices, slice_dim)):
736      partition_info = _PartitionInfo(
737          full_shape=shape.as_list(), var_offset=var_offset)
738      var_full_name = "%s/part_%d" % (name, i)
739      with ops.name_scope(
740          var_full_name + "/PartitionedInitializer", skip_on_eager=False):
741        # Create the tensor to initialize the variable with default value.
742        if initializer is None:
743          init, initializing_from_value = self._get_default_initializer(
744              name=name, shape=shape, dtype=dtype)
745          if initializing_from_value:
746            init_shape = None
747          else:
748            init_shape = var_shape
749        elif callable(initializer):
750          init = initializer
751          init_shape = var_shape
752        elif isinstance(initializer, ops.Tensor):
753          init = array_ops.slice(initializer, var_offset, var_shape)
754          # Use the dtype of the given tensor.
755          dtype = init.dtype.base_dtype
756          init_shape = None
757        else:
758          init = ops.convert_to_tensor(initializer, dtype=dtype)
759          init = array_ops.slice(init, var_offset, var_shape)
760          init_shape = None
761
762      with ops.name_scope(None):
763        var = self._get_single_variable(
764            name=var_full_name,
765            shape=init_shape,
766            dtype=dtype,
767            initializer=init,
768            partition_info=partition_info,
769            regularizer=regularizer,
770            reuse=reuse,
771            trainable=trainable,
772            collections=collections,
773            caching_device=caching_device,
774            validate_shape=validate_shape,
775            use_resource=use_resource,
776            constraint=constraint,
777            synchronization=synchronization,
778            aggregation=aggregation)
779
780      # pylint: disable=protected-access
781      var._set_save_slice_info(
782          variables.Variable.SaveSliceInfo(name, shape.as_list(), var_offset,
783                                           var_shape))
784      vs.append(var)
785      # pylint: enable=protected-access
786
787    partitioned_var = variables.PartitionedVariable(
788        name=name,
789        shape=shape,
790        dtype=dtype,
791        variable_list=vs,
792        partitions=partitions)
793    if not context.executing_eagerly() or self._store_eager_variables:
794      self._partitioned_vars[name] = partitioned_var
795    return partitioned_var
796
797  def _get_single_variable(self,
798                           name,
799                           shape=None,
800                           dtype=dtypes.float32,
801                           initializer=None,
802                           regularizer=None,
803                           partition_info=None,
804                           reuse=None,
805                           trainable=None,
806                           collections=None,
807                           caching_device=None,
808                           validate_shape=True,
809                           use_resource=None,
810                           constraint=None,
811                           synchronization=VariableSynchronization.AUTO,
812                           aggregation=VariableAggregation.NONE):
813    """Get or create a single Variable (e.g.
814
815    a shard or entire variable).
816
817    See the documentation of get_variable above (ignore partitioning components)
818    for details.
819
820    Args:
821      name: see get_variable.
822      shape: see get_variable.
823      dtype: see get_variable.
824      initializer: see get_variable.
825      regularizer: see get_variable.
826      partition_info: _PartitionInfo object.
827      reuse: see get_variable.
828      trainable: see get_variable.
829      collections: see get_variable.
830      caching_device: see get_variable.
831      validate_shape: see get_variable.
832      use_resource: see get_variable.
833      constraint: see get_variable.
834      synchronization: see get_variable.
835      aggregation: see get_variable.
836
837    Returns:
838      A Variable.  See documentation of get_variable above.
839
840    Raises:
841      ValueError: See documentation of get_variable above.
842    """
843    # Set to true if initializer is a constant.
844    initializing_from_value = False
845    if initializer is not None and not callable(initializer):
846      initializing_from_value = True
847    if shape is not None and initializing_from_value:
848      raise ValueError("If initializer is a constant, do not specify shape.")
849
850    dtype = dtypes.as_dtype(dtype)
851    shape = tensor_shape.as_shape(shape)
852
853    if name in self._vars:
854      # Here we handle the case when returning an existing variable.
855      if reuse is False:
856        var = self._vars[name]
857        err_msg = ("Variable %s already exists, disallowed."
858                   " Did you mean to set reuse=True or "
859                   "reuse=tf.AUTO_REUSE in VarScope?" % name)
860        # ResourceVariables don't have an op associated with so no traceback
861        if isinstance(var, resource_variable_ops.ResourceVariable):
862          raise ValueError(err_msg)
863        tb = var.op.traceback[::-1]
864        # Throw away internal tf entries and only take a few lines. In some
865        # cases the traceback can be longer (e.g. if someone uses factory
866        # functions to create variables) so we take more than needed in the
867        # default case.
868        tb = [x for x in tb if "tensorflow/python" not in x[0]][:5]
869        raise ValueError("%s Originally defined at:\n\n%s" %
870                         (err_msg, "".join(traceback.format_list(tb))))
871      found_var = self._vars[name]
872      if not shape.is_compatible_with(found_var.get_shape()):
873        raise ValueError("Trying to share variable %s, but specified shape %s"
874                         " and found shape %s." %
875                         (name, shape, found_var.get_shape()))
876      if not dtype.is_compatible_with(found_var.dtype):
877        dtype_str = dtype.name
878        found_type_str = found_var.dtype.name
879        raise ValueError("Trying to share variable %s, but specified dtype %s"
880                         " and found dtype %s." %
881                         (name, dtype_str, found_type_str))
882      return found_var
883
884    # The code below handles only the case of creating a new variable.
885    if reuse is True:
886      raise ValueError("Variable %s does not exist, or was not created with "
887                       "tf.get_variable(). Did you mean to set "
888                       "reuse=tf.AUTO_REUSE in VarScope?" % name)
889
890    # Create the tensor to initialize the variable with default value.
891    if initializer is None:
892      initializer, initializing_from_value = self._get_default_initializer(
893          name=name, shape=shape, dtype=dtype)
894    # Enter an init scope when creating the initializer.
895    with ops.init_scope():
896      if initializing_from_value:
897        init_val = initializer
898        variable_dtype = None
899      else:
900        # Instantiate initializer if provided initializer is a type object.
901        if tf_inspect.isclass(initializer):
902          initializer = initializer()
903        if shape is not None and shape.is_fully_defined():
904          if "partition_info" in tf_inspect.getargspec(initializer).args:
905            init_val = lambda: initializer(  # pylint: disable=g-long-lambda
906                shape.as_list(),
907                dtype=dtype,
908                partition_info=partition_info)
909          else:
910            init_val = lambda: initializer(  # pylint: disable=g-long-lambda
911                shape.as_list(), dtype=dtype)
912          variable_dtype = dtype.base_dtype
913        elif len(tf_inspect.getargspec(initializer).args) == len(
914            tf_inspect.getargspec(initializer).defaults or []):
915          init_val = initializer
916          variable_dtype = None
917        else:
918          raise ValueError("The initializer passed is not valid. It should "
919                           "be a callable with no arguments and the "
920                           "shape should not be provided or an instance of "
921                           "`tf.keras.initializers.*' and `shape` should be "
922                           "fully defined.")
923
924    # Create the variable.
925    if use_resource is None:
926      # Set the default value if unspecified.
927      use_resource = _DEFAULT_USE_RESOURCE
928    v = variables.VariableV1(
929        initial_value=init_val,
930        name=name,
931        trainable=trainable,
932        collections=collections,
933        caching_device=caching_device,
934        dtype=variable_dtype,
935        validate_shape=validate_shape,
936        constraint=constraint,
937        use_resource=use_resource,
938        synchronization=synchronization,
939        aggregation=aggregation)
940    if context.executing_eagerly() and self._store_eager_variables:
941      if collections:
942        ops.add_to_collections(collections, v)
943      else:
944        ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, v)
945      if trainable:
946        ops.add_to_collection(ops.GraphKeys.TRAINABLE_VARIABLES, v)
947
948    if not context.executing_eagerly() or self._store_eager_variables:
949      # In eager mode we do not want to keep default references to Variable
950      # objects as this will prevent their memory from being released.
951      self._vars[name] = v
952    logging.vlog(1, "Created variable %s with shape %s and init %s", v.name,
953                 format(shape), initializer)
954
955    # Run the regularizer if requested and save the resulting loss.
956    if regularizer:
957      def make_regularizer_op():
958        with ops.colocate_with(v):
959          with ops.name_scope(name + "/Regularizer/"):
960            return regularizer(v)
961
962      if regularizer(v) is not None:
963        lazy_eval_tensor = _LazyEvalTensor(make_regularizer_op)
964        ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES,
965                              lazy_eval_tensor)
966
967    return v
968
969  # Initialize variable when no initializer provided
970  def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32):
971    """Provide a default initializer and a corresponding value.
972
973    Args:
974      name: see get_variable.
975      shape: see get_variable.
976      dtype: see get_variable.
977
978    Returns:
979      initializer and initializing_from_value. See get_variable above.
980
981    Raises:
982      ValueError: When giving unsupported dtype.
983    """
984    del shape
985    # If dtype is DT_FLOAT, provide a uniform unit scaling initializer
986    if dtype.is_floating:
987      initializer = init_ops.glorot_uniform_initializer()
988      initializing_from_value = False
989    # If dtype is DT_INT/DT_UINT, provide a default value `zero`
990    # If dtype is DT_BOOL, provide a default value `FALSE`
991    elif (dtype.is_integer or dtype.is_unsigned or dtype.is_bool or
992          dtype == dtypes.string):
993      initializer = init_ops.zeros_initializer()
994      initializing_from_value = False
995    # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
996    else:
997      raise ValueError("An initializer for variable %s of %s is required" %
998                       (name, dtype.base_dtype))
999
1000    return initializer, initializing_from_value
1001
1002
1003class _LazyEvalTensor(object):
1004  """A Tensor-like object that only evaluates its thunk when used."""
1005
1006  def __init__(self, thunk):
1007    """Initializes a _LazyEvalTensor object.
1008
1009    Args:
1010      thunk: A callable. A thunk which computes the value of the tensor.
1011    """
1012    self._thunk = thunk
1013    self._master_tensor = thunk()
1014
1015  def _as_tensor(self, dtype=None, name=None, as_ref=False):
1016    del name
1017    assert not as_ref
1018    assert dtype in [None, self.dtype]
1019
1020    return self._thunk()
1021
1022
1023def _make_master_property(name):
1024  @property
1025  def prop(self):
1026    return getattr(self._master_tensor, name)  # pylint: disable=protected-access
1027  return prop
1028
1029_master_property_list = ("device", "dtype", "graph", "name", "op", "shape",
1030                         "value_index")
1031for _name in _master_property_list:
1032  setattr(_LazyEvalTensor, _name, _make_master_property(_name))
1033
1034
1035def _make_master_method(name):
1036  def method(self, *args, **kwargs):
1037    return getattr(self._master_tensor, name)(*args, **kwargs)  # pylint: disable=protected-access
1038  return method
1039
1040_master_method_list = ("get_shape", "__str__", "shape_as_list")
1041for _name in _master_method_list:
1042  setattr(_LazyEvalTensor, _name, _make_master_method(_name))
1043
1044
1045def _make_op_method(name):
1046  def method(self, *args, **kwargs):
1047    return getattr(self._as_tensor(), name)(*args, **kwargs)  # pylint: disable=protected-access
1048  return method
1049
1050_op_list = ("__abs__", "__add__", "__and__", "__bool__", "__div__", "__eq__",
1051            "__floordiv__", "__ge__", "__getitem__", "__gt__", "__invert__",
1052            "__iter__", "__le__", "__len__", "__lt__", "__matmul__", "__mod__",
1053            "__mul__", "__ne__", "__neg__", "__nonzero__", "__or__", "__pow__",
1054            "__radd__", "__rand__", "__rdiv__", "__rfloordiv__", "__rmatmul__",
1055            "__rmod__", "__rmul__", "__ror__", "__rpow__", "__rsub__",
1056            "__rtruediv__", "__rxor__", "__sub__", "__truediv__", "__xor__",
1057            "eval", "numpy")
1058for _name in _op_list:
1059  setattr(_LazyEvalTensor, _name, _make_op_method(_name))
1060
1061
1062ops.register_tensor_conversion_function(
1063    _LazyEvalTensor,
1064    lambda val, dtype, name, as_ref: val._as_tensor(dtype, name, as_ref)  # pylint: disable=protected-access
1065    )
1066
1067session.register_session_run_conversion_functions(
1068    _LazyEvalTensor,
1069    lambda fetch: ([fetch._master_tensor], lambda fetched_vals: fetched_vals[0])  # pylint: disable=protected-access
1070    )
1071
1072ops.register_dense_tensor_like_type(_LazyEvalTensor)
1073
1074
1075# To stop regularization, use this regularizer
1076@tf_export(v1=["no_regularizer"])
1077def no_regularizer(_):
1078  """Use this function to prevent regularization of variables."""
1079  return None
1080
1081
1082# TODO(alive): support caching devices and partitioned variables in Eager mode.
1083@tf_export(v1=["VariableScope"])
1084class VariableScope(object):
1085  """Variable scope object to carry defaults to provide to `get_variable`.
1086
1087  Many of the arguments we need for `get_variable` in a variable store are most
1088  easily handled with a context. This object is used for the defaults.
1089
1090  Attributes:
1091    name: name of the current scope, used as prefix in get_variable.
1092    initializer: default initializer passed to get_variable.
1093    regularizer: default regularizer passed to get_variable.
1094    reuse: Boolean, None, or tf.compat.v1.AUTO_REUSE, setting the reuse in
1095      get_variable. When eager execution is enabled this argument is always
1096      forced to be False.
1097    caching_device: string, callable, or None: the caching device passed to
1098      get_variable.
1099    partitioner: callable or `None`: the partitioner passed to `get_variable`.
1100    custom_getter: default custom getter passed to get_variable.
1101    name_scope: The name passed to `tf.name_scope`.
1102    dtype: default type passed to get_variable (defaults to DT_FLOAT).
1103    use_resource: if False, create a normal Variable; if True create an
1104      experimental ResourceVariable with well-defined semantics. Defaults to
1105      False (will later change to True). When eager execution is enabled this
1106      argument is always forced to be True.
1107    constraint: An optional projection function to be applied to the variable
1108      after being updated by an `Optimizer` (e.g. used to implement norm
1109      constraints or value constraints for layer weights). The function must
1110      take as input the unprojected Tensor representing the value of the
1111      variable and return the Tensor for the projected value (which must have
1112      the same shape). Constraints are not safe to use when doing asynchronous
1113      distributed training.
1114  """
1115
1116  def __init__(self,
1117               reuse,
1118               name="",
1119               initializer=None,
1120               regularizer=None,
1121               caching_device=None,
1122               partitioner=None,
1123               custom_getter=None,
1124               name_scope="",
1125               dtype=dtypes.float32,
1126               use_resource=None,
1127               constraint=None):
1128    """Creates a new VariableScope with the given properties."""
1129    self._name = name
1130    self._initializer = initializer
1131    self._regularizer = regularizer
1132    self._reuse = reuse
1133    self._caching_device = caching_device
1134    self._partitioner = partitioner
1135    self._custom_getter = custom_getter
1136    self._name_scope = name_scope
1137    self._dtype = dtype
1138    self._use_resource = use_resource
1139    self._constraint = constraint
1140    if context.executing_eagerly():
1141      if self._caching_device is not None:
1142        raise NotImplementedError("Caching devices is not yet supported "
1143                                  "when eager execution is enabled.")
1144      self._reuse = AUTO_REUSE
1145      self._use_resource = True
1146
1147  @property
1148  def name(self):
1149    return self._name
1150
1151  @property
1152  def original_name_scope(self):
1153    return self._name_scope
1154
1155  @property
1156  def reuse(self):
1157    return self._reuse
1158
1159  @property
1160  def initializer(self):
1161    return self._initializer
1162
1163  @property
1164  def dtype(self):
1165    return self._dtype
1166
1167  @property
1168  def use_resource(self):
1169    return self._use_resource
1170
1171  @property
1172  def regularizer(self):
1173    return self._regularizer
1174
1175  @property
1176  def caching_device(self):
1177    return self._caching_device
1178
1179  @property
1180  def partitioner(self):
1181    return self._partitioner
1182
1183  @property
1184  def custom_getter(self):
1185    return self._custom_getter
1186
1187  @property
1188  def constraint(self):
1189    return self._constraint
1190
1191  def reuse_variables(self):
1192    """Reuse variables in this scope."""
1193    self._reuse = True
1194
1195  def set_initializer(self, initializer):
1196    """Set initializer for this scope."""
1197    self._initializer = initializer
1198
1199  def set_dtype(self, dtype):
1200    """Set data type for this scope."""
1201    self._dtype = dtype
1202
1203  def set_use_resource(self, use_resource):
1204    """Sets whether to use ResourceVariables for this scope."""
1205    if context.executing_eagerly() and not use_resource:
1206      raise ValueError("When eager execution is enabled, "
1207                       "use_resource cannot be set to false.")
1208    self._use_resource = use_resource
1209
1210  def set_regularizer(self, regularizer):
1211    """Set regularizer for this scope."""
1212    self._regularizer = regularizer
1213
1214  def set_caching_device(self, caching_device):
1215    """Set caching_device for this scope."""
1216    if context.executing_eagerly():
1217      raise NotImplementedError("Caching devices are not yet supported "
1218                                "when eager execution is enabled.")
1219    self._caching_device = caching_device
1220
1221  def set_partitioner(self, partitioner):
1222    """Set partitioner for this scope."""
1223    self._partitioner = partitioner
1224
1225  def set_custom_getter(self, custom_getter):
1226    """Set custom getter for this scope."""
1227    self._custom_getter = custom_getter
1228
1229  def get_collection(self, name):
1230    """Get this scope's variables."""
1231    scope = self._name + "/" if self._name else ""
1232    return ops.get_collection(name, scope)
1233
1234  def trainable_variables(self):
1235    """Get this scope's trainable variables."""
1236    return self.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
1237
1238  def global_variables(self):
1239    """Get this scope's global variables."""
1240    return self.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
1241
1242  def local_variables(self):
1243    """Get this scope's local variables."""
1244    return self.get_collection(ops.GraphKeys.LOCAL_VARIABLES)
1245
1246  def get_variable(self,
1247                   var_store,
1248                   name,
1249                   shape=None,
1250                   dtype=None,
1251                   initializer=None,
1252                   regularizer=None,
1253                   reuse=None,
1254                   trainable=None,
1255                   collections=None,
1256                   caching_device=None,
1257                   partitioner=None,
1258                   validate_shape=True,
1259                   use_resource=None,
1260                   custom_getter=None,
1261                   constraint=None,
1262                   synchronization=VariableSynchronization.AUTO,
1263                   aggregation=VariableAggregation.NONE):
1264    """Gets an existing variable with this name or create a new one."""
1265    if regularizer is None:
1266      regularizer = self._regularizer
1267    if caching_device is None:
1268      caching_device = self._caching_device
1269    if partitioner is None:
1270      partitioner = self._partitioner
1271    if custom_getter is None:
1272      custom_getter = self._custom_getter
1273    if context.executing_eagerly():
1274      reuse = False
1275      use_resource = True
1276    else:
1277      if reuse is None:
1278        reuse = self._reuse
1279      if use_resource is None:
1280        use_resource = self._use_resource
1281
1282    full_name = self.name + "/" + name if self.name else name
1283    # Variable names only depend on variable_scope (full_name here),
1284    # not name_scope, so we reset it below for the time of variable creation.
1285    with ops.name_scope(None, skip_on_eager=False):
1286      # Check that `initializer` dtype and `dtype` are consistent before
1287      # replacing them with defaults.
1288      if (dtype is not None and initializer is not None and
1289          not callable(initializer)):
1290        init_dtype = ops.convert_to_tensor(initializer).dtype.base_dtype
1291        if init_dtype != dtype:
1292          raise ValueError("Initializer type '%s' and explicit dtype '%s' "
1293                           "don't match." % (init_dtype, dtype))
1294      if initializer is None:
1295        initializer = self._initializer
1296      if constraint is None:
1297        constraint = self._constraint
1298      if dtype is None:
1299        dtype = self._dtype
1300      return var_store.get_variable(
1301          full_name,
1302          shape=shape,
1303          dtype=dtype,
1304          initializer=initializer,
1305          regularizer=regularizer,
1306          reuse=reuse,
1307          trainable=trainable,
1308          collections=collections,
1309          caching_device=caching_device,
1310          partitioner=partitioner,
1311          validate_shape=validate_shape,
1312          use_resource=use_resource,
1313          custom_getter=custom_getter,
1314          constraint=constraint,
1315          synchronization=synchronization,
1316          aggregation=aggregation)
1317
1318  def _get_partitioned_variable(self,
1319                                var_store,
1320                                name,
1321                                shape=None,
1322                                dtype=None,
1323                                initializer=None,
1324                                regularizer=None,
1325                                trainable=None,
1326                                collections=None,
1327                                caching_device=None,
1328                                partitioner=None,
1329                                validate_shape=True,
1330                                use_resource=None,
1331                                constraint=None,
1332                                synchronization=VariableSynchronization.AUTO,
1333                                aggregation=VariableAggregation.NONE):
1334    """Gets an existing variable with this name or create a new one."""
1335    if initializer is None:
1336      initializer = self._initializer
1337    if regularizer is None:
1338      regularizer = self._regularizer
1339    if constraint is None:
1340      constraint = self._constraint
1341    if caching_device is None:
1342      caching_device = self._caching_device
1343    if partitioner is None:
1344      partitioner = self._partitioner
1345    if dtype is None:
1346      dtype = self._dtype
1347    if use_resource is None:
1348      use_resource = self._use_resource
1349
1350    if self._custom_getter is not None:
1351      raise ValueError(
1352          "Private access to _get_partitioned_variable is not allowed when "
1353          "a custom getter is set.  Current custom getter: %s.  "
1354          "It is likely that you're using create_partitioned_variables.  "
1355          "If so, consider instead using get_variable with a non-empty "
1356          "partitioner parameter instead." % self._custom_getter)
1357
1358    if partitioner is None:
1359      raise ValueError("No partitioner was specified")
1360
1361    # This allows the variable scope name to be used as the variable name if
1362    # this function is invoked with an empty name arg, for backward
1363    # compatibility with create_partitioned_variables().
1364    full_name_list = []
1365    if self.name:
1366      full_name_list.append(self.name)
1367    if name:
1368      full_name_list.append(name)
1369    full_name = "/".join(full_name_list)
1370
1371    # Variable names only depend on variable_scope (full_name here),
1372    # not name_scope, so we reset it below for the time of variable creation.
1373    with ops.name_scope(None, skip_on_eager=False):
1374      # pylint: disable=protected-access
1375      return var_store._get_partitioned_variable(
1376          full_name,
1377          shape=shape,
1378          dtype=dtype,
1379          initializer=initializer,
1380          regularizer=regularizer,
1381          reuse=self.reuse,
1382          trainable=trainable,
1383          collections=collections,
1384          caching_device=caching_device,
1385          partitioner=partitioner,
1386          validate_shape=validate_shape,
1387          use_resource=use_resource,
1388          constraint=constraint,
1389          synchronization=synchronization,
1390          aggregation=aggregation)
1391      # pylint: enable=protected-access
1392
1393
1394_VARSTORE_KEY = ("__variable_store",)
1395_VARSCOPESTORE_KEY = ("__varscope",)
1396
1397
1398class _VariableScopeStore(threading.local):
1399  """A thread local store for the current variable scope and scope counts."""
1400
1401  def __init__(self):
1402    super(_VariableScopeStore, self).__init__()
1403    self.current_scope = VariableScope(False)
1404    self.variable_scopes_count = {}
1405
1406  def open_variable_scope(self, scope_name):
1407    if scope_name in self.variable_scopes_count:
1408      self.variable_scopes_count[scope_name] += 1
1409    else:
1410      self.variable_scopes_count[scope_name] = 1
1411
1412  def close_variable_subscopes(self, scope_name):
1413    for k in list(self.variable_scopes_count.keys()):
1414      if scope_name is None or k.startswith(scope_name + "/"):
1415        self.variable_scopes_count[k] = 0
1416
1417  def variable_scope_count(self, scope_name):
1418    return self.variable_scopes_count.get(scope_name, 0)
1419
1420
1421def get_variable_scope_store():
1422  """Returns the variable scope store for current thread."""
1423  scope_store = ops.get_collection(_VARSCOPESTORE_KEY)
1424
1425  if not scope_store:
1426    scope_store = _VariableScopeStore()
1427    ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store)
1428  else:
1429    scope_store = scope_store[0]
1430
1431  return scope_store
1432
1433
1434@tf_export(v1=["get_variable_scope"])
1435def get_variable_scope():
1436  """Returns the current variable scope."""
1437  return get_variable_scope_store().current_scope
1438
1439
1440def _get_default_variable_store():
1441  store = ops.get_collection(_VARSTORE_KEY)
1442  if store:
1443    return store[0]
1444  store = _VariableStore()
1445  ops.add_to_collection(_VARSTORE_KEY, store)
1446  return store
1447
1448
1449@tf_contextlib.contextmanager
1450def with_variable_store(store):
1451  store_collection = ops.get_collection_ref(_VARSTORE_KEY)
1452  old = list(store_collection)
1453  store_collection[:] = [store]
1454  try:
1455    yield
1456  finally:
1457    store_collection[:] = old
1458
1459
1460class EagerVariableStore(object):
1461  """Wrapper allowing functional layers to be used with eager execution.
1462
1463  When eager execution is enabled Variables get deleted when they go out of
1464  scope, and are not stored in global collections by default. A lot of code
1465  (mostly the functional layers in tf.layers) assumes that variables are kept in
1466  a global list.
1467
1468  EagerVariableStore can be used in conjunction with this code to make it
1469  eager-friendly. For example, to create a dense layer, use:
1470
1471  ```
1472    container = tfe.EagerVariableStore()
1473    for input in dataset_iterator:
1474      with container.as_default():
1475        x = tf.compat.v1.layers.dense(input, name="l1")
1476    print(container.variables)  # Should print the variables used in the layer.
1477  ```
1478  """
1479
1480  def __init__(self, store=None):
1481    if store is not None:
1482      if not store._store_eager_variables:  # pylint: disable=protected-access
1483        raise ValueError("Cannot construct EagerVariableStore from a "
1484                         "VariableStore object that does not hold eager "
1485                         "variables.")
1486      self._store = store
1487    else:
1488      self._store = _VariableStore()
1489    self._store._store_eager_variables = True  # pylint: disable=protected-access
1490
1491  def as_default(self):
1492    return with_variable_store(self._store)
1493
1494  def variables(self):
1495    return sorted(self._store._vars.values(), key=lambda x: x.name)  # pylint: disable=protected-access
1496
1497  def trainable_variables(self):
1498    # pylint: disable=protected-access
1499    return sorted([x for x in self._store._vars.values() if x.trainable],
1500                  key=lambda x: x.name)
1501    # pylint: enable=protected-access
1502
1503  def non_trainable_variables(self):
1504    # pylint: disable=protected-access
1505    return sorted([x for x in self._store._vars.values() if not x.trainable],
1506                  key=lambda x: x.name)
1507    # pylint: enable=protected-access
1508
1509  def copy(self):
1510    """Copy this variable store and all of its contents.
1511
1512    Variables contained in this store will be copied over to the new variable
1513    store, meaning that they can be modified without affecting the variables in
1514    this store.
1515
1516    Returns:
1517      A new EagerVariableStore instance containing copied variables.
1518    """
1519    # pylint: disable=protected-access
1520    new_store = EagerVariableStore()
1521    for key, var in iteritems(self._store._vars):
1522      # Strip device out of variable name.
1523      try:
1524        index = var.name.index(":")
1525      except ValueError:
1526        stripped_var_name = var.name
1527      else:
1528        stripped_var_name = var.name[:index]
1529
1530      # Create new variable with same value, name, and "trainable" flag.
1531      new_var = resource_variable_ops.ResourceVariable(
1532          var.read_value(), name=stripped_var_name, trainable=var.trainable)
1533      new_store._store._vars[key] = new_var
1534    return new_store
1535    # pylint: enable=protected-access
1536
1537
1538# The argument list for get_variable must match arguments to get_local_variable.
1539# So, if you are updating the arguments, also update arguments to
1540# get_local_variable below.
1541@tf_export(v1=["get_variable"])
1542def get_variable(name,
1543                 shape=None,
1544                 dtype=None,
1545                 initializer=None,
1546                 regularizer=None,
1547                 trainable=None,
1548                 collections=None,
1549                 caching_device=None,
1550                 partitioner=None,
1551                 validate_shape=True,
1552                 use_resource=None,
1553                 custom_getter=None,
1554                 constraint=None,
1555                 synchronization=VariableSynchronization.AUTO,
1556                 aggregation=VariableAggregation.NONE):
1557  return get_variable_scope().get_variable(
1558      _get_default_variable_store(),
1559      name,
1560      shape=shape,
1561      dtype=dtype,
1562      initializer=initializer,
1563      regularizer=regularizer,
1564      trainable=trainable,
1565      collections=collections,
1566      caching_device=caching_device,
1567      partitioner=partitioner,
1568      validate_shape=validate_shape,
1569      use_resource=use_resource,
1570      custom_getter=custom_getter,
1571      constraint=constraint,
1572      synchronization=synchronization,
1573      aggregation=aggregation)
1574
1575
1576get_variable_or_local_docstring = ("""%s
1577
1578%sThis function prefixes the name with the current variable scope
1579and performs reuse checks. See the
1580[Variable Scope How To](https://tensorflow.org/guide/variables)
1581for an extensive description of how reusing works. Here is a basic example:
1582
1583```python
1584def foo():
1585  with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
1586    v = tf.get_variable("v", [1])
1587  return v
1588
1589v1 = foo()  # Creates v.
1590v2 = foo()  # Gets the same, existing v.
1591assert v1 == v2
1592```
1593
1594If initializer is `None` (the default), the default initializer passed in
1595the variable scope will be used. If that one is `None` too, a
1596`glorot_uniform_initializer` will be used. The initializer can also be
1597a Tensor, in which case the variable is initialized to this value and shape.
1598
1599Similarly, if the regularizer is `None` (the default), the default regularizer
1600passed in the variable scope will be used (if that is `None` too,
1601then by default no regularization is performed).
1602
1603If a partitioner is provided, a `PartitionedVariable` is returned.
1604Accessing this object as a `Tensor` returns the shards concatenated along
1605the partition axis.
1606
1607Some useful partitioners are available.  See, e.g.,
1608`variable_axis_size_partitioner` and `min_max_variable_partitioner`.
1609
1610Args:
1611  name: The name of the new or existing variable.
1612  shape: Shape of the new or existing variable.
1613  dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
1614  initializer: Initializer for the variable if one is created. Can either be
1615    an initializer object or a Tensor. If it's a Tensor, its shape must be known
1616    unless validate_shape is False.
1617  regularizer: A (Tensor -> Tensor or None) function; the result of
1618    applying it on a newly created variable will be added to the collection
1619    `tf.GraphKeys.REGULARIZATION_LOSSES` and can be used for regularization.
1620  %scollections: List of graph collections keys to add the Variable to.
1621    Defaults to `[%s]` (see `tf.Variable`).
1622  caching_device: Optional device string or function describing where the
1623    Variable should be cached for reading.  Defaults to the Variable's
1624    device.  If not `None`, caches on another device.  Typical use is to
1625    cache on the device where the Ops using the Variable reside, to
1626    deduplicate copying through `Switch` and other conditional statements.
1627  partitioner: Optional callable that accepts a fully defined `TensorShape`
1628    and `dtype` of the Variable to be created, and returns a list of
1629    partitions for each axis (currently only one axis can be partitioned).
1630  validate_shape: If False, allows the variable to be initialized with a
1631      value of unknown shape. If True, the default, the shape of initial_value
1632      must be known. For this to be used the initializer must be a Tensor and
1633      not an initializer object.
1634  use_resource: If False, creates a regular Variable. If true, creates an
1635    experimental ResourceVariable instead with well-defined semantics.
1636    Defaults to False (will later change to True). When eager execution is
1637    enabled this argument is always forced to be True.
1638  custom_getter: Callable that takes as a first argument the true getter, and
1639    allows overwriting the internal get_variable method.
1640    The signature of `custom_getter` should match that of this method,
1641    but the most future-proof version will allow for changes:
1642    `def custom_getter(getter, *args, **kwargs)`.  Direct access to
1643    all `get_variable` parameters is also allowed:
1644    `def custom_getter(getter, name, *args, **kwargs)`.  A simple identity
1645    custom getter that simply creates variables with modified names is:
1646    ```python
1647    def custom_getter(getter, name, *args, **kwargs):
1648      return getter(name + '_suffix', *args, **kwargs)
1649    ```
1650  constraint: An optional projection function to be applied to the variable
1651    after being updated by an `Optimizer` (e.g. used to implement norm
1652    constraints or value constraints for layer weights). The function must
1653    take as input the unprojected Tensor representing the value of the
1654    variable and return the Tensor for the projected value
1655    (which must have the same shape). Constraints are not safe to
1656    use when doing asynchronous distributed training.
1657  synchronization: Indicates when a distributed a variable will be
1658    aggregated. Accepted values are constants defined in the class
1659    `tf.VariableSynchronization`. By default the synchronization is set to
1660    `AUTO` and the current `DistributionStrategy` chooses
1661    when to synchronize.
1662  aggregation: Indicates how a distributed variable will be aggregated.
1663    Accepted values are constants defined in the class
1664    `tf.VariableAggregation`.
1665
1666Returns:
1667  The created or existing `Variable` (or `PartitionedVariable`, if a
1668  partitioner was used).
1669
1670Raises:
1671  ValueError: when creating a new variable and shape is not declared,
1672    when violating reuse during variable creation, or when `initializer` dtype
1673    and `dtype` don't match. Reuse is set inside `variable_scope`.
1674""")
1675get_variable.__doc__ = get_variable_or_local_docstring % (
1676    "Gets an existing variable with these parameters or create a new one.", "",
1677    "trainable: If `True` also add the variable to the graph collection\n"
1678    "    `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).\n  ",
1679    "GraphKeys.GLOBAL_VARIABLES")
1680
1681
1682# The argument list for get_local_variable must match arguments to get_variable.
1683# So, if you are updating the arguments, also update arguments to get_variable.
1684@tf_export(v1=["get_local_variable"])
1685def get_local_variable(  # pylint: disable=missing-docstring
1686    name,
1687    shape=None,
1688    dtype=None,
1689    initializer=None,
1690    regularizer=None,
1691    trainable=False,  # pylint: disable=unused-argument
1692    collections=None,
1693    caching_device=None,
1694    partitioner=None,
1695    validate_shape=True,
1696    use_resource=None,
1697    custom_getter=None,
1698    constraint=None,
1699    synchronization=VariableSynchronization.AUTO,
1700    aggregation=VariableAggregation.NONE):
1701  if collections:
1702    collections += [ops.GraphKeys.LOCAL_VARIABLES]
1703  else:
1704    collections = [ops.GraphKeys.LOCAL_VARIABLES]
1705  return get_variable(
1706      name,
1707      shape=shape,
1708      dtype=dtype,
1709      initializer=initializer,
1710      regularizer=regularizer,
1711      trainable=False,
1712      collections=collections,
1713      caching_device=caching_device,
1714      partitioner=partitioner,
1715      validate_shape=validate_shape,
1716      use_resource=use_resource,
1717      synchronization=synchronization,
1718      aggregation=aggregation,
1719      custom_getter=custom_getter,
1720      constraint=constraint)
1721
1722
1723get_local_variable.__doc__ = get_variable_or_local_docstring % (
1724    "Gets an existing *local* variable or creates a new one.",
1725    "Behavior is the same as in `get_variable`, except that variables are\n"
1726    "added to the `LOCAL_VARIABLES` collection and `trainable` is set to\n"
1727    "`False`.\n", "", "GraphKeys.LOCAL_VARIABLES")
1728
1729
1730def _get_partitioned_variable(name,
1731                              shape=None,
1732                              dtype=None,
1733                              initializer=None,
1734                              regularizer=None,
1735                              trainable=True,
1736                              collections=None,
1737                              caching_device=None,
1738                              partitioner=None,
1739                              validate_shape=True,
1740                              use_resource=None,
1741                              constraint=None,
1742                              synchronization=VariableSynchronization.AUTO,
1743                              aggregation=VariableAggregation.NONE):
1744  """Gets or creates a sharded variable list with these parameters.
1745
1746  The `partitioner` must be a callable that accepts a fully defined
1747  `TensorShape` and returns a sequence of integers (the `partitions`).
1748  These integers describe how to partition the given sharded `Variable`
1749  along the given dimension.  That is, `partitions[1] = 3` means split
1750  the `Variable` into 3 shards along dimension 1.  Currently, sharding along
1751  only one axis is supported.
1752
1753  If the list of variables with the given name (prefix) is already stored,
1754  we return the stored variables. Otherwise, we create a new one.
1755
1756  If initializer is `None` (the default), the default initializer passed in
1757  the constructor is used. If that one is `None` too, we use a new
1758  `glorot_uniform_initializer`. If initializer is a Tensor, we use
1759  it as a value and derive the shape from the initializer.
1760
1761  If the initializer is a callable, then it will be called for each
1762  shard.  Otherwise the initializer should match the shape of the entire
1763  sharded Variable, and it will be sliced accordingly for each shard.
1764
1765  Some useful partitioners are available.  See, e.g.,
1766  `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
1767
1768  Args:
1769    name: The name of the new or existing variable.
1770    shape: Shape of the new or existing variable.
1771    dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
1772    initializer: Initializer for the variable if one is created.
1773    regularizer: A (Tensor -> Tensor or None) function; the result of applying
1774      it on a newly created variable will be added to the collection
1775      GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
1776    trainable: If `True` also add the variable to the graph collection
1777      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
1778    collections: List of graph collections keys to add the Variable to. Defaults
1779      to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
1780    caching_device: Optional device string or function describing where the
1781      Variable should be cached for reading.  Defaults to the Variable's device.
1782      If not `None`, caches on another device.  Typical use is to cache on the
1783      device where the Ops using the Variable reside, to deduplicate copying
1784      through `Switch` and other conditional statements.
1785    partitioner: Optional callable that accepts a fully defined `TensorShape`
1786      and `dtype` of the Variable to be created, and returns a list of
1787      partitions for each axis (currently only one axis can be partitioned).
1788    validate_shape: If False, allows the variable to be initialized with a value
1789      of unknown shape. If True, the default, the shape of initial_value must be
1790      known.
1791    use_resource: If False, creates a regular Variable. If True, creates an
1792      experimental ResourceVariable instead which has well-defined semantics.
1793      Defaults to False (will later change to True).
1794    constraint: An optional projection function to be applied to the variable
1795      after being updated by an `Optimizer` (e.g. used to implement norm
1796      constraints or value constraints for layer weights). The function must
1797      take as input the unprojected Tensor representing the value of the
1798      variable and return the Tensor for the projected value (which must have
1799      the same shape). Constraints are not safe to use when doing asynchronous
1800      distributed training.
1801    synchronization: Indicates when a distributed a variable will be aggregated.
1802      Accepted values are constants defined in the class
1803      `tf.VariableSynchronization`. By default the synchronization is set to
1804      `AUTO` and the current `DistributionStrategy` chooses when to synchronize.
1805    aggregation: Indicates how a distributed variable will be aggregated.
1806      Accepted values are constants defined in the class
1807      `tf.VariableAggregation`.
1808
1809  Returns:
1810    A tuple `(shards, partitions)` where `shards` is the list of `Variable`
1811    shards and `partitions` is the output of the partitioner on the input
1812    shape.
1813
1814  Raises:
1815    ValueError: when creating a new variable and shape is not declared,
1816      or when violating reuse during variable creation. Reuse is set inside
1817      `variable_scope`.
1818  """
1819  # pylint: disable=protected-access
1820  scope = get_variable_scope()
1821  if scope.custom_getter is not None:
1822    raise ValueError(
1823        "Private access to _get_partitioned_variable is not allowed when "
1824        "a custom getter is set.  Current custom getter: %s.  "
1825        "It is likely that you're using create_partitioned_variables.  "
1826        "If so, consider instead using get_variable with a non-empty "
1827        "partitioner parameter instead." % scope.custom_getter)
1828  return scope._get_partitioned_variable(
1829      _get_default_variable_store(),
1830      name,
1831      shape=shape,
1832      dtype=dtype,
1833      initializer=initializer,
1834      regularizer=regularizer,
1835      trainable=trainable,
1836      collections=collections,
1837      caching_device=caching_device,
1838      partitioner=partitioner,
1839      validate_shape=validate_shape,
1840      use_resource=use_resource,
1841      constraint=constraint,
1842      synchronization=synchronization,
1843      aggregation=aggregation)
1844  # pylint: enable=protected-access
1845
1846
1847# Named like a function for compatibility with the previous
1848# @tf_contextlib.contextmanager definition.
1849class _pure_variable_scope(object):  # pylint: disable=invalid-name
1850  """A context for the variable_scope, see `variable_scope` for docs."""
1851
1852  def __init__(self,
1853               name_or_scope,
1854               reuse=None,
1855               initializer=None,
1856               regularizer=None,
1857               caching_device=None,
1858               partitioner=None,
1859               custom_getter=None,
1860               old_name_scope=None,
1861               dtype=dtypes.float32,
1862               use_resource=None,
1863               constraint=None):
1864    """Creates a context for the variable_scope, see `variable_scope` for docs.
1865
1866    Note: this does not create a name scope.
1867
1868    Args:
1869      name_or_scope: `string` or `VariableScope`: the scope to open.
1870      reuse: `True` or None, or tf.compat.v1.AUTO_REUSE; if `None`, we inherit
1871        the parent scope's reuse flag.
1872      initializer: default initializer for variables within this scope.
1873      regularizer: default regularizer for variables within this scope.
1874      caching_device: default caching device for variables within this scope.
1875      partitioner: default partitioner for variables within this scope.
1876      custom_getter: default custom getter for variables within this scope.
1877      old_name_scope: the original name scope when re-entering a variable scope.
1878      dtype: type of the variables within this scope (defaults to `DT_FLOAT`).
1879      use_resource: If False, variables in this scope will be regular Variables.
1880        If True, experimental ResourceVariables will be creates instead, with
1881        well-defined semantics. Defaults to False (will later change to True).
1882      constraint: An optional projection function to be applied to the variable
1883        after being updated by an `Optimizer` (e.g. used to implement norm
1884        constraints or value constraints for layer weights). The function must
1885        take as input the unprojected Tensor representing the value of the
1886        variable and return the Tensor for the projected value (which must have
1887        the same shape). Constraints are not safe to use when doing asynchronous
1888        distributed training.
1889    """
1890    self._name_or_scope = name_or_scope
1891    self._reuse = reuse
1892    self._initializer = initializer
1893    self._regularizer = regularizer
1894    self._caching_device = caching_device
1895    self._partitioner = partitioner
1896    self._custom_getter = custom_getter
1897    self._old_name_scope = old_name_scope
1898    self._dtype = dtype
1899    self._use_resource = use_resource
1900    self._constraint = constraint
1901    self._var_store = _get_default_variable_store()
1902    self._var_scope_store = get_variable_scope_store()
1903    self._last_variable_scope_object = None
1904    if isinstance(self._name_or_scope, VariableScope):
1905      self._new_name = self._name_or_scope.name
1906      name_scope = self._name_or_scope._name_scope  # pylint: disable=protected-access
1907      # Handler for the case when we jump to a shared scope.  We create a new
1908      #   VariableScope (self._var_scope_object) that contains a copy of the
1909      #   provided shared scope, possibly with changed reuse and initializer, if
1910      #   the user requested this.
1911      variable_scope_object = VariableScope(
1912          self._name_or_scope.reuse if not self._reuse else self._reuse,
1913          name=self._new_name,
1914          initializer=self._name_or_scope.initializer,
1915          regularizer=self._name_or_scope.regularizer,
1916          caching_device=self._name_or_scope.caching_device,
1917          partitioner=self._name_or_scope.partitioner,
1918          dtype=self._name_or_scope.dtype,
1919          custom_getter=self._name_or_scope.custom_getter,
1920          name_scope=name_scope,
1921          use_resource=self._name_or_scope.use_resource,
1922          constraint=self._constraint)
1923      if self._initializer is not None:
1924        variable_scope_object.set_initializer(self._initializer)
1925      if self._regularizer is not None:
1926        variable_scope_object.set_regularizer(self._regularizer)
1927      if self._caching_device is not None:
1928        variable_scope_object.set_caching_device(self._caching_device)
1929      if self._partitioner is not None:
1930        variable_scope_object.set_partitioner(self._partitioner)
1931      if self._custom_getter is not None:
1932        variable_scope_object.set_custom_getter(
1933            _maybe_wrap_custom_getter(self._custom_getter,
1934                                      self._name_or_scope.custom_getter))
1935      if self._dtype is not None:
1936        variable_scope_object.set_dtype(self._dtype)
1937      if self._use_resource is not None:
1938        variable_scope_object.set_use_resource(self._use_resource)
1939      self._cached_variable_scope_object = variable_scope_object
1940
1941  def __enter__(self):
1942    """Begins the scope block.
1943
1944    Returns:
1945      A VariableScope.
1946    Raises:
1947      ValueError: when trying to reuse within a create scope, or create within
1948        a reuse scope, or if reuse is not `None` or `True`.
1949      TypeError: when the types of some arguments are not appropriate.
1950    """
1951    self._old = self._var_scope_store.current_scope
1952    if isinstance(self._name_or_scope, VariableScope):
1953      self._var_scope_store.open_variable_scope(self._new_name)
1954      self._old_subscopes = copy.copy(
1955          self._var_scope_store.variable_scopes_count)
1956      variable_scope_object = self._cached_variable_scope_object
1957    else:
1958      # Handler for the case when we just prolong current variable scope.
1959      #   VariableScope with name extended by the provided one, and inherited
1960      #   reuse and initializer (except if the user provided values to set).
1961      self._new_name = (
1962          self._old.name + "/" +
1963          self._name_or_scope if self._old.name else self._name_or_scope)
1964      self._reuse = (self._reuse or
1965                     self._old.reuse)  # Re-using is inherited by sub-scopes.
1966      if self._old_name_scope is None:
1967        name_scope = self._name_or_scope
1968      else:
1969        name_scope = self._old_name_scope
1970      variable_scope_object = VariableScope(
1971          self._reuse,
1972          name=self._new_name,
1973          initializer=self._old.initializer,
1974          regularizer=self._old.regularizer,
1975          caching_device=self._old.caching_device,
1976          partitioner=self._old.partitioner,
1977          dtype=self._old.dtype,
1978          use_resource=self._old.use_resource,
1979          custom_getter=self._old.custom_getter,
1980          name_scope=name_scope,
1981          constraint=self._constraint)
1982      if self._initializer is not None:
1983        variable_scope_object.set_initializer(self._initializer)
1984      if self._regularizer is not None:
1985        variable_scope_object.set_regularizer(self._regularizer)
1986      if self._caching_device is not None:
1987        variable_scope_object.set_caching_device(self._caching_device)
1988      if self._partitioner is not None:
1989        variable_scope_object.set_partitioner(self._partitioner)
1990      if self._custom_getter is not None:
1991        variable_scope_object.set_custom_getter(
1992            _maybe_wrap_custom_getter(self._custom_getter,
1993                                      self._old.custom_getter))
1994      if self._dtype is not None:
1995        variable_scope_object.set_dtype(self._dtype)
1996      if self._use_resource is not None:
1997        variable_scope_object.set_use_resource(self._use_resource)
1998      self._var_scope_store.open_variable_scope(self._new_name)
1999    self._var_scope_store.current_scope = variable_scope_object
2000    self._last_variable_scope_object = variable_scope_object
2001    return variable_scope_object
2002
2003  def __exit__(self, type_arg, value_arg, traceback_arg):
2004    if (self._var_scope_store.current_scope is
2005        not self._last_variable_scope_object):
2006      raise RuntimeError("Improper nesting of variable_scope.")
2007    # If jumping out from a non-prolonged scope, restore counts.
2008    if isinstance(self._name_or_scope, VariableScope):
2009      self._var_scope_store.variable_scopes_count = self._old_subscopes
2010    else:
2011      self._var_scope_store.close_variable_subscopes(self._new_name)
2012    self._var_scope_store.current_scope = self._old
2013
2014
2015def _maybe_wrap_custom_getter(custom_getter, old_getter):
2016  """Wrap a call to a custom_getter to use the old_getter internally."""
2017  if old_getter is None:
2018    return custom_getter
2019
2020  # The new custom_getter should call the old one
2021  def wrapped_custom_getter(getter, *args, **kwargs):
2022    # Call:
2023    #  custom_getter(
2024    #    lambda: old_getter(true_getter, ...), *args, **kwargs)
2025    # which means custom_getter will call old_getter, which
2026    # will call the true_getter, perform any intermediate
2027    # processing, and return the results to the current
2028    # getter, which will also perform additional processing.
2029    return custom_getter(functools.partial(old_getter, getter), *args, **kwargs)
2030
2031  return wrapped_custom_getter
2032
2033
2034def _get_unique_variable_scope(prefix):
2035  """Get a name with the given prefix unique in the current variable scope."""
2036  var_scope_store = get_variable_scope_store()
2037  current_scope = get_variable_scope()
2038  name = current_scope.name + "/" + prefix if current_scope.name else prefix
2039  if var_scope_store.variable_scope_count(name) == 0:
2040    return prefix
2041  idx = 1
2042  while var_scope_store.variable_scope_count(name + ("_%d" % idx)) > 0:
2043    idx += 1
2044  return prefix + ("_%d" % idx)
2045
2046
2047# Named like a function for backwards compatibility with the
2048# @tf_contextlib.contextmanager version, which was switched to a class to avoid
2049# some object creation overhead.
2050@tf_export(v1=["variable_scope"])  # pylint: disable=invalid-name
2051class variable_scope(object):
2052  """A context manager for defining ops that creates variables (layers).
2053
2054  This context manager validates that the (optional) `values` are from the same
2055  graph, ensures that graph is the default graph, and pushes a name scope and a
2056  variable scope.
2057
2058  If `name_or_scope` is not None, it is used as is. If `name_or_scope` is None,
2059  then `default_name` is used.  In that case, if the same name has been
2060  previously used in the same scope, it will be made unique by appending `_N`
2061  to it.
2062
2063  Variable scope allows you to create new variables and to share already created
2064  ones while providing checks to not create or share by accident. For details,
2065  see the [Variable Scope How To](https://tensorflow.org/guide/variables), here
2066  we present only a few basic examples.
2067
2068  Simple example of how to create a new variable:
2069
2070  ```python
2071  with tf.compat.v1.variable_scope("foo"):
2072      with tf.compat.v1.variable_scope("bar"):
2073          v = tf.compat.v1.get_variable("v", [1])
2074          assert v.name == "foo/bar/v:0"
2075  ```
2076
2077  Simple example of how to reenter a premade variable scope safely:
2078
2079  ```python
2080  with tf.compat.v1.variable_scope("foo") as vs:
2081    pass
2082
2083  # Re-enter the variable scope.
2084  with tf.compat.v1.variable_scope(vs,
2085                         auxiliary_name_scope=False) as vs1:
2086    # Restore the original name_scope.
2087    with tf.name_scope(vs1.original_name_scope):
2088        v = tf.compat.v1.get_variable("v", [1])
2089        assert v.name == "foo/v:0"
2090        c = tf.constant([1], name="c")
2091        assert c.name == "foo/c:0"
2092  ```
2093
2094  Keep in mind that the counters for `default_name` are discarded once the
2095  parent scope is exited. Therefore when the code re-enters the scope (for
2096  instance by saving it), all nested default_name counters will be restarted.
2097
2098  For instance:
2099
2100  ```python
2101  with tf.compat.v1.variable_scope("foo") as vs:
2102    with tf.compat.v1.variable_scope(None, default_name="bar"):
2103      v = tf.compat.v1.get_variable("a", [1])
2104      assert v.name == "foo/bar/a:0", v.name
2105    with tf.compat.v1.variable_scope(None, default_name="bar"):
2106      v = tf.compat.v1.get_variable("b", [1])
2107      assert v.name == "foo/bar_1/b:0"
2108
2109  with tf.compat.v1.variable_scope(vs):
2110    with tf.compat.v1.variable_scope(None, default_name="bar"):
2111      v = tf.compat.v1.get_variable("c", [1])
2112      assert v.name == "foo/bar/c:0"   # Uses bar instead of bar_2!
2113  ```
2114
2115  Basic example of sharing a variable AUTO_REUSE:
2116
2117  ```python
2118  def foo():
2119    with tf.compat.v1.variable_scope("foo", reuse=tf.compat.v1.AUTO_REUSE):
2120      v = tf.compat.v1.get_variable("v", [1])
2121    return v
2122
2123  v1 = foo()  # Creates v.
2124  v2 = foo()  # Gets the same, existing v.
2125  assert v1 == v2
2126  ```
2127
2128  Basic example of sharing a variable with reuse=True:
2129
2130  ```python
2131  with tf.compat.v1.variable_scope("foo"):
2132      v = tf.compat.v1.get_variable("v", [1])
2133  with tf.compat.v1.variable_scope("foo", reuse=True):
2134      v1 = tf.compat.v1.get_variable("v", [1])
2135  assert v1 == v
2136  ```
2137
2138  Sharing a variable by capturing a scope and setting reuse:
2139
2140  ```python
2141  with tf.compat.v1.variable_scope("foo") as scope:
2142      v = tf.compat.v1.get_variable("v", [1])
2143      scope.reuse_variables()
2144      v1 = tf.compat.v1.get_variable("v", [1])
2145  assert v1 == v
2146  ```
2147
2148  To prevent accidental sharing of variables, we raise an exception when getting
2149  an existing variable in a non-reusing scope.
2150
2151  ```python
2152  with tf.compat.v1.variable_scope("foo"):
2153      v = tf.compat.v1.get_variable("v", [1])
2154      v1 = tf.compat.v1.get_variable("v", [1])
2155      #  Raises ValueError("... v already exists ...").
2156  ```
2157
2158  Similarly, we raise an exception when trying to get a variable that does not
2159  exist in reuse mode.
2160
2161  ```python
2162  with tf.compat.v1.variable_scope("foo", reuse=True):
2163      v = tf.compat.v1.get_variable("v", [1])
2164      #  Raises ValueError("... v does not exists ...").
2165  ```
2166
2167  Note that the `reuse` flag is inherited: if we open a reusing scope, then all
2168  its sub-scopes become reusing as well.
2169
2170  A note about name scoping: Setting `reuse` does not impact the naming of other
2171  ops such as mult. See related discussion on
2172  [github#6189](https://github.com/tensorflow/tensorflow/issues/6189)
2173
2174  Note that up to and including version 1.0, it was allowed (though explicitly
2175  discouraged) to pass False to the reuse argument, yielding undocumented
2176  behaviour slightly different from None. Starting at 1.1.0 passing None and
2177  False as reuse has exactly the same effect.
2178
2179  A note about using variable scopes in multi-threaded environment: Variable
2180  scopes are thread local, so one thread will not see another thread's current
2181  scope. Also, when using `default_name`, unique scopes names are also generated
2182  only on a per thread basis. If the same name was used within a different
2183  thread, that doesn't prevent a new thread from creating the same scope.
2184  However, the underlying variable store is shared across threads (within the
2185  same graph). As such, if another thread tries to create a new variable with
2186  the same name as a variable created by a previous thread, it will fail unless
2187  reuse is True.
2188
2189  Further, each thread starts with an empty variable scope. So if you wish to
2190  preserve name prefixes from a scope from the main thread, you should capture
2191  the main thread's scope and re-enter it in each thread. For e.g.
2192
2193  ```
2194  main_thread_scope = variable_scope.get_variable_scope()
2195
2196  # Thread's target function:
2197  def thread_target_fn(captured_scope):
2198    with variable_scope.variable_scope(captured_scope):
2199      # .... regular code for this thread
2200
2201
2202  thread = threading.Thread(target=thread_target_fn, args=(main_thread_scope,))
2203  ```
2204  """
2205
2206  def __init__(self,
2207               name_or_scope,
2208               default_name=None,
2209               values=None,
2210               initializer=None,
2211               regularizer=None,
2212               caching_device=None,
2213               partitioner=None,
2214               custom_getter=None,
2215               reuse=None,
2216               dtype=None,
2217               use_resource=None,
2218               constraint=None,
2219               auxiliary_name_scope=True):
2220    """Initialize the context manager.
2221
2222    Args:
2223      name_or_scope: `string` or `VariableScope`: the scope to open.
2224      default_name: The default name to use if the `name_or_scope` argument is
2225        `None`, this name will be uniquified. If name_or_scope is provided it
2226        won't be used and therefore it is not required and can be None.
2227      values: The list of `Tensor` arguments that are passed to the op function.
2228      initializer: default initializer for variables within this scope.
2229      regularizer: default regularizer for variables within this scope.
2230      caching_device: default caching device for variables within this scope.
2231      partitioner: default partitioner for variables within this scope.
2232      custom_getter: default custom getter for variables within this scope.
2233      reuse: `True`, None, or tf.compat.v1.AUTO_REUSE; if `True`, we go into
2234        reuse mode for this scope as well as all sub-scopes; if
2235        tf.compat.v1.AUTO_REUSE, we create variables if they do not exist, and
2236        return them otherwise; if None, we inherit the parent scope's reuse
2237        flag. When eager execution is enabled, new variables are always created
2238        unless an EagerVariableStore or template is currently active.
2239      dtype: type of variables created in this scope (defaults to the type in
2240        the passed scope, or inherited from parent scope).
2241      use_resource: If False, all variables will be regular Variables. If True,
2242        experimental ResourceVariables with well-defined semantics will be used
2243        instead. Defaults to False (will later change to True). When eager
2244        execution is enabled this argument is always forced to be True.
2245      constraint: An optional projection function to be applied to the variable
2246        after being updated by an `Optimizer` (e.g. used to implement norm
2247        constraints or value constraints for layer weights). The function must
2248        take as input the unprojected Tensor representing the value of the
2249        variable and return the Tensor for the projected value (which must have
2250        the same shape). Constraints are not safe to use when doing asynchronous
2251        distributed training.
2252      auxiliary_name_scope: If `True`, we create an auxiliary name scope with
2253        the scope. If `False`, we don't create it. Note that the argument is not
2254        inherited, and it only takes effect for once when creating. You should
2255        only use it for re-entering a premade variable scope.
2256
2257    Returns:
2258      A scope that can be captured and reused.
2259
2260    Raises:
2261      ValueError: when trying to reuse within a create scope, or create within
2262        a reuse scope.
2263      TypeError: when the types of some arguments are not appropriate.
2264    """
2265    self._name_or_scope = name_or_scope
2266    self._default_name = default_name
2267    self._values = values
2268    self._initializer = initializer
2269    self._regularizer = regularizer
2270    self._caching_device = caching_device
2271    self._partitioner = partitioner
2272    self._custom_getter = custom_getter
2273    self._reuse = reuse
2274    self._dtype = dtype
2275    self._use_resource = use_resource
2276    self._constraint = constraint
2277    if self._default_name is None and self._name_or_scope is None:
2278      raise TypeError("If default_name is None then name_or_scope is required")
2279    if self._reuse is False:
2280      # We don't allow non-inheriting scopes, False = None here.
2281      self._reuse = None
2282    if not (self._reuse is True
2283            or self._reuse is None
2284            or self._reuse is AUTO_REUSE):
2285      raise ValueError("The reuse parameter must be True or False or None.")
2286    if self._values is None:
2287      self._values = []
2288    self._in_graph_mode = not context.executing_eagerly()
2289    if self._in_graph_mode:
2290      self._graph = ops._get_graph_from_inputs(self._values)  # pylint: disable=protected-access
2291    self._cached_pure_variable_scope = None
2292    self._current_name_scope = None
2293    if not isinstance(auxiliary_name_scope, bool):
2294      raise TypeError("The auxiliary_name_scope must be `True` or `False`, "
2295                      "while get {}".format(auxiliary_name_scope))
2296    self._auxiliary_name_scope = auxiliary_name_scope
2297
2298  def __enter__(self):
2299    # If the default graph is building a function, then we should not replace it
2300    # with the cached graph.
2301    if ops.get_default_graph().building_function:
2302      self._building_function = True
2303    else:
2304      self._building_function = False
2305    if self._in_graph_mode and not self._building_function:
2306      self._graph_context_manager = self._graph.as_default()
2307      self._graph_context_manager.__enter__()
2308    if self._cached_pure_variable_scope is not None:
2309      # Fast path for re-entering variable_scopes. We've held on to the pure
2310      # variable scope from a previous successful __enter__, so we avoid some
2311      # overhead by re-using that object.
2312      if self._current_name_scope is not None:
2313        self._current_name_scope.__enter__()
2314      return self._cached_pure_variable_scope.__enter__()
2315
2316    try:
2317      return self._enter_scope_uncached()
2318    except:
2319      if (self._in_graph_mode and not self._building_function and
2320          self._graph_context_manager is not None):
2321        self._graph_context_manager.__exit__(*sys.exc_info())
2322      raise
2323
2324  def _enter_scope_uncached(self):
2325    """Enters the context manager when there is no cached scope yet.
2326
2327    Returns:
2328      The entered variable scope.
2329
2330    Raises:
2331      TypeError: A wrong type is passed as `scope` at __init__().
2332      ValueError: `reuse` is incorrectly set at __init__().
2333    """
2334    if self._auxiliary_name_scope:
2335      # Create a new name scope later
2336      current_name_scope = None
2337    else:
2338      # Reenter the current name scope
2339      name_scope = ops.get_name_scope()
2340      if name_scope:
2341        # Hack to reenter
2342        name_scope += "/"
2343        current_name_scope = ops.name_scope(name_scope, skip_on_eager=False)
2344      else:
2345        # Root scope
2346        current_name_scope = ops.name_scope(name_scope, skip_on_eager=False)
2347
2348    # IMPORTANT: Only assign to self._cached_pure_variable_scope and
2349    # self._current_name_scope after successful __enter__() calls.
2350    if self._name_or_scope is not None:
2351      if not isinstance(self._name_or_scope,
2352                        (VariableScope,) + six.string_types):
2353        raise TypeError("VariableScope: name_or_scope must be a string or "
2354                        "VariableScope.")
2355      if isinstance(self._name_or_scope, six.string_types):
2356        name_scope = self._name_or_scope
2357      else:
2358        name_scope = self._name_or_scope.name.split("/")[-1]
2359      if name_scope or current_name_scope:
2360        current_name_scope = current_name_scope or ops.name_scope(
2361            name_scope, skip_on_eager=False)
2362        try:
2363          current_name_scope_name = current_name_scope.__enter__()
2364        except:
2365          current_name_scope.__exit__(*sys.exc_info())
2366          raise
2367        self._current_name_scope = current_name_scope
2368        if isinstance(self._name_or_scope, six.string_types):
2369          old_name_scope = current_name_scope_name
2370        else:
2371          old_name_scope = self._name_or_scope.original_name_scope
2372        pure_variable_scope = _pure_variable_scope(
2373            self._name_or_scope,
2374            reuse=self._reuse,
2375            initializer=self._initializer,
2376            regularizer=self._regularizer,
2377            caching_device=self._caching_device,
2378            partitioner=self._partitioner,
2379            custom_getter=self._custom_getter,
2380            old_name_scope=old_name_scope,
2381            dtype=self._dtype,
2382            use_resource=self._use_resource,
2383            constraint=self._constraint)
2384        try:
2385          entered_pure_variable_scope = pure_variable_scope.__enter__()
2386        except:
2387          pure_variable_scope.__exit__(*sys.exc_info())
2388          raise
2389        self._cached_pure_variable_scope = pure_variable_scope
2390        return entered_pure_variable_scope
2391      else:
2392        self._current_name_scope = None
2393        # This can only happen if someone is entering the root variable scope.
2394        pure_variable_scope = _pure_variable_scope(
2395            self._name_or_scope,
2396            reuse=self._reuse,
2397            initializer=self._initializer,
2398            regularizer=self._regularizer,
2399            caching_device=self._caching_device,
2400            partitioner=self._partitioner,
2401            custom_getter=self._custom_getter,
2402            dtype=self._dtype,
2403            use_resource=self._use_resource,
2404            constraint=self._constraint)
2405        try:
2406          entered_pure_variable_scope = pure_variable_scope.__enter__()
2407        except:
2408          pure_variable_scope.__exit__(*sys.exc_info())
2409          raise
2410        self._cached_pure_variable_scope = pure_variable_scope
2411        return entered_pure_variable_scope
2412
2413    else:  # Here name_or_scope is None. Using default name, but made unique.
2414      if self._reuse:
2415        raise ValueError("reuse=True cannot be used without a name_or_scope")
2416      current_name_scope = current_name_scope or ops.name_scope(
2417          self._default_name, skip_on_eager=False)
2418      try:
2419        current_name_scope_name = current_name_scope.__enter__()
2420      except:
2421        current_name_scope.__exit__(*sys.exc_info())
2422        raise
2423      self._current_name_scope = current_name_scope
2424      unique_default_name = _get_unique_variable_scope(self._default_name)
2425      pure_variable_scope = _pure_variable_scope(
2426          unique_default_name,
2427          initializer=self._initializer,
2428          regularizer=self._regularizer,
2429          caching_device=self._caching_device,
2430          partitioner=self._partitioner,
2431          custom_getter=self._custom_getter,
2432          old_name_scope=current_name_scope_name,
2433          dtype=self._dtype,
2434          use_resource=self._use_resource,
2435          constraint=self._constraint)
2436      try:
2437        entered_pure_variable_scope = pure_variable_scope.__enter__()
2438      except:
2439        pure_variable_scope.__exit__(*sys.exc_info())
2440        raise
2441      self._cached_pure_variable_scope = pure_variable_scope
2442      return entered_pure_variable_scope
2443
2444  def __exit__(self, type_arg, value_arg, traceback_arg):
2445    try:
2446      self._cached_pure_variable_scope.__exit__(type_arg, value_arg,
2447                                                traceback_arg)
2448    finally:
2449      try:
2450        if self._current_name_scope:
2451          self._current_name_scope.__exit__(type_arg, value_arg,
2452                                            traceback_arg)
2453      finally:
2454        if self._in_graph_mode and not self._building_function:
2455          self._graph_context_manager.__exit__(type_arg, value_arg,
2456                                               traceback_arg)
2457
2458
2459# pylint: disable=g-doc-return-or-yield
2460@tf_export(v1=["variable_op_scope"])
2461@tf_contextlib.contextmanager
2462def variable_op_scope(values,
2463                      name_or_scope,
2464                      default_name=None,
2465                      initializer=None,
2466                      regularizer=None,
2467                      caching_device=None,
2468                      partitioner=None,
2469                      custom_getter=None,
2470                      reuse=None,
2471                      dtype=None,
2472                      use_resource=None,
2473                      constraint=None):
2474  """Deprecated: context manager for defining an op that creates variables."""
2475  logging.warn("tf.variable_op_scope(values, name, default_name) is deprecated,"
2476               " use tf.variable_scope(name, default_name, values)")
2477  with variable_scope(
2478      name_or_scope,
2479      default_name=default_name,
2480      values=values,
2481      initializer=initializer,
2482      regularizer=regularizer,
2483      caching_device=caching_device,
2484      partitioner=partitioner,
2485      custom_getter=custom_getter,
2486      reuse=reuse,
2487      dtype=dtype,
2488      use_resource=use_resource,
2489      constraint=constraint) as scope:
2490    yield scope
2491
2492
2493def _call_partitioner(partitioner, shape, dtype):
2494  """Call partitioner validating its inputs/output.
2495
2496  Args:
2497    partitioner: a function mapping `Tensor` shape and dtype to a list of
2498      partitions.
2499    shape: shape of the `Tensor` to partition, must have at least two
2500      dimensions.
2501    dtype: dtype of the elements in the `Tensor`.
2502
2503  Returns:
2504    A list with elements >=1 and exactly one >1. The index of that
2505    element corresponds to the partitioning axis.
2506  """
2507  if not shape.is_fully_defined():
2508    raise ValueError("Shape of a new partitioned variable must be "
2509                     "fully defined, but instead was %s." % (shape,))
2510  if shape.ndims < 1:
2511    raise ValueError("A partitioned Variable must have rank at least 1, "
2512                     "shape: %s" % shape)
2513
2514  slicing = partitioner(shape=shape, dtype=dtype)
2515  if not isinstance(slicing, collections_lib.Sequence):
2516    raise ValueError("Partitioner must return a sequence, but saw: %s" %
2517                     slicing)
2518  if len(slicing) != shape.ndims:
2519    raise ValueError(
2520        "Partitioner returned a partition list that does not match the "
2521        "Variable's rank: %s vs. %s" % (slicing, shape))
2522  if any(p < 1 for p in slicing):
2523    raise ValueError("Partitioner returned zero partitions for some axes: %s" %
2524                     slicing)
2525  if sum(p > 1 for p in slicing) > 1:
2526    raise ValueError("Can only slice a variable along one dimension: "
2527                     "shape: %s, partitioning: %s" % (shape, slicing))
2528  return slicing
2529
2530
2531# TODO(slebedev): could be inlined, but
2532# `_VariableStore._get_partitioned_variable` is too complex even
2533# without this logic.
2534def _get_slice_dim_and_num_slices(slicing):
2535  """Get slicing dimension and number of slices from the partitioner output."""
2536  for slice_dim, num_slices in enumerate(slicing):
2537    if num_slices > 1:
2538      break
2539  else:
2540    # Degenerate case: no partitioning applied.
2541    slice_dim = 0
2542    num_slices = 1
2543  return slice_dim, num_slices
2544
2545
2546def _iter_slices(full_shape, num_slices, slice_dim):
2547  """Slices a given a shape along the specified dimension."""
2548  num_slices_with_excess = full_shape[slice_dim] % num_slices
2549  offset = [0] * len(full_shape)
2550  min_slice_len = full_shape[slice_dim] // num_slices
2551  for i in xrange(num_slices):
2552    shape = full_shape[:]
2553    shape[slice_dim] = min_slice_len + bool(i < num_slices_with_excess)
2554    yield offset[:], shape
2555    offset[slice_dim] += shape[slice_dim]
2556
2557
2558def default_variable_creator(next_creator=None, **kwargs):
2559  """Default variable creator."""
2560  assert next_creator is None
2561  initial_value = kwargs.get("initial_value", None)
2562  trainable = kwargs.get("trainable", None)
2563  collections = kwargs.get("collections", None)
2564  validate_shape = kwargs.get("validate_shape", True)
2565  caching_device = kwargs.get("caching_device", None)
2566  name = kwargs.get("name", None)
2567  variable_def = kwargs.get("variable_def", None)
2568  dtype = kwargs.get("dtype", None)
2569  expected_shape = kwargs.get("expected_shape", None)
2570  import_scope = kwargs.get("import_scope", None)
2571  constraint = kwargs.get("constraint", None)
2572  use_resource = kwargs.get("use_resource", None)
2573  synchronization = kwargs.get("synchronization", None)
2574  aggregation = kwargs.get("aggregation", None)
2575  shape = kwargs.get("shape", None)
2576
2577  if use_resource is None:
2578    use_resource = get_variable_scope().use_resource
2579  if use_resource is None:
2580    use_resource = _DEFAULT_USE_RESOURCE
2581  use_resource = use_resource or context.executing_eagerly()
2582  if use_resource:
2583    distribute_strategy = kwargs.get("distribute_strategy", None)
2584    return resource_variable_ops.ResourceVariable(
2585        initial_value=initial_value,
2586        trainable=trainable,
2587        collections=collections,
2588        validate_shape=validate_shape,
2589        caching_device=caching_device,
2590        name=name,
2591        dtype=dtype,
2592        constraint=constraint,
2593        variable_def=variable_def,
2594        import_scope=import_scope,
2595        distribute_strategy=distribute_strategy,
2596        synchronization=synchronization,
2597        aggregation=aggregation,
2598        shape=shape)
2599  else:
2600    return variables.RefVariable(
2601        initial_value=initial_value,
2602        trainable=trainable,
2603        collections=collections,
2604        validate_shape=validate_shape,
2605        caching_device=caching_device,
2606        name=name,
2607        dtype=dtype,
2608        constraint=constraint,
2609        variable_def=variable_def,
2610        expected_shape=expected_shape,
2611        import_scope=import_scope,
2612        synchronization=synchronization,
2613        aggregation=aggregation,
2614        shape=shape)
2615
2616
2617def default_variable_creator_v2(next_creator=None, **kwargs):
2618  """Default variable creator."""
2619  assert next_creator is None
2620  initial_value = kwargs.get("initial_value", None)
2621  trainable = kwargs.get("trainable", None)
2622  validate_shape = kwargs.get("validate_shape", True)
2623  caching_device = kwargs.get("caching_device", None)
2624  name = kwargs.get("name", None)
2625  variable_def = kwargs.get("variable_def", None)
2626  dtype = kwargs.get("dtype", None)
2627  import_scope = kwargs.get("import_scope", None)
2628  constraint = kwargs.get("constraint", None)
2629  distribute_strategy = kwargs.get("distribute_strategy", None)
2630  synchronization = kwargs.get("synchronization", None)
2631  aggregation = kwargs.get("aggregation", None)
2632  shape = kwargs.get("shape", None)
2633
2634  return resource_variable_ops.ResourceVariable(
2635      initial_value=initial_value,
2636      trainable=trainable,
2637      validate_shape=validate_shape,
2638      caching_device=caching_device,
2639      name=name,
2640      dtype=dtype,
2641      constraint=constraint,
2642      variable_def=variable_def,
2643      import_scope=import_scope,
2644      distribute_strategy=distribute_strategy,
2645      synchronization=synchronization,
2646      aggregation=aggregation,
2647      shape=shape)
2648
2649
2650variables.default_variable_creator = default_variable_creator
2651variables.default_variable_creator_v2 = default_variable_creator_v2
2652
2653
2654def _make_getter(captured_getter, captured_previous):
2655  """Gets around capturing loop variables in python being broken."""
2656  return lambda **kwargs: captured_getter(captured_previous, **kwargs)
2657
2658
2659# TODO(apassos) remove forwarding symbol
2660variable = variables.VariableV1
2661
2662
2663@tf_export(v1=["variable_creator_scope"])
2664@tf_contextlib.contextmanager
2665def variable_creator_scope_v1(variable_creator):
2666  """Scope which defines a variable creation function to be used by variable().
2667
2668  variable_creator is expected to be a function with the following signature:
2669
2670  ```
2671    def variable_creator(next_creator, **kwargs)
2672  ```
2673
2674  The creator is supposed to eventually call the next_creator to create a
2675  variable if it does want to create a variable and not call Variable or
2676  ResourceVariable directly. This helps make creators composable. A creator may
2677  choose to create multiple variables, return already existing variables, or
2678  simply register that a variable was created and defer to the next creators in
2679  line. Creators can also modify the keyword arguments seen by the next
2680  creators.
2681
2682  Custom getters in the variable scope will eventually resolve down to these
2683  custom creators when they do create variables.
2684
2685  The valid keyword arguments in kwds are:
2686
2687   * initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
2688        which is the initial value for the Variable. The initial value must have
2689        a shape specified unless `validate_shape` is set to False. Can also be a
2690        callable with no argument that returns the initial value when called. In
2691        that case, `dtype` must be specified. (Note that initializer functions
2692        from init_ops.py must first be bound to a shape before being used here.)
2693   * trainable: If `True`, the default, also adds the variable to the graph
2694        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
2695        the default list of variables to use by the `Optimizer` classes.
2696        `trainable` defaults to `True`, unless `synchronization` is
2697        set to `ON_READ`, in which case it defaults to `False`.
2698   * collections: List of graph collections keys. The new variable is added to
2699        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
2700   * validate_shape: If `False`, allows the variable to be initialized with a
2701        value of unknown shape. If `True`, the default, the shape of
2702        `initial_value` must be known.
2703   * caching_device: Optional device string describing where the Variable
2704        should be cached for reading.  Defaults to the Variable's device.
2705        If not `None`, caches on another device.  Typical use is to cache
2706        on the device where the Ops using the Variable reside, to deduplicate
2707        copying through `Switch` and other conditional statements.
2708   * name: Optional name for the variable. Defaults to `'Variable'` and gets
2709        uniquified automatically.
2710   * dtype: If set, initial_value will be converted to the given type.
2711        If `None`, either the datatype will be kept (if `initial_value` is
2712        a Tensor), or `convert_to_tensor` will decide.
2713   * constraint: A constraint function to be applied to the variable after
2714        updates by some algorithms.
2715   * use_resource: if True, a ResourceVariable is always created.
2716   * synchronization: Indicates when a distributed a variable will be
2717        aggregated. Accepted values are constants defined in the class
2718        `tf.VariableSynchronization`. By default the synchronization is set to
2719        `AUTO` and the current `DistributionStrategy` chooses
2720        when to synchronize.
2721   * aggregation: Indicates how a distributed variable will be aggregated.
2722        Accepted values are constants defined in the class
2723        `tf.VariableAggregation`.
2724
2725  This set may grow over time, so it's important the signature of creators is as
2726  mentioned above.
2727
2728  Args:
2729    variable_creator: the passed creator
2730
2731  Yields:
2732    A scope in which the creator is active
2733  """
2734  with ops.get_default_graph()._variable_creator_scope(variable_creator):  # pylint: disable=protected-access
2735    yield
2736
2737
2738# Note: only the docstrings differ between this and v1.
2739@tf_export("variable_creator_scope", v1=[])
2740@tf_contextlib.contextmanager
2741def variable_creator_scope(variable_creator):
2742  """Scope which defines a variable creation function to be used by variable().
2743
2744  variable_creator is expected to be a function with the following signature:
2745
2746  ```
2747    def variable_creator(next_creator, **kwargs)
2748  ```
2749
2750  The creator is supposed to eventually call the next_creator to create a
2751  variable if it does want to create a variable and not call Variable or
2752  ResourceVariable directly. This helps make creators composable. A creator may
2753  choose to create multiple variables, return already existing variables, or
2754  simply register that a variable was created and defer to the next creators in
2755  line. Creators can also modify the keyword arguments seen by the next
2756  creators.
2757
2758  Custom getters in the variable scope will eventually resolve down to these
2759  custom creators when they do create variables.
2760
2761  The valid keyword arguments in kwds are:
2762
2763   * initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
2764        which is the initial value for the Variable. The initial value must have
2765        a shape specified unless `validate_shape` is set to False. Can also be a
2766        callable with no argument that returns the initial value when called. In
2767        that case, `dtype` must be specified. (Note that initializer functions
2768        from init_ops.py must first be bound to a shape before being used here.)
2769   * trainable: If `True`, the default, GradientTapes automatically watch
2770        uses of this Variable.
2771   * validate_shape: If `False`, allows the variable to be initialized with a
2772        value of unknown shape. If `True`, the default, the shape of
2773        `initial_value` must be known.
2774   * caching_device: Optional device string describing where the Variable
2775        should be cached for reading.  Defaults to the Variable's device.
2776        If not `None`, caches on another device.  Typical use is to cache
2777        on the device where the Ops using the Variable reside, to deduplicate
2778        copying through `Switch` and other conditional statements.
2779   * name: Optional name for the variable. Defaults to `'Variable'` and gets
2780        uniquified automatically.
2781      dtype: If set, initial_value will be converted to the given type.
2782        If `None`, either the datatype will be kept (if `initial_value` is
2783        a Tensor), or `convert_to_tensor` will decide.
2784   * constraint: A constraint function to be applied to the variable after
2785        updates by some algorithms.
2786   * synchronization: Indicates when a distributed a variable will be
2787        aggregated. Accepted values are constants defined in the class
2788        `tf.VariableSynchronization`. By default the synchronization is set to
2789        `AUTO` and the current `DistributionStrategy` chooses
2790        when to synchronize.
2791   * aggregation: Indicates how a distributed variable will be aggregated.
2792        Accepted values are constants defined in the class
2793        `tf.VariableAggregation`.
2794
2795  This set may grow over time, so it's important the signature of creators is as
2796  mentioned above.
2797
2798  Args:
2799    variable_creator: the passed creator
2800
2801  Yields:
2802    A scope in which the creator is active
2803  """
2804  with ops.get_default_graph()._variable_creator_scope(variable_creator):  # pylint: disable=protected-access
2805    yield
2806