• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A class to store named variables and a scope operator to manage sharing."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import enum  # pylint: disable=g-bad-import-order
23import functools
24import sys
25import threading
26import traceback
27
28import six
29from six import iteritems
30from six.moves import xrange, zip  # pylint: disable=redefined-builtin
31
32from tensorflow.python import tf2
33from tensorflow.python.client import session
34from tensorflow.python.eager import context
35from tensorflow.python.eager import monitoring
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_shape
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import init_ops
41from tensorflow.python.ops import resource_variable_ops
42from tensorflow.python.ops import variables
43from tensorflow.python.platform import tf_logging as logging
44from tensorflow.python.types import core
45from tensorflow.python.util import deprecation
46from tensorflow.python.util import function_utils
47from tensorflow.python.util import tf_contextlib
48from tensorflow.python.util import tf_inspect
49from tensorflow.python.util.compat import collections_abc
50from tensorflow.python.util.tf_export import tf_export
51
52__all__ = [
53    "AUTO_REUSE", "VariableScope", "get_variable_scope", "get_variable",
54    "get_local_variable", "variable_scope", "variable_op_scope",
55    "no_regularizer", "VariableSynchronization", "VariableAggregation"
56]
57
58_api_usage_gauge = monitoring.BoolGauge(
59    "/tensorflow/api/resource_variables",
60    "Whether variable_scope.enable_resource_variables() is called.")
61
62
63class _PartitionInfo(object):
64  """Holds partition info used by initializer functions."""
65
66  __slots__ = ["_full_shape", "_var_offset"]
67
68  def __init__(self, full_shape, var_offset):
69    """Constructor.
70
71    Args:
72      full_shape: Tuple or list of `int` indicating the full combined shape of
73        the partitioned variables.
74      var_offset: Tuple or list of `int` specifying offset of this partition
75        with respect to the full variable for each dimension.
76
77    Raises:
78      TypeError: If `full_shape` or `var_offset` is not a sequence.
79      ValueError: If `full_shape` or `var_offset` differ in length. If
80        `var_offset` exceeds `full_shape` in any dimension.
81    """
82    if not isinstance(full_shape, collections_abc.Sequence) or isinstance(
83        full_shape, six.string_types):
84      raise TypeError(
85          "`full_shape` must be a sequence (like tuple or list) instead of " +
86          type(full_shape).__name__)
87
88    if not isinstance(var_offset, collections_abc.Sequence) or isinstance(
89        var_offset, six.string_types):
90      raise TypeError(
91          "`var_offset` must be a sequence (like tuple or list) instead of " +
92          type(var_offset).__name__)
93
94    if len(var_offset) != len(full_shape):
95      raise ValueError(
96          "Expected equal length, but `var_offset` is of length {} while "
97          "full_shape is of length {}.".format(
98              len(var_offset), len(full_shape)))
99
100    for offset, shape in zip(var_offset, full_shape):
101      if offset < 0 or offset >= shape:
102        raise ValueError(
103            "Expected 0 <= offset < shape but found offset={}, shape={} for "
104            "var_offset={}, full_shape={}".format(offset, shape, var_offset,
105                                                  full_shape))
106
107    self._full_shape = full_shape
108    self._var_offset = var_offset
109
110  @property
111  def full_shape(self):
112    return self._full_shape
113
114  @property
115  def var_offset(self):
116    return self._var_offset
117
118  def single_offset(self, shape):
119    """Returns the offset when the variable is partitioned in at most one dim.
120
121    Args:
122      shape: Tuple or list of `int` indicating the shape of one specific
123        variable partition.
124
125    Returns:
126      `int` representing the offset in the dimension along which the variable is
127       partitioned. Returns 0 if the variable is not being partitioned.
128
129    Raises:
130      ValueError: Depending on self.single_slice_dim().
131    """
132
133    single_slice_dim = self.single_slice_dim(shape)
134    # If this variable is not being partitioned at all, single_slice_dim() could
135    # return None.
136    if single_slice_dim is None:
137      return 0
138    return self.var_offset[single_slice_dim]
139
140  def single_slice_dim(self, shape):
141    """Returns the slice dim when the variable is partitioned only in one dim.
142
143    Args:
144      shape: Tuple or list of `int` indicating the shape of one specific
145        variable partition.
146
147    Returns:
148      `int` representing the dimension that the variable is partitioned in, or
149      `None` if the variable doesn't seem to be partitioned at all.
150
151    Raises:
152      TypeError: If `shape` is not a sequence.
153      ValueError: If `shape` is not the same length as `self.full_shape`. If
154        the variable is partitioned in more than one dimension.
155    """
156    if not isinstance(shape, collections_abc.Sequence) or isinstance(
157        shape, six.string_types):
158      raise TypeError(
159          "`shape` must be a sequence (like tuple or list) instead of " +
160          type(shape).__name__)
161
162    if len(shape) != len(self.full_shape):
163      raise ValueError(
164          "Expected equal length, but received shape={} of length {} while "
165          "self.full_shape={} is of length {}.".format(shape, len(shape),
166                                                       self.full_shape,
167                                                       len(self.full_shape)))
168
169    for i in xrange(len(shape)):
170      if self.var_offset[i] + shape[i] > self.full_shape[i]:
171        raise ValueError(
172            "With self.var_offset={}, a partition of shape={} would exceed "
173            "self.full_shape={} in dimension {}.".format(
174                self.var_offset, shape, self.full_shape, i))
175
176    slice_dim = None
177    for i in xrange(len(shape)):
178      if shape[i] == self.full_shape[i]:
179        continue
180      if slice_dim is not None:
181        raise ValueError(
182            "Cannot use single_slice_dim() with shape={} and "
183            "self.full_shape={} since slice dim could be either dimension {} "
184            "or {}.".format(shape, self.full_shape, i, slice_dim))
185      slice_dim = i
186
187    return slice_dim
188
189
190class _ReuseMode(enum.Enum):
191  """Mode for variable access within a variable scope."""
192
193  # Indicates that variables are to be fetched if they already exist or
194  # otherwise created.
195  AUTO_REUSE = 1
196
197  # TODO(alive): For TensorFlow 2.0, Deprecate True/False/None API in favor of
198  #              enum values.
199  # REUSE_FALSE = 2
200  # REUSE_TRUE = 3
201
202
203# TODO(apassos) remove these forwarding symbols.
204VariableSynchronization = variables.VariableSynchronization  # pylint: disable=invalid-name
205VariableAggregation = variables.VariableAggregation  # pylint: disable=invalid-name
206
207AUTO_REUSE = _ReuseMode.AUTO_REUSE
208tf_export(v1=["AUTO_REUSE"]).export_constant(__name__, "AUTO_REUSE")
209AUTO_REUSE.__doc__ = """
210When passed in as the value for the `reuse` flag, AUTO_REUSE indicates that
211get_variable() should create the requested variable if it doesn't exist or, if
212it does exist, simply return it.
213"""
214
215_DEFAULT_USE_RESOURCE = tf2.enabled()
216
217
218@tf_export(v1=["enable_resource_variables"])
219def enable_resource_variables():
220  """Creates resource variables by default.
221
222  Resource variables are improved versions of TensorFlow variables with a
223  well-defined memory model. Accessing a resource variable reads its value, and
224  all ops which access a specific read value of the variable are guaranteed to
225  see the same value for that tensor. Writes which happen after a read (by
226  having a control or data dependency on the read) are guaranteed not to affect
227  the value of the read tensor, and similarly writes which happen before a read
228  are guaranteed to affect the value. No guarantees are made about unordered
229  read/write pairs.
230
231  Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0
232  feature.
233  """
234  global _DEFAULT_USE_RESOURCE
235  _DEFAULT_USE_RESOURCE = True
236  logging.vlog(1, "Enabling resource variables")
237  _api_usage_gauge.get_cell().set(True)
238
239
240@tf_export(v1=["resource_variables_enabled"])
241def resource_variables_enabled():
242  """Returns `True` if resource variables are enabled.
243
244  Resource variables are improved versions of TensorFlow variables with a
245  well-defined memory model. Accessing a resource variable reads its value, and
246  all ops which access a specific read value of the variable are guaranteed to
247  see the same value for that tensor. Writes which happen after a read (by
248  having a control or data dependency on the read) are guaranteed not to affect
249  the value of the read tensor, and similarly writes which happen before a read
250  are guaranteed to affect the value. No guarantees are made about unordered
251  read/write pairs.
252
253  Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0
254  feature.
255  """
256  global _DEFAULT_USE_RESOURCE
257  return _DEFAULT_USE_RESOURCE
258
259
260@deprecation.deprecated(
261    None, "non-resource variables are not supported in the long term")
262@tf_export(v1=["disable_resource_variables"])
263def disable_resource_variables():
264  """Opts out of resource variables.
265
266  If your code needs tf.disable_resource_variables() to be called to work
267  properly please file a bug.
268  """
269  global _DEFAULT_USE_RESOURCE
270  _DEFAULT_USE_RESOURCE = False
271  logging.vlog(1, "Disabling resource variables")
272  _api_usage_gauge.get_cell().set(False)
273
274
275def _needs_no_arguments(python_callable):
276  """Returns true if the callable needs no arguments to call."""
277  # TODO(bfontain): Switch to inspect.signature when we are python 3 only.
278  # signature = inspect.signature(python_callable)
279  # return not [1 for param in signature.parameters.values()
280  #             if param.default == param.empty]
281  num_arguments = len(tf_inspect.getargspec(python_callable).args)
282  if not tf_inspect.isfunction(python_callable) and not isinstance(
283      python_callable, functools.partial):
284    # getargspec includes self for function objects (which aren't
285    # functools.partial). This has no default so we need to remove it.
286    # It is not even an argument so its odd that getargspec returns this.
287    # Note that this is fixed with inspect.signature in Python 3.
288    num_arguments -= 1
289  return num_arguments == len(
290      tf_inspect.getargspec(python_callable).defaults or [])
291
292
293class _VariableStore(object):
294  """Variable store that carries a number of named Variables.
295
296  New variable names and new variables can be created; all stored
297  variables are initialized with the initializer passed to __init__.
298
299  Attributes:
300    vars: a dictionary with string names (same as passed in GetVar) as keys and
301      the corresponding TensorFlow Variables as values.
302  """
303
304  __slots__ = ["_vars", "_partitioned_vars", "_store_eager_variables"]
305
306  def __init__(self):
307    """Create a variable store."""
308    self._vars = {}  # A dictionary of the stored TensorFlow variables.
309    self._partitioned_vars = {}  # A dict of the stored PartitionedVariables.
310    self._store_eager_variables = False
311
312  def get_variable(self,
313                   name,
314                   shape=None,
315                   dtype=dtypes.float32,
316                   initializer=None,
317                   regularizer=None,
318                   reuse=None,
319                   trainable=None,
320                   collections=None,
321                   caching_device=None,
322                   partitioner=None,
323                   validate_shape=True,
324                   use_resource=None,
325                   custom_getter=None,
326                   constraint=None,
327                   synchronization=VariableSynchronization.AUTO,
328                   aggregation=VariableAggregation.NONE):
329    """Gets an existing variable with these parameters or create a new one.
330
331    If a variable with the given name is already stored, we return the stored
332    variable. Otherwise, we create a new one.
333
334    Set `reuse` to `True` when you only want to reuse existing Variables.
335    Set `reuse` to `False` when you only want to create new Variables.
336    Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want
337    variables to be created if they don't exist or returned if they do.
338
339    If initializer is `None` (the default), the default initializer passed in
340    the constructor is used. If that one is `None` too, we use a new
341    `glorot_uniform_initializer`. If initializer is a Tensor, we use
342    it as a value and derive the shape from the initializer.
343
344    If a partitioner is provided, a `PartitionedVariable` is returned.
345    Accessing this object as a `Tensor` returns the shards concatenated along
346    the partition axis.
347
348    Some useful partitioners are available.  See, e.g.,
349    `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
350
351    Args:
352      name: The name of the new or existing variable.
353      shape: Shape of the new or existing variable.
354      dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
355      initializer: Initializer for the variable.
356      regularizer: A (Tensor -> Tensor or None) function; the result of applying
357        it on a newly created variable will be added to the collection
358        GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
359      reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of
360        variables. When eager execution is enabled  this argument is always
361        forced to be False.
362      trainable: If `True` also add the variable to the graph collection
363        `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). `trainable`
364        defaults to `True`, unless `synchronization` is set to `ON_READ`, in
365        which case it defaults to `False`.
366      collections: List of graph collections keys to add the `Variable` to.
367        Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
368      caching_device: Optional device string or function describing where the
369        Variable should be cached for reading.  Defaults to the Variable's
370        device.  If not `None`, caches on another device.  Typical use is to
371        cache on the device where the Ops using the `Variable` reside, to
372        deduplicate copying through `Switch` and other conditional statements.
373      partitioner: Optional callable that accepts a fully defined `TensorShape`
374        and dtype of the `Variable` to be created, and returns a list of
375        partitions for each axis (currently only one axis can be partitioned).
376      validate_shape: If False, allows the variable to be initialized with a
377        value of unknown shape. If True, the default, the shape of initial_value
378        must be known.
379      use_resource: If False, creates a regular Variable. If True, creates
380        instead an experimental ResourceVariable which has well-defined
381        semantics. Defaults to False (will later change to True). When eager
382        execution is enabled this argument is always forced to be true.
383      custom_getter: Callable that takes as a first argument the true getter,
384        and allows overwriting the internal get_variable method. The signature
385        of `custom_getter` should match that of this method,
386        but the most future-proof version will allow for changes: `def
387          custom_getter(getter, *args, **kwargs)`.  Direct access to
388        all `get_variable` parameters is also allowed: `def
389          custom_getter(getter, name, *args, **kwargs)`.  A simple identity
390        custom getter that simply creates variables with modified names is:
391          ```python
392        def custom_getter(getter, name, *args, **kwargs): return getter(name +
393          '_suffix', *args, **kwargs) ```
394      constraint: An optional projection function to be applied to the variable
395        after being updated by an `Optimizer` (e.g. used to implement norm
396        constraints or value constraints for layer weights). The function must
397        take as input the unprojected Tensor representing the value of the
398        variable and return the Tensor for the projected value (which must have
399        the same shape). Constraints are not safe to use when doing asynchronous
400        distributed training.
401      synchronization: Indicates when a distributed a variable will be
402        aggregated. Accepted values are constants defined in the class
403        `tf.VariableSynchronization`. By default the synchronization is set to
404        `AUTO` and the current `DistributionStrategy` chooses when to
405        synchronize.
406      aggregation: Indicates how a distributed variable will be aggregated.
407        Accepted values are constants defined in the class
408        `tf.VariableAggregation`.
409
410    Returns:
411      The created or existing `Variable` (or `PartitionedVariable`, if a
412      partitioner was used).
413
414    Raises:
415      ValueError: when creating a new variable and shape is not declared,
416        when reusing a variable and specifying a conflicting shape,
417        or when violating reuse during variable creation.
418      RuntimeError: when eager execution is enabled and not called from an
419        EagerVariableStore.
420    """
421    if custom_getter is not None and not callable(custom_getter):
422      raise ValueError("Passed a custom_getter which is not callable: %s" %
423                       custom_getter)
424
425    with ops.init_scope():
426      if context.executing_eagerly():
427        # Variable creation and initialization takes place in `init_scope`s;
428        # as such, if an `init_scope` lifts us into the eager context, then we
429        # need to use `ResourceVariable`s.
430        use_resource = True
431
432    # Note that it's fine to reuse eager variables whose initialization was
433    # lifted from a function-building graph into the eager context (that's why
434    # the following clause is not wrapped in an `init_scope`); lifted variables
435    # are tracked by the graph's `VariableStore`.
436    if context.executing_eagerly():
437      if not self._store_eager_variables and reuse:
438        raise RuntimeError(
439            "When eager execution is enabled variable reuse is only supported"
440            " when an EagerVariableStore is active. See the documentation on"
441            " EagerVariableStore for example usage.")
442      if self._store_eager_variables:
443        reuse = AUTO_REUSE
444
445    # If a *_ref type is passed in an error would be triggered further down the
446    # stack. We prevent this using base_dtype to get a non-ref version of the
447    # type, before doing anything else. When _ref types are removed in favor of
448    # resources, this line can be removed.
449    try:
450      dtype = dtype.base_dtype
451    except AttributeError:
452      # .base_dtype not existing means that we will try and use the raw dtype
453      # which was passed in - this might be a NumPy type which is valid.
454      pass
455
456    # This is the main logic of get_variable.  However, custom_getter
457    # may override this logic.  So we save it as a callable and pass
458    # it to custom_getter.
459    # Note: the parameters of _true_getter, and their documentation, match
460    # *exactly* item-for-item with the docstring of this method.
461    def _true_getter(  # pylint: disable=missing-docstring
462        name,
463        shape=None,
464        dtype=dtypes.float32,
465        initializer=None,
466        regularizer=None,
467        reuse=None,
468        trainable=None,
469        collections=None,
470        caching_device=None,
471        partitioner=None,
472        validate_shape=True,
473        use_resource=None,
474        constraint=None,
475        synchronization=VariableSynchronization.AUTO,
476        aggregation=VariableAggregation.NONE):
477      is_scalar = (
478          shape is not None and isinstance(shape, collections_abc.Sequence) and
479          not shape)
480      # Partitioned variable case
481      if partitioner is not None and not is_scalar:
482        if not callable(partitioner):
483          raise ValueError("Partitioner must be callable, but received: %s" %
484                           partitioner)
485        with ops.name_scope(None):
486          return self._get_partitioned_variable(
487              name=name,
488              shape=shape,
489              dtype=dtype,
490              initializer=initializer,
491              regularizer=regularizer,
492              reuse=reuse,
493              trainable=trainable,
494              collections=collections,
495              caching_device=caching_device,
496              partitioner=partitioner,
497              validate_shape=validate_shape,
498              use_resource=use_resource,
499              constraint=constraint,
500              synchronization=synchronization,
501              aggregation=aggregation)
502
503      # Special case for partitioned variable to allow reuse without having to
504      # specify partitioner.
505      if (reuse is True and partitioner is None
506          and name in self._partitioned_vars):
507        return self._get_partitioned_variable(
508            name=name,
509            shape=shape,
510            dtype=dtype,
511            initializer=initializer,
512            regularizer=regularizer,
513            reuse=reuse,
514            trainable=trainable,
515            collections=collections,
516            caching_device=caching_device,
517            partitioner=None,
518            validate_shape=validate_shape,
519            use_resource=use_resource,
520            constraint=constraint,
521            synchronization=synchronization,
522            aggregation=aggregation)
523
524      # Single variable case
525      if "%s/part_0" % name in self._vars:
526        raise ValueError(
527            "No partitioner was provided, but a partitioned version of the "
528            "variable was found: %s/part_0. Perhaps a variable of the same "
529            "name was already created with partitioning?" % name)
530
531      return self._get_single_variable(
532          name=name,
533          shape=shape,
534          dtype=dtype,
535          initializer=initializer,
536          regularizer=regularizer,
537          reuse=reuse,
538          trainable=trainable,
539          collections=collections,
540          caching_device=caching_device,
541          validate_shape=validate_shape,
542          use_resource=use_resource,
543          constraint=constraint,
544          synchronization=synchronization,
545          aggregation=aggregation)
546
547    synchronization, aggregation, trainable = (
548        variables.validate_synchronization_aggregation_trainable(
549            synchronization, aggregation, trainable, name))
550
551    if custom_getter is not None:
552      # Handle backwards compatibility with getter arguments that were added
553      # to the API after users started writing custom getters.
554      custom_getter_kwargs = {
555          "getter": _true_getter,
556          "name": name,
557          "shape": shape,
558          "dtype": dtype,
559          "initializer": initializer,
560          "regularizer": regularizer,
561          "reuse": reuse,
562          "trainable": trainable,
563          "collections": collections,
564          "caching_device": caching_device,
565          "partitioner": partitioner,
566          "validate_shape": validate_shape,
567          "use_resource": use_resource,
568          "synchronization": synchronization,
569          "aggregation": aggregation,
570      }
571      # `fn_args` and `has_kwargs` can handle functions, `functools.partial`,
572      # `lambda`.
573      if ("constraint" in function_utils.fn_args(custom_getter) or
574          function_utils.has_kwargs(custom_getter)):
575        custom_getter_kwargs["constraint"] = constraint
576      return custom_getter(**custom_getter_kwargs)
577    else:
578      return _true_getter(
579          name,
580          shape=shape,
581          dtype=dtype,
582          initializer=initializer,
583          regularizer=regularizer,
584          reuse=reuse,
585          trainable=trainable,
586          collections=collections,
587          caching_device=caching_device,
588          partitioner=partitioner,
589          validate_shape=validate_shape,
590          use_resource=use_resource,
591          constraint=constraint,
592          synchronization=synchronization,
593          aggregation=aggregation)
594
595  def _get_partitioned_variable(self,
596                                name,
597                                partitioner,
598                                shape=None,
599                                dtype=dtypes.float32,
600                                initializer=None,
601                                regularizer=None,
602                                reuse=None,
603                                trainable=None,
604                                collections=None,
605                                caching_device=None,
606                                validate_shape=True,
607                                use_resource=None,
608                                constraint=None,
609                                synchronization=VariableSynchronization.AUTO,
610                                aggregation=VariableAggregation.NONE):
611    """Gets or creates a sharded variable list with these parameters.
612
613    The `partitioner` must be a callable that accepts a fully defined
614    `TensorShape` and returns a sequence of integers (the `partitions`).
615    These integers describe how to partition the given sharded `Variable`
616    along the given dimension.  That is, `partitions[1] = 3` means split
617    the `Variable` into 3 shards along dimension 1.  Currently, sharding along
618    only one axis is supported.
619
620    If the list of variables with the given name (prefix) is already stored,
621    we return the stored variables. Otherwise, we create a new one.
622
623    Set `reuse` to `True` when you only want to reuse existing Variables.
624    Set `reuse` to `False` when you only want to create new Variables.
625    Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want
626    variables to be created if they don't exist or returned if they do.
627
628    If initializer is `None` (the default), the default initializer passed in
629    the constructor is used. If that one is `None` too, we use a new
630    `glorot_uniform_initializer`. If initializer is a Tensor, we use
631    it as a value and derive the shape from the initializer.
632
633    If the initializer is a callable, then it will be called for each
634    shard.  Otherwise the initializer should match the shape of the entire
635    sharded Variable, and it will be sliced accordingly for each shard.
636
637    Some useful partitioners are available.  See, e.g.,
638    `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
639
640    Args:
641      name: the name of the new or existing sharded variable.
642      partitioner: Optional callable that accepts a fully defined `TensorShape`
643        and `dtype` of the Variable to be created, and returns a list of
644        partitions for each axis (currently only one axis can be partitioned).
645      shape: shape of the new or existing sharded variable.
646      dtype: type of the new or existing sharded variable (defaults to
647        `DT_FLOAT`).
648      initializer: initializer for the sharded variable.
649      regularizer: a (Tensor -> Tensor or None) function; the result of applying
650        it on a newly created variable will be added to the collection
651        GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
652      reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of
653        variables.
654      trainable: If `True` also add the variable to the graph collection
655        `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
656      collections: List of graph collections keys to add the Variable to.
657        Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
658      caching_device: Optional device string or function describing where the
659        Variable should be cached for reading.  Defaults to the Variable's
660        device.  If not `None`, caches on another device.  Typical use is to
661        cache on the device where the Ops using the Variable reside, to
662        deduplicate copying through `Switch` and other conditional statements.
663      validate_shape: If False, allows the variable to be initialized with a
664        value of unknown shape. If True, the default, the shape of initial_value
665        must be known.
666      use_resource: If False, creates a regular Variable. If True, creates an
667        experimental ResourceVariable which has well-defined semantics. Defaults
668        to False (will later change to True).
669      constraint: An optional projection function to be applied to the variable
670        after being updated by an `Optimizer` (e.g. used to implement norm
671        constraints or value constraints for layer weights). The function must
672        take as input the unprojected Tensor representing the value of the
673        variable and return the Tensor for the projected value (which must have
674        the same shape). Constraints are not safe to use when doing asynchronous
675        distributed training.
676      synchronization: Indicates when a distributed a variable will be
677        aggregated. Accepted values are constants defined in the class
678        `tf.VariableSynchronization`. By default the synchronization is set to
679        `AUTO` and the current `DistributionStrategy` chooses when to
680        synchronize.
681      aggregation: Indicates how a distributed variable will be aggregated.
682        Accepted values are constants defined in the class
683        `tf.VariableAggregation`.
684
685    Returns:
686      A `PartitionedVariable` object.
687
688    Raises:
689      ValueError: when creating a new variable and shape is not declared,
690        when reusing a variable and specifying a conflicting shape,
691        when violating reuse during variable creation, or if an existing
692        sharded variable exists for the given name but with different sharding.
693    """
694    initializing_from_value = initializer is not None and isinstance(
695        initializer, ops.Tensor)
696    if name in self._vars:
697      raise ValueError(
698          "A partitioner was provided, but an unpartitioned version of the "
699          "variable was found: %s.  Perhaps a variable of the same name was "
700          "already created without partitioning?" % name)
701
702    shape = tensor_shape.as_shape(shape)
703    if initializing_from_value:
704      shape = shape.merge_with(initializer.get_shape())
705
706    partitions = None
707    if not reuse or partitioner:
708      partitions = _call_partitioner(partitioner, shape, dtype)
709
710    if name in self._partitioned_vars:
711      if reuse is False:
712        raise ValueError(
713            "Partitioned variable with name %s already exists. Did you mean to "
714            "set reuse=True or reuse=tf.AUTO_REUSE in VarScope?" % name)
715
716      existing_var = self._partitioned_vars[name]
717      if not shape.is_compatible_with(existing_var.get_shape()):
718        raise ValueError(
719            "Trying to reuse partitioned variable %s, but specified shape %s "
720            "and found shape %s." % (name, shape, existing_var.get_shape()))
721      if not dtype.is_compatible_with(existing_var.dtype):
722        raise ValueError(
723            "Trying to reuse partitioned variable %s, but specified dtype %s "
724            "and found dtype %s." % (name, dtype.name, existing_var.dtype.name))
725
726      # pylint: disable=protected-access
727      if (partitions is not None and
728          existing_var._get_partitions() != partitions):
729        raise ValueError(
730            "Trying to reuse partitioned variable %s, but specified partitions "
731            "%s and found partitions %s." %
732            (name, partitions, existing_var._get_partitions()))
733      # pylint: enable=protected-access
734
735      return existing_var
736
737    if reuse is True:
738      raise ValueError("PartitionedVariable %s does not exist, or was not "
739                       "created with tf.get_variable(). Did you mean to set "
740                       "reuse=False or reuse=tf.AUTO_REUSE in VarScope?" % name)
741
742    slice_dim, num_slices = _get_slice_dim_and_num_slices(partitions)
743
744    if "%s/part_0" % name in self._vars:
745      if "%s/part_%d" % (name, num_slices - 1) not in self._vars:
746        raise ValueError(
747            "Partitioner returned a different partitioning than what was "
748            "already found.  Partitioner returned %d shards, and shard "
749            "%s/part_0 was found, but %s/part_%d was not." %
750            (num_slices, name, name, num_slices - 1))
751      if "%s/part_%d" % (name, num_slices) in self._vars:
752        raise ValueError(
753            "Partitioner returned a different partitioning than what was "
754            "already found.  Partitioner returned %d shards, and shard "
755            "%s/part_0 was found, but so was the extra shard %s/part_%d." %
756            (num_slices, name, name, num_slices))
757
758    vs = []
759    for i, (var_offset, var_shape) in enumerate(
760        _iter_slices(shape.as_list(), num_slices, slice_dim)):
761      partition_info = _PartitionInfo(
762          full_shape=shape.as_list(), var_offset=var_offset)
763      var_full_name = "%s/part_%d" % (name, i)
764      with ops.name_scope(
765          var_full_name + "/PartitionedInitializer", skip_on_eager=False):
766        # Create the tensor to initialize the variable with default value.
767        if initializer is None:
768          init, initializing_from_value = self._get_default_initializer(
769              name=name, shape=shape, dtype=dtype)
770          if initializing_from_value:
771            init_shape = None
772          else:
773            init_shape = var_shape
774        elif callable(initializer):
775          init = initializer
776          init_shape = var_shape
777        elif isinstance(initializer, ops.Tensor):
778          init = array_ops.slice(initializer, var_offset, var_shape)
779          # Use the dtype of the given tensor.
780          dtype = init.dtype.base_dtype
781          init_shape = None
782        else:
783          init = ops.convert_to_tensor(initializer, dtype=dtype)
784          init = array_ops.slice(init, var_offset, var_shape)
785          init_shape = None
786
787      with ops.name_scope(None):
788        var = self._get_single_variable(
789            name=var_full_name,
790            shape=init_shape,
791            dtype=dtype,
792            initializer=init,
793            partition_info=partition_info,
794            regularizer=regularizer,
795            reuse=reuse,
796            trainable=trainable,
797            collections=collections,
798            caching_device=caching_device,
799            validate_shape=validate_shape,
800            use_resource=use_resource,
801            constraint=constraint,
802            synchronization=synchronization,
803            aggregation=aggregation)
804
805      # pylint: disable=protected-access
806      var._set_save_slice_info(
807          variables.Variable.SaveSliceInfo(name, shape.as_list(), var_offset,
808                                           var_shape))
809      vs.append(var)
810      # pylint: enable=protected-access
811
812    partitioned_var = variables.PartitionedVariable(
813        name=name,
814        shape=shape,
815        dtype=dtype,
816        variable_list=vs,
817        partitions=partitions)
818    if not context.executing_eagerly() or self._store_eager_variables:
819      self._partitioned_vars[name] = partitioned_var
820    return partitioned_var
821
822  def _get_single_variable(self,
823                           name,
824                           shape=None,
825                           dtype=dtypes.float32,
826                           initializer=None,
827                           regularizer=None,
828                           partition_info=None,
829                           reuse=None,
830                           trainable=None,
831                           collections=None,
832                           caching_device=None,
833                           validate_shape=True,
834                           use_resource=None,
835                           constraint=None,
836                           synchronization=VariableSynchronization.AUTO,
837                           aggregation=VariableAggregation.NONE):
838    """Get or create a single Variable (e.g.
839
840    a shard or entire variable).
841
842    See the documentation of get_variable above (ignore partitioning components)
843    for details.
844
845    Args:
846      name: see get_variable.
847      shape: see get_variable.
848      dtype: see get_variable.
849      initializer: see get_variable.
850      regularizer: see get_variable.
851      partition_info: _PartitionInfo object.
852      reuse: see get_variable.
853      trainable: see get_variable.
854      collections: see get_variable.
855      caching_device: see get_variable.
856      validate_shape: see get_variable.
857      use_resource: see get_variable.
858      constraint: see get_variable.
859      synchronization: see get_variable.
860      aggregation: see get_variable.
861
862    Returns:
863      A Variable.  See documentation of get_variable above.
864
865    Raises:
866      ValueError: See documentation of get_variable above.
867    """
868    # Set to true if initializer is a constant.
869    initializing_from_value = False
870    if initializer is not None and not callable(initializer):
871      initializing_from_value = True
872    if shape is not None and initializing_from_value:
873      raise ValueError("If initializer is a constant, do not specify shape.")
874
875    dtype = dtypes.as_dtype(dtype)
876    shape = tensor_shape.as_shape(shape)
877
878    if name in self._vars:
879      # Here we handle the case when returning an existing variable.
880      if reuse is False:
881        var = self._vars[name]
882        err_msg = ("Variable %s already exists, disallowed."
883                   " Did you mean to set reuse=True or "
884                   "reuse=tf.AUTO_REUSE in VarScope?" % name)
885        # ResourceVariables don't have an op associated with so no traceback
886        if isinstance(var, resource_variable_ops.ResourceVariable):
887          raise ValueError(err_msg)
888        tb = var.op.traceback[::-1]
889        # Throw away internal tf entries and only take a few lines. In some
890        # cases the traceback can be longer (e.g. if someone uses factory
891        # functions to create variables) so we take more than needed in the
892        # default case.
893        tb = [x for x in tb if "tensorflow/python" not in x[0]][:5]
894        raise ValueError("%s Originally defined at:\n\n%s" %
895                         (err_msg, "".join(traceback.format_list(tb))))
896      found_var = self._vars[name]
897      if not shape.is_compatible_with(found_var.get_shape()):
898        raise ValueError("Trying to share variable %s, but specified shape %s"
899                         " and found shape %s." %
900                         (name, shape, found_var.get_shape()))
901      if not dtype.is_compatible_with(found_var.dtype):
902        dtype_str = dtype.name
903        found_type_str = found_var.dtype.name
904        raise ValueError("Trying to share variable %s, but specified dtype %s"
905                         " and found dtype %s." %
906                         (name, dtype_str, found_type_str))
907      return found_var
908
909    # The code below handles only the case of creating a new variable.
910    if reuse is True:
911      raise ValueError("Variable %s does not exist, or was not created with "
912                       "tf.get_variable(). Did you mean to set "
913                       "reuse=tf.AUTO_REUSE in VarScope?" % name)
914
915    # Create the tensor to initialize the variable with default value.
916    if initializer is None:
917      initializer, initializing_from_value = self._get_default_initializer(
918          name=name, shape=shape, dtype=dtype)
919    # Enter an init scope when creating the initializer.
920    with ops.init_scope():
921      if initializing_from_value:
922        init_val = initializer
923        variable_dtype = None
924      else:
925        # Instantiate initializer if provided initializer is a type object.
926        if tf_inspect.isclass(initializer):
927          initializer = initializer()
928        if shape.is_fully_defined():
929          if "partition_info" in tf_inspect.getargspec(initializer).args:
930            init_val = functools.partial(initializer,
931                                         shape.as_list(),
932                                         dtype=dtype,
933                                         partition_info=partition_info)
934          else:
935            init_val = functools.partial(initializer,
936                                         shape.as_list(), dtype=dtype)
937          variable_dtype = dtype.base_dtype
938        elif _needs_no_arguments(initializer):
939          init_val = initializer
940          variable_dtype = None
941        else:
942          raise ValueError("The initializer passed is not valid. It should "
943                           "be a callable with no arguments and the "
944                           "shape should not be provided or an instance of "
945                           "`tf.keras.initializers.*' and `shape` should be "
946                           "fully defined.")
947
948    # Create the variable.
949    if use_resource is None:
950      # Set the default value if unspecified.
951      use_resource = _DEFAULT_USE_RESOURCE
952    v = variables.VariableV1(
953        initial_value=init_val,
954        name=name,
955        trainable=trainable,
956        collections=collections,
957        caching_device=caching_device,
958        dtype=variable_dtype,
959        validate_shape=validate_shape,
960        constraint=constraint,
961        use_resource=use_resource,
962        synchronization=synchronization,
963        aggregation=aggregation)
964    if context.executing_eagerly() and self._store_eager_variables:
965      if collections:
966        ops.add_to_collections(collections, v)
967      else:
968        ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, v)
969      if trainable:
970        ops.add_to_collection(ops.GraphKeys.TRAINABLE_VARIABLES, v)
971
972    if not context.executing_eagerly() or self._store_eager_variables:
973      # In eager mode we do not want to keep default references to Variable
974      # objects as this will prevent their memory from being released.
975      self._vars[name] = v
976    logging.vlog(1, "Created variable %s with shape %s and init %s", v.name,
977                 format(shape), initializer)
978
979    # Run the regularizer if requested and save the resulting loss.
980    if regularizer:
981      def make_regularizer_op():
982        with ops.colocate_with(v):
983          with ops.name_scope(name + "/Regularizer/"):
984            return regularizer(v)
985
986      if regularizer(v) is not None:
987        lazy_eval_tensor = _LazyEvalTensor(make_regularizer_op)
988        ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES,
989                              lazy_eval_tensor)
990
991    return v
992
993  # Initialize variable when no initializer provided
994  def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32):
995    """Provide a default initializer and a corresponding value.
996
997    Args:
998      name: see get_variable.
999      shape: see get_variable.
1000      dtype: see get_variable.
1001
1002    Returns:
1003      initializer and initializing_from_value. See get_variable above.
1004
1005    Raises:
1006      ValueError: When giving unsupported dtype.
1007    """
1008    del shape
1009    # If dtype is DT_FLOAT, provide a uniform unit scaling initializer
1010    if dtype.is_floating:
1011      initializer = init_ops.glorot_uniform_initializer()
1012      initializing_from_value = False
1013    # If dtype is DT_INT/DT_UINT, provide a default value `zero`
1014    # If dtype is DT_BOOL, provide a default value `FALSE`
1015    elif (dtype.is_integer or dtype.is_unsigned or dtype.is_bool or
1016          dtype == dtypes.string):
1017      initializer = init_ops.zeros_initializer()
1018      initializing_from_value = False
1019    # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
1020    else:
1021      raise ValueError("An initializer for variable %s of %s is required" %
1022                       (name, dtype.base_dtype))
1023
1024    return initializer, initializing_from_value
1025
1026
1027class _LazyEvalTensor(core.Tensor):
1028  """A Tensor-like object that only evaluates its thunk when used."""
1029
1030  def __init__(self, thunk):
1031    """Initializes a _LazyEvalTensor object.
1032
1033    Args:
1034      thunk: A callable. A thunk which computes the value of the tensor.
1035    """
1036    self._thunk = thunk
1037    self._master_tensor = thunk()
1038
1039  def _as_tensor(self, dtype=None, name=None, as_ref=False):
1040    del name
1041    assert not as_ref
1042    assert dtype in [None, self.dtype]
1043
1044    return self._thunk()
1045
1046
1047def _make_master_property(name):
1048  @property
1049  def prop(self):
1050    return getattr(self._master_tensor, name)  # pylint: disable=protected-access
1051  return prop
1052
1053_master_property_list = ("device", "dtype", "graph", "name", "op", "shape",
1054                         "value_index")
1055for _name in _master_property_list:
1056  setattr(_LazyEvalTensor, _name, _make_master_property(_name))
1057
1058
1059def _make_master_method(name):
1060  def method(self, *args, **kwargs):
1061    return getattr(self._master_tensor, name)(*args, **kwargs)  # pylint: disable=protected-access
1062  return method
1063
1064_master_method_list = ("get_shape", "__str__", "shape_as_list")
1065for _name in _master_method_list:
1066  setattr(_LazyEvalTensor, _name, _make_master_method(_name))
1067
1068
1069def _make_op_method(name):
1070  def method(self, *args, **kwargs):
1071    return getattr(self._as_tensor(), name)(*args, **kwargs)  # pylint: disable=protected-access
1072  return method
1073
1074_op_list = ("__abs__", "__add__", "__and__", "__bool__", "__div__", "__eq__",
1075            "__floordiv__", "__ge__", "__getitem__", "__gt__", "__invert__",
1076            "__iter__", "__le__", "__len__", "__lt__", "__matmul__", "__mod__",
1077            "__mul__", "__ne__", "__neg__", "__nonzero__", "__or__", "__pow__",
1078            "__radd__", "__rand__", "__rdiv__", "__rfloordiv__", "__rmatmul__",
1079            "__rmod__", "__rmul__", "__ror__", "__rpow__", "__rsub__",
1080            "__rtruediv__", "__rxor__", "__sub__", "__truediv__", "__xor__",
1081            "eval", "numpy")
1082for _name in _op_list:
1083  setattr(_LazyEvalTensor, _name, _make_op_method(_name))
1084
1085
1086ops.register_tensor_conversion_function(
1087    _LazyEvalTensor,
1088    lambda val, dtype, name, as_ref: val._as_tensor(dtype, name, as_ref)  # pylint: disable=protected-access
1089    )
1090
1091session.register_session_run_conversion_functions(
1092    _LazyEvalTensor,
1093    lambda fetch: ([fetch._master_tensor], lambda fetched_vals: fetched_vals[0])  # pylint: disable=protected-access
1094    )
1095
1096
1097# To stop regularization, use this regularizer
1098@tf_export(v1=["no_regularizer"])
1099def no_regularizer(_):
1100  """Use this function to prevent regularization of variables."""
1101  return None
1102
1103
1104# TODO(alive): support caching devices and partitioned variables in Eager mode.
1105@tf_export(v1=["VariableScope"])
1106class VariableScope(object):
1107  """Variable scope object to carry defaults to provide to `get_variable`.
1108
1109  Many of the arguments we need for `get_variable` in a variable store are most
1110  easily handled with a context. This object is used for the defaults.
1111
1112  Attributes:
1113    name: name of the current scope, used as prefix in get_variable.
1114    initializer: default initializer passed to get_variable.
1115    regularizer: default regularizer passed to get_variable.
1116    reuse: Boolean, None, or tf.compat.v1.AUTO_REUSE, setting the reuse in
1117      get_variable. When eager execution is enabled this argument is always
1118      forced to be False.
1119    caching_device: string, callable, or None: the caching device passed to
1120      get_variable.
1121    partitioner: callable or `None`: the partitioner passed to `get_variable`.
1122    custom_getter: default custom getter passed to get_variable.
1123    name_scope: The name passed to `tf.name_scope`.
1124    dtype: default type passed to get_variable (defaults to DT_FLOAT).
1125    use_resource: if False, create a normal Variable; if True create an
1126      experimental ResourceVariable with well-defined semantics. Defaults to
1127      False (will later change to True). When eager execution is enabled this
1128      argument is always forced to be True.
1129    constraint: An optional projection function to be applied to the variable
1130      after being updated by an `Optimizer` (e.g. used to implement norm
1131      constraints or value constraints for layer weights). The function must
1132      take as input the unprojected Tensor representing the value of the
1133      variable and return the Tensor for the projected value (which must have
1134      the same shape). Constraints are not safe to use when doing asynchronous
1135      distributed training.
1136  """
1137
1138  def __init__(self,
1139               reuse,
1140               name="",
1141               initializer=None,
1142               regularizer=None,
1143               caching_device=None,
1144               partitioner=None,
1145               custom_getter=None,
1146               name_scope="",
1147               dtype=dtypes.float32,
1148               use_resource=None,
1149               constraint=None):
1150    """Creates a new VariableScope with the given properties."""
1151    self._name = name
1152    self._initializer = initializer
1153    self._regularizer = regularizer
1154    self._reuse = reuse
1155    self._caching_device = caching_device
1156    self._partitioner = partitioner
1157    self._custom_getter = custom_getter
1158    self._name_scope = name_scope
1159    self._dtype = dtype
1160    self._use_resource = use_resource
1161    self._constraint = constraint
1162    if context.executing_eagerly():
1163      if self._caching_device is not None:
1164        raise NotImplementedError("Caching devices is not yet supported "
1165                                  "when eager execution is enabled.")
1166      self._reuse = AUTO_REUSE
1167      self._use_resource = True
1168
1169  @property
1170  def name(self):
1171    return self._name
1172
1173  @property
1174  def original_name_scope(self):
1175    return self._name_scope
1176
1177  @property
1178  def reuse(self):
1179    return self._reuse
1180
1181  @property
1182  def initializer(self):
1183    return self._initializer
1184
1185  @property
1186  def dtype(self):
1187    return self._dtype
1188
1189  @property
1190  def use_resource(self):
1191    return self._use_resource
1192
1193  @property
1194  def regularizer(self):
1195    return self._regularizer
1196
1197  @property
1198  def caching_device(self):
1199    return self._caching_device
1200
1201  @property
1202  def partitioner(self):
1203    return self._partitioner
1204
1205  @property
1206  def custom_getter(self):
1207    return self._custom_getter
1208
1209  @property
1210  def constraint(self):
1211    return self._constraint
1212
1213  def reuse_variables(self):
1214    """Reuse variables in this scope."""
1215    self._reuse = True
1216
1217  def set_initializer(self, initializer):
1218    """Set initializer for this scope."""
1219    self._initializer = initializer
1220
1221  def set_dtype(self, dtype):
1222    """Set data type for this scope."""
1223    self._dtype = dtype
1224
1225  def set_use_resource(self, use_resource):
1226    """Sets whether to use ResourceVariables for this scope."""
1227    if context.executing_eagerly() and not use_resource:
1228      raise ValueError("When eager execution is enabled, "
1229                       "use_resource cannot be set to false.")
1230    self._use_resource = use_resource
1231
1232  def set_regularizer(self, regularizer):
1233    """Set regularizer for this scope."""
1234    self._regularizer = regularizer
1235
1236  def set_caching_device(self, caching_device):
1237    """Set caching_device for this scope."""
1238    if context.executing_eagerly():
1239      raise NotImplementedError("Caching devices are not yet supported "
1240                                "when eager execution is enabled.")
1241    self._caching_device = caching_device
1242
1243  def set_partitioner(self, partitioner):
1244    """Set partitioner for this scope."""
1245    self._partitioner = partitioner
1246
1247  def set_custom_getter(self, custom_getter):
1248    """Set custom getter for this scope."""
1249    self._custom_getter = custom_getter
1250
1251  def get_collection(self, name):
1252    """Get this scope's variables."""
1253    scope = self._name + "/" if self._name else ""
1254    return ops.get_collection(name, scope)
1255
1256  def trainable_variables(self):
1257    """Get this scope's trainable variables."""
1258    return self.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
1259
1260  def global_variables(self):
1261    """Get this scope's global variables."""
1262    return self.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
1263
1264  def local_variables(self):
1265    """Get this scope's local variables."""
1266    return self.get_collection(ops.GraphKeys.LOCAL_VARIABLES)
1267
1268  def get_variable(self,
1269                   var_store,
1270                   name,
1271                   shape=None,
1272                   dtype=None,
1273                   initializer=None,
1274                   regularizer=None,
1275                   reuse=None,
1276                   trainable=None,
1277                   collections=None,
1278                   caching_device=None,
1279                   partitioner=None,
1280                   validate_shape=True,
1281                   use_resource=None,
1282                   custom_getter=None,
1283                   constraint=None,
1284                   synchronization=VariableSynchronization.AUTO,
1285                   aggregation=VariableAggregation.NONE):
1286    """Gets an existing variable with this name or create a new one."""
1287    if regularizer is None:
1288      regularizer = self._regularizer
1289    if caching_device is None:
1290      caching_device = self._caching_device
1291    if partitioner is None:
1292      partitioner = self._partitioner
1293    if custom_getter is None:
1294      custom_getter = self._custom_getter
1295    if context.executing_eagerly():
1296      reuse = False
1297      use_resource = True
1298    else:
1299      if reuse is None:
1300        reuse = self._reuse
1301      if use_resource is None:
1302        use_resource = self._use_resource
1303
1304    full_name = self.name + "/" + name if self.name else name
1305    # Variable names only depend on variable_scope (full_name here),
1306    # not name_scope, so we reset it below for the time of variable creation.
1307    with ops.name_scope(None, skip_on_eager=False):
1308      # Check that `initializer` dtype and `dtype` are consistent before
1309      # replacing them with defaults.
1310      if (dtype is not None and initializer is not None and
1311          not callable(initializer)):
1312        init_dtype = ops.convert_to_tensor(initializer).dtype.base_dtype
1313        if init_dtype != dtype:
1314          raise ValueError("Initializer type '%s' and explicit dtype '%s' "
1315                           "don't match." % (init_dtype, dtype))
1316      if initializer is None:
1317        initializer = self._initializer
1318      if constraint is None:
1319        constraint = self._constraint
1320      if dtype is None:
1321        dtype = self._dtype
1322      return var_store.get_variable(
1323          full_name,
1324          shape=shape,
1325          dtype=dtype,
1326          initializer=initializer,
1327          regularizer=regularizer,
1328          reuse=reuse,
1329          trainable=trainable,
1330          collections=collections,
1331          caching_device=caching_device,
1332          partitioner=partitioner,
1333          validate_shape=validate_shape,
1334          use_resource=use_resource,
1335          custom_getter=custom_getter,
1336          constraint=constraint,
1337          synchronization=synchronization,
1338          aggregation=aggregation)
1339
1340  def _get_partitioned_variable(self,
1341                                var_store,
1342                                name,
1343                                shape=None,
1344                                dtype=None,
1345                                initializer=None,
1346                                regularizer=None,
1347                                trainable=None,
1348                                collections=None,
1349                                caching_device=None,
1350                                partitioner=None,
1351                                validate_shape=True,
1352                                use_resource=None,
1353                                constraint=None,
1354                                synchronization=VariableSynchronization.AUTO,
1355                                aggregation=VariableAggregation.NONE):
1356    """Gets an existing variable with this name or create a new one."""
1357    if initializer is None:
1358      initializer = self._initializer
1359    if regularizer is None:
1360      regularizer = self._regularizer
1361    if constraint is None:
1362      constraint = self._constraint
1363    if caching_device is None:
1364      caching_device = self._caching_device
1365    if partitioner is None:
1366      partitioner = self._partitioner
1367    if dtype is None:
1368      dtype = self._dtype
1369    if use_resource is None:
1370      use_resource = self._use_resource
1371
1372    if self._custom_getter is not None:
1373      raise ValueError(
1374          "Private access to _get_partitioned_variable is not allowed when "
1375          "a custom getter is set.  Current custom getter: %s.  "
1376          "It is likely that you're using create_partitioned_variables.  "
1377          "If so, consider instead using get_variable with a non-empty "
1378          "partitioner parameter instead." % self._custom_getter)
1379
1380    if partitioner is None:
1381      raise ValueError("No partitioner was specified")
1382
1383    # This allows the variable scope name to be used as the variable name if
1384    # this function is invoked with an empty name arg, for backward
1385    # compatibility with create_partitioned_variables().
1386    full_name_list = []
1387    if self.name:
1388      full_name_list.append(self.name)
1389    if name:
1390      full_name_list.append(name)
1391    full_name = "/".join(full_name_list)
1392
1393    # Variable names only depend on variable_scope (full_name here),
1394    # not name_scope, so we reset it below for the time of variable creation.
1395    with ops.name_scope(None, skip_on_eager=False):
1396      # pylint: disable=protected-access
1397      return var_store._get_partitioned_variable(
1398          full_name,
1399          shape=shape,
1400          dtype=dtype,
1401          initializer=initializer,
1402          regularizer=regularizer,
1403          reuse=self.reuse,
1404          trainable=trainable,
1405          collections=collections,
1406          caching_device=caching_device,
1407          partitioner=partitioner,
1408          validate_shape=validate_shape,
1409          use_resource=use_resource,
1410          constraint=constraint,
1411          synchronization=synchronization,
1412          aggregation=aggregation)
1413      # pylint: enable=protected-access
1414
1415
1416_VARSTORE_KEY = ("__variable_store",)
1417_VARSCOPESTORE_KEY = ("__varscope",)
1418
1419
1420class _VariableScopeStore(threading.local):
1421  """A thread local store for the current variable scope and scope counts."""
1422
1423  def __init__(self):
1424    super(_VariableScopeStore, self).__init__()
1425    self.current_scope = VariableScope(False)
1426    self.variable_scopes_count = {}
1427
1428  def open_variable_scope(self, scope_name):
1429    if scope_name in self.variable_scopes_count:
1430      self.variable_scopes_count[scope_name] += 1
1431    else:
1432      self.variable_scopes_count[scope_name] = 1
1433
1434  def close_variable_subscopes(self, scope_name):
1435    for k in list(self.variable_scopes_count.keys()):
1436      if scope_name is None or k.startswith(scope_name + "/"):
1437        self.variable_scopes_count[k] = 0
1438
1439  def variable_scope_count(self, scope_name):
1440    return self.variable_scopes_count.get(scope_name, 0)
1441
1442
1443def get_variable_scope_store():
1444  """Returns the variable scope store for current thread."""
1445  scope_store = ops.get_collection(_VARSCOPESTORE_KEY)
1446
1447  if not scope_store:
1448    scope_store = _VariableScopeStore()
1449    ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store)
1450  else:
1451    scope_store = scope_store[0]
1452
1453  return scope_store
1454
1455
1456@tf_export(v1=["get_variable_scope"])
1457def get_variable_scope():
1458  """Returns the current variable scope."""
1459  return get_variable_scope_store().current_scope
1460
1461
1462def _get_default_variable_store():
1463  store = ops.get_collection(_VARSTORE_KEY)
1464  if store:
1465    return store[0]
1466  store = _VariableStore()
1467  ops.add_to_collection(_VARSTORE_KEY, store)
1468  return store
1469
1470
1471@tf_contextlib.contextmanager
1472def with_variable_store(store):
1473  store_collection = ops.get_collection_ref(_VARSTORE_KEY)
1474  old = list(store_collection)
1475  store_collection[:] = [store]
1476  try:
1477    yield
1478  finally:
1479    store_collection[:] = old
1480
1481
1482class EagerVariableStore(object):
1483  """Wrapper allowing functional layers to be used with eager execution.
1484
1485  When eager execution is enabled Variables get deleted when they go out of
1486  scope, and are not stored in global collections by default. A lot of code
1487  (mostly the functional layers in tf.layers) assumes that variables are kept in
1488  a global list.
1489
1490  EagerVariableStore can be used in conjunction with this code to make it
1491  eager-friendly. For example, to create a dense layer, use:
1492
1493  ```
1494    container = tfe.EagerVariableStore()
1495    for input in dataset_iterator:
1496      with container.as_default():
1497        x = tf.compat.v1.layers.dense(input, name="l1")
1498    print(container.variables)  # Should print the variables used in the layer.
1499  ```
1500  """
1501
1502  def __init__(self, store=None):
1503    if store is not None:
1504      if not store._store_eager_variables:  # pylint: disable=protected-access
1505        raise ValueError("Cannot construct EagerVariableStore from a "
1506                         "VariableStore object that does not hold eager "
1507                         "variables.")
1508      self._store = store
1509    else:
1510      self._store = _VariableStore()
1511    self._store._store_eager_variables = True  # pylint: disable=protected-access
1512
1513  def as_default(self):
1514    return with_variable_store(self._store)
1515
1516  def variables(self):
1517    return sorted(self._store._vars.values(), key=lambda x: x.name)  # pylint: disable=protected-access
1518
1519  def trainable_variables(self):
1520    # pylint: disable=protected-access
1521    return sorted([x for x in self._store._vars.values() if x.trainable],
1522                  key=lambda x: x.name)
1523    # pylint: enable=protected-access
1524
1525  def non_trainable_variables(self):
1526    # pylint: disable=protected-access
1527    return sorted([x for x in self._store._vars.values() if not x.trainable],
1528                  key=lambda x: x.name)
1529    # pylint: enable=protected-access
1530
1531  def copy(self):
1532    """Copy this variable store and all of its contents.
1533
1534    Variables contained in this store will be copied over to the new variable
1535    store, meaning that they can be modified without affecting the variables in
1536    this store.
1537
1538    Returns:
1539      A new EagerVariableStore instance containing copied variables.
1540    """
1541    # pylint: disable=protected-access
1542    new_store = EagerVariableStore()
1543    for key, var in iteritems(self._store._vars):
1544      # Strip device out of variable name.
1545      try:
1546        index = var.name.index(":")
1547      except ValueError:
1548        stripped_var_name = var.name
1549      else:
1550        stripped_var_name = var.name[:index]
1551
1552      # Create new variable with same value, name, and "trainable" flag.
1553      new_var = resource_variable_ops.ResourceVariable(
1554          var.read_value(), name=stripped_var_name, trainable=var.trainable)
1555      new_store._store._vars[key] = new_var
1556    return new_store
1557    # pylint: enable=protected-access
1558
1559
1560# The argument list for get_variable must match arguments to get_local_variable.
1561# So, if you are updating the arguments, also update arguments to
1562# get_local_variable below.
1563@tf_export(v1=["get_variable"])
1564def get_variable(name,
1565                 shape=None,
1566                 dtype=None,
1567                 initializer=None,
1568                 regularizer=None,
1569                 trainable=None,
1570                 collections=None,
1571                 caching_device=None,
1572                 partitioner=None,
1573                 validate_shape=True,
1574                 use_resource=None,
1575                 custom_getter=None,
1576                 constraint=None,
1577                 synchronization=VariableSynchronization.AUTO,
1578                 aggregation=VariableAggregation.NONE):
1579  return get_variable_scope().get_variable(
1580      _get_default_variable_store(),
1581      name,
1582      shape=shape,
1583      dtype=dtype,
1584      initializer=initializer,
1585      regularizer=regularizer,
1586      trainable=trainable,
1587      collections=collections,
1588      caching_device=caching_device,
1589      partitioner=partitioner,
1590      validate_shape=validate_shape,
1591      use_resource=use_resource,
1592      custom_getter=custom_getter,
1593      constraint=constraint,
1594      synchronization=synchronization,
1595      aggregation=aggregation)
1596
1597
1598get_variable_or_local_docstring = ("""%s
1599
1600%sThis function prefixes the name with the current variable scope
1601and performs reuse checks. See the
1602[Variable Scope How To](https://tensorflow.org/guide/variables)
1603for an extensive description of how reusing works. Here is a basic example:
1604
1605```python
1606def foo():
1607  with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
1608    v = tf.get_variable("v", [1])
1609  return v
1610
1611v1 = foo()  # Creates v.
1612v2 = foo()  # Gets the same, existing v.
1613assert v1 == v2
1614```
1615
1616If initializer is `None` (the default), the default initializer passed in
1617the variable scope will be used. If that one is `None` too, a
1618`glorot_uniform_initializer` will be used. The initializer can also be
1619a Tensor, in which case the variable is initialized to this value and shape.
1620
1621Similarly, if the regularizer is `None` (the default), the default regularizer
1622passed in the variable scope will be used (if that is `None` too,
1623then by default no regularization is performed).
1624
1625If a partitioner is provided, a `PartitionedVariable` is returned.
1626Accessing this object as a `Tensor` returns the shards concatenated along
1627the partition axis.
1628
1629Some useful partitioners are available.  See, e.g.,
1630`variable_axis_size_partitioner` and `min_max_variable_partitioner`.
1631
1632Args:
1633  name: The name of the new or existing variable.
1634  shape: Shape of the new or existing variable.
1635  dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
1636  initializer: Initializer for the variable if one is created. Can either be
1637    an initializer object or a Tensor. If it's a Tensor, its shape must be known
1638    unless validate_shape is False.
1639  regularizer: A (Tensor -> Tensor or None) function; the result of
1640    applying it on a newly created variable will be added to the collection
1641    `tf.GraphKeys.REGULARIZATION_LOSSES` and can be used for regularization.
1642  %scollections: List of graph collections keys to add the Variable to.
1643    Defaults to `[%s]` (see `tf.Variable`).
1644  caching_device: Optional device string or function describing where the
1645    Variable should be cached for reading.  Defaults to the Variable's
1646    device.  If not `None`, caches on another device.  Typical use is to
1647    cache on the device where the Ops using the Variable reside, to
1648    deduplicate copying through `Switch` and other conditional statements.
1649  partitioner: Optional callable that accepts a fully defined `TensorShape`
1650    and `dtype` of the Variable to be created, and returns a list of
1651    partitions for each axis (currently only one axis can be partitioned).
1652  validate_shape: If False, allows the variable to be initialized with a
1653      value of unknown shape. If True, the default, the shape of initial_value
1654      must be known. For this to be used the initializer must be a Tensor and
1655      not an initializer object.
1656  use_resource: If False, creates a regular Variable. If true, creates an
1657    experimental ResourceVariable instead with well-defined semantics.
1658    Defaults to False (will later change to True). When eager execution is
1659    enabled this argument is always forced to be True.
1660  custom_getter: Callable that takes as a first argument the true getter, and
1661    allows overwriting the internal get_variable method.
1662    The signature of `custom_getter` should match that of this method,
1663    but the most future-proof version will allow for changes:
1664    `def custom_getter(getter, *args, **kwargs)`.  Direct access to
1665    all `get_variable` parameters is also allowed:
1666    `def custom_getter(getter, name, *args, **kwargs)`.  A simple identity
1667    custom getter that simply creates variables with modified names is:
1668    ```python
1669    def custom_getter(getter, name, *args, **kwargs):
1670      return getter(name + '_suffix', *args, **kwargs)
1671    ```
1672  constraint: An optional projection function to be applied to the variable
1673    after being updated by an `Optimizer` (e.g. used to implement norm
1674    constraints or value constraints for layer weights). The function must
1675    take as input the unprojected Tensor representing the value of the
1676    variable and return the Tensor for the projected value
1677    (which must have the same shape). Constraints are not safe to
1678    use when doing asynchronous distributed training.
1679  synchronization: Indicates when a distributed a variable will be
1680    aggregated. Accepted values are constants defined in the class
1681    `tf.VariableSynchronization`. By default the synchronization is set to
1682    `AUTO` and the current `DistributionStrategy` chooses
1683    when to synchronize.
1684  aggregation: Indicates how a distributed variable will be aggregated.
1685    Accepted values are constants defined in the class
1686    `tf.VariableAggregation`.
1687
1688Returns:
1689  The created or existing `Variable` (or `PartitionedVariable`, if a
1690  partitioner was used).
1691
1692Raises:
1693  ValueError: when creating a new variable and shape is not declared,
1694    when violating reuse during variable creation, or when `initializer` dtype
1695    and `dtype` don't match. Reuse is set inside `variable_scope`.
1696""")
1697get_variable.__doc__ = get_variable_or_local_docstring % (
1698    "Gets an existing variable with these parameters or create a new one.", "",
1699    "trainable: If `True` also add the variable to the graph collection\n"
1700    "    `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).\n  ",
1701    "GraphKeys.GLOBAL_VARIABLES")
1702
1703
1704# The argument list for get_local_variable must match arguments to get_variable.
1705# So, if you are updating the arguments, also update arguments to get_variable.
1706@tf_export(v1=["get_local_variable"])
1707def get_local_variable(  # pylint: disable=missing-docstring
1708    name,
1709    shape=None,
1710    dtype=None,
1711    initializer=None,
1712    regularizer=None,
1713    trainable=False,  # pylint: disable=unused-argument
1714    collections=None,
1715    caching_device=None,
1716    partitioner=None,
1717    validate_shape=True,
1718    use_resource=None,
1719    custom_getter=None,
1720    constraint=None,
1721    synchronization=VariableSynchronization.AUTO,
1722    aggregation=VariableAggregation.NONE):
1723  if collections:
1724    collections += [ops.GraphKeys.LOCAL_VARIABLES]
1725  else:
1726    collections = [ops.GraphKeys.LOCAL_VARIABLES]
1727  return get_variable(
1728      name,
1729      shape=shape,
1730      dtype=dtype,
1731      initializer=initializer,
1732      regularizer=regularizer,
1733      trainable=False,
1734      collections=collections,
1735      caching_device=caching_device,
1736      partitioner=partitioner,
1737      validate_shape=validate_shape,
1738      use_resource=use_resource,
1739      synchronization=synchronization,
1740      aggregation=aggregation,
1741      custom_getter=custom_getter,
1742      constraint=constraint)
1743
1744
1745get_local_variable.__doc__ = get_variable_or_local_docstring % (
1746    "Gets an existing *local* variable or creates a new one.",
1747    "Behavior is the same as in `get_variable`, except that variables are\n"
1748    "added to the `LOCAL_VARIABLES` collection and `trainable` is set to\n"
1749    "`False`.\n", "", "GraphKeys.LOCAL_VARIABLES")
1750
1751
1752def _get_partitioned_variable(name,
1753                              shape=None,
1754                              dtype=None,
1755                              initializer=None,
1756                              regularizer=None,
1757                              trainable=True,
1758                              collections=None,
1759                              caching_device=None,
1760                              partitioner=None,
1761                              validate_shape=True,
1762                              use_resource=None,
1763                              constraint=None,
1764                              synchronization=VariableSynchronization.AUTO,
1765                              aggregation=VariableAggregation.NONE):
1766  """Gets or creates a sharded variable list with these parameters.
1767
1768  The `partitioner` must be a callable that accepts a fully defined
1769  `TensorShape` and returns a sequence of integers (the `partitions`).
1770  These integers describe how to partition the given sharded `Variable`
1771  along the given dimension.  That is, `partitions[1] = 3` means split
1772  the `Variable` into 3 shards along dimension 1.  Currently, sharding along
1773  only one axis is supported.
1774
1775  If the list of variables with the given name (prefix) is already stored,
1776  we return the stored variables. Otherwise, we create a new one.
1777
1778  If initializer is `None` (the default), the default initializer passed in
1779  the constructor is used. If that one is `None` too, we use a new
1780  `glorot_uniform_initializer`. If initializer is a Tensor, we use
1781  it as a value and derive the shape from the initializer.
1782
1783  If the initializer is a callable, then it will be called for each
1784  shard.  Otherwise the initializer should match the shape of the entire
1785  sharded Variable, and it will be sliced accordingly for each shard.
1786
1787  Some useful partitioners are available.  See, e.g.,
1788  `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
1789
1790  Args:
1791    name: The name of the new or existing variable.
1792    shape: Shape of the new or existing variable.
1793    dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
1794    initializer: Initializer for the variable if one is created.
1795    regularizer: A (Tensor -> Tensor or None) function; the result of applying
1796      it on a newly created variable will be added to the collection
1797      GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
1798    trainable: If `True` also add the variable to the graph collection
1799      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
1800    collections: List of graph collections keys to add the Variable to. Defaults
1801      to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
1802    caching_device: Optional device string or function describing where the
1803      Variable should be cached for reading.  Defaults to the Variable's device.
1804      If not `None`, caches on another device.  Typical use is to cache on the
1805      device where the Ops using the Variable reside, to deduplicate copying
1806      through `Switch` and other conditional statements.
1807    partitioner: Optional callable that accepts a fully defined `TensorShape`
1808      and `dtype` of the Variable to be created, and returns a list of
1809      partitions for each axis (currently only one axis can be partitioned).
1810    validate_shape: If False, allows the variable to be initialized with a value
1811      of unknown shape. If True, the default, the shape of initial_value must be
1812      known.
1813    use_resource: If False, creates a regular Variable. If True, creates an
1814      experimental ResourceVariable instead which has well-defined semantics.
1815      Defaults to False (will later change to True).
1816    constraint: An optional projection function to be applied to the variable
1817      after being updated by an `Optimizer` (e.g. used to implement norm
1818      constraints or value constraints for layer weights). The function must
1819      take as input the unprojected Tensor representing the value of the
1820      variable and return the Tensor for the projected value (which must have
1821      the same shape). Constraints are not safe to use when doing asynchronous
1822      distributed training.
1823    synchronization: Indicates when a distributed a variable will be aggregated.
1824      Accepted values are constants defined in the class
1825      `tf.VariableSynchronization`. By default the synchronization is set to
1826      `AUTO` and the current `DistributionStrategy` chooses when to synchronize.
1827    aggregation: Indicates how a distributed variable will be aggregated.
1828      Accepted values are constants defined in the class
1829      `tf.VariableAggregation`.
1830
1831  Returns:
1832    A tuple `(shards, partitions)` where `shards` is the list of `Variable`
1833    shards and `partitions` is the output of the partitioner on the input
1834    shape.
1835
1836  Raises:
1837    ValueError: when creating a new variable and shape is not declared,
1838      or when violating reuse during variable creation. Reuse is set inside
1839      `variable_scope`.
1840  """
1841  # pylint: disable=protected-access
1842  scope = get_variable_scope()
1843  if scope.custom_getter is not None:
1844    raise ValueError(
1845        "Private access to _get_partitioned_variable is not allowed when "
1846        "a custom getter is set.  Current custom getter: %s.  "
1847        "It is likely that you're using create_partitioned_variables.  "
1848        "If so, consider instead using get_variable with a non-empty "
1849        "partitioner parameter instead." % scope.custom_getter)
1850  return scope._get_partitioned_variable(
1851      _get_default_variable_store(),
1852      name,
1853      shape=shape,
1854      dtype=dtype,
1855      initializer=initializer,
1856      regularizer=regularizer,
1857      trainable=trainable,
1858      collections=collections,
1859      caching_device=caching_device,
1860      partitioner=partitioner,
1861      validate_shape=validate_shape,
1862      use_resource=use_resource,
1863      constraint=constraint,
1864      synchronization=synchronization,
1865      aggregation=aggregation)
1866  # pylint: enable=protected-access
1867
1868
1869# Named like a function for compatibility with the previous
1870# @tf_contextlib.contextmanager definition.
1871class _pure_variable_scope(object):  # pylint: disable=invalid-name
1872  """A context for the variable_scope, see `variable_scope` for docs."""
1873
1874  def __init__(self,
1875               name_or_scope,
1876               reuse=None,
1877               initializer=None,
1878               regularizer=None,
1879               caching_device=None,
1880               partitioner=None,
1881               custom_getter=None,
1882               old_name_scope=None,
1883               dtype=dtypes.float32,
1884               use_resource=None,
1885               constraint=None):
1886    """Creates a context for the variable_scope, see `variable_scope` for docs.
1887
1888    Note: this does not create a name scope.
1889
1890    Args:
1891      name_or_scope: `string` or `VariableScope`: the scope to open.
1892      reuse: `True` or None, or tf.compat.v1.AUTO_REUSE; if `None`, we inherit
1893        the parent scope's reuse flag.
1894      initializer: default initializer for variables within this scope.
1895      regularizer: default regularizer for variables within this scope.
1896      caching_device: default caching device for variables within this scope.
1897      partitioner: default partitioner for variables within this scope.
1898      custom_getter: default custom getter for variables within this scope.
1899      old_name_scope: the original name scope when re-entering a variable scope.
1900      dtype: type of the variables within this scope (defaults to `DT_FLOAT`).
1901      use_resource: If False, variables in this scope will be regular Variables.
1902        If True, experimental ResourceVariables will be creates instead, with
1903        well-defined semantics. Defaults to False (will later change to True).
1904      constraint: An optional projection function to be applied to the variable
1905        after being updated by an `Optimizer` (e.g. used to implement norm
1906        constraints or value constraints for layer weights). The function must
1907        take as input the unprojected Tensor representing the value of the
1908        variable and return the Tensor for the projected value (which must have
1909        the same shape). Constraints are not safe to use when doing asynchronous
1910        distributed training.
1911    """
1912    self._name_or_scope = name_or_scope
1913    self._reuse = reuse
1914    self._initializer = initializer
1915    self._regularizer = regularizer
1916    self._caching_device = caching_device
1917    self._partitioner = partitioner
1918    self._custom_getter = custom_getter
1919    self._old_name_scope = old_name_scope
1920    self._dtype = dtype
1921    self._use_resource = use_resource
1922    self._constraint = constraint
1923    self._var_store = _get_default_variable_store()
1924    self._var_scope_store = get_variable_scope_store()
1925    self._last_variable_scope_object = None
1926    if isinstance(self._name_or_scope, VariableScope):
1927      self._new_name = self._name_or_scope.name
1928      name_scope = self._name_or_scope._name_scope  # pylint: disable=protected-access
1929      # Handler for the case when we jump to a shared scope.  We create a new
1930      #   VariableScope (self._var_scope_object) that contains a copy of the
1931      #   provided shared scope, possibly with changed reuse and initializer, if
1932      #   the user requested this.
1933      variable_scope_object = VariableScope(
1934          self._name_or_scope.reuse if not self._reuse else self._reuse,
1935          name=self._new_name,
1936          initializer=self._name_or_scope.initializer,
1937          regularizer=self._name_or_scope.regularizer,
1938          caching_device=self._name_or_scope.caching_device,
1939          partitioner=self._name_or_scope.partitioner,
1940          dtype=self._name_or_scope.dtype,
1941          custom_getter=self._name_or_scope.custom_getter,
1942          name_scope=name_scope,
1943          use_resource=self._name_or_scope.use_resource,
1944          constraint=self._constraint)
1945      if self._initializer is not None:
1946        variable_scope_object.set_initializer(self._initializer)
1947      if self._regularizer is not None:
1948        variable_scope_object.set_regularizer(self._regularizer)
1949      if self._caching_device is not None:
1950        variable_scope_object.set_caching_device(self._caching_device)
1951      if self._partitioner is not None:
1952        variable_scope_object.set_partitioner(self._partitioner)
1953      if self._custom_getter is not None:
1954        variable_scope_object.set_custom_getter(
1955            _maybe_wrap_custom_getter(self._custom_getter,
1956                                      self._name_or_scope.custom_getter))
1957      if self._dtype is not None:
1958        variable_scope_object.set_dtype(self._dtype)
1959      if self._use_resource is not None:
1960        variable_scope_object.set_use_resource(self._use_resource)
1961      self._cached_variable_scope_object = variable_scope_object
1962
1963  def __enter__(self):
1964    """Begins the scope block.
1965
1966    Returns:
1967      A VariableScope.
1968    Raises:
1969      ValueError: when trying to reuse within a create scope, or create within
1970        a reuse scope, or if reuse is not `None` or `True`.
1971      TypeError: when the types of some arguments are not appropriate.
1972    """
1973    self._old = self._var_scope_store.current_scope
1974    if isinstance(self._name_or_scope, VariableScope):
1975      self._var_scope_store.open_variable_scope(self._new_name)
1976      self._old_subscopes = copy.copy(
1977          self._var_scope_store.variable_scopes_count)
1978      variable_scope_object = self._cached_variable_scope_object
1979    else:
1980      # Handler for the case when we just prolong current variable scope.
1981      #   VariableScope with name extended by the provided one, and inherited
1982      #   reuse and initializer (except if the user provided values to set).
1983      self._new_name = (
1984          self._old.name + "/" +
1985          self._name_or_scope if self._old.name else self._name_or_scope)
1986      self._reuse = (self._reuse or
1987                     self._old.reuse)  # Re-using is inherited by sub-scopes.
1988      if self._old_name_scope is None:
1989        name_scope = self._name_or_scope
1990      else:
1991        name_scope = self._old_name_scope
1992      variable_scope_object = VariableScope(
1993          self._reuse,
1994          name=self._new_name,
1995          initializer=self._old.initializer,
1996          regularizer=self._old.regularizer,
1997          caching_device=self._old.caching_device,
1998          partitioner=self._old.partitioner,
1999          dtype=self._old.dtype,
2000          use_resource=self._old.use_resource,
2001          custom_getter=self._old.custom_getter,
2002          name_scope=name_scope,
2003          constraint=self._constraint)
2004      if self._initializer is not None:
2005        variable_scope_object.set_initializer(self._initializer)
2006      if self._regularizer is not None:
2007        variable_scope_object.set_regularizer(self._regularizer)
2008      if self._caching_device is not None:
2009        variable_scope_object.set_caching_device(self._caching_device)
2010      if self._partitioner is not None:
2011        variable_scope_object.set_partitioner(self._partitioner)
2012      if self._custom_getter is not None:
2013        variable_scope_object.set_custom_getter(
2014            _maybe_wrap_custom_getter(self._custom_getter,
2015                                      self._old.custom_getter))
2016      if self._dtype is not None:
2017        variable_scope_object.set_dtype(self._dtype)
2018      if self._use_resource is not None:
2019        variable_scope_object.set_use_resource(self._use_resource)
2020      self._var_scope_store.open_variable_scope(self._new_name)
2021    self._var_scope_store.current_scope = variable_scope_object
2022    self._last_variable_scope_object = variable_scope_object
2023    return variable_scope_object
2024
2025  def __exit__(self, type_arg, value_arg, traceback_arg):
2026    if (self._var_scope_store.current_scope is
2027        not self._last_variable_scope_object):
2028      raise RuntimeError("Improper nesting of variable_scope.")
2029    # If jumping out from a non-prolonged scope, restore counts.
2030    if isinstance(self._name_or_scope, VariableScope):
2031      self._var_scope_store.variable_scopes_count = self._old_subscopes
2032    else:
2033      self._var_scope_store.close_variable_subscopes(self._new_name)
2034    self._var_scope_store.current_scope = self._old
2035
2036
2037def _maybe_wrap_custom_getter(custom_getter, old_getter):
2038  """Wrap a call to a custom_getter to use the old_getter internally."""
2039  if old_getter is None:
2040    return custom_getter
2041
2042  # The new custom_getter should call the old one
2043  def wrapped_custom_getter(getter, *args, **kwargs):
2044    # Call:
2045    #  custom_getter(
2046    #    lambda: old_getter(true_getter, ...), *args, **kwargs)
2047    # which means custom_getter will call old_getter, which
2048    # will call the true_getter, perform any intermediate
2049    # processing, and return the results to the current
2050    # getter, which will also perform additional processing.
2051    return custom_getter(functools.partial(old_getter, getter), *args, **kwargs)
2052
2053  return wrapped_custom_getter
2054
2055
2056def _get_unique_variable_scope(prefix):
2057  """Get a name with the given prefix unique in the current variable scope."""
2058  var_scope_store = get_variable_scope_store()
2059  current_scope = get_variable_scope()
2060  name = current_scope.name + "/" + prefix if current_scope.name else prefix
2061  if var_scope_store.variable_scope_count(name) == 0:
2062    return prefix
2063  idx = 1
2064  while var_scope_store.variable_scope_count(name + ("_%d" % idx)) > 0:
2065    idx += 1
2066  return prefix + ("_%d" % idx)
2067
2068
2069# Named like a function for backwards compatibility with the
2070# @tf_contextlib.contextmanager version, which was switched to a class to avoid
2071# some object creation overhead.
2072@tf_export(v1=["variable_scope"])  # pylint: disable=invalid-name
2073class variable_scope(object):
2074  """A context manager for defining ops that creates variables (layers).
2075
2076  This context manager validates that the (optional) `values` are from the same
2077  graph, ensures that graph is the default graph, and pushes a name scope and a
2078  variable scope.
2079
2080  If `name_or_scope` is not None, it is used as is. If `name_or_scope` is None,
2081  then `default_name` is used.  In that case, if the same name has been
2082  previously used in the same scope, it will be made unique by appending `_N`
2083  to it.
2084
2085  Variable scope allows you to create new variables and to share already created
2086  ones while providing checks to not create or share by accident. For details,
2087  see the [Variable Scope How To](https://tensorflow.org/guide/variables), here
2088  we present only a few basic examples.
2089
2090  The Variable Scope works as expected when the Eager Execution is Disabled.
2091
2092  ```python
2093  tf.compat.v1.disable_eager_execution()
2094  ```
2095
2096  Simple example of how to create a new variable:
2097
2098  ```python
2099  with tf.compat.v1.variable_scope("foo"):
2100      with tf.compat.v1.variable_scope("bar"):
2101          v = tf.compat.v1.get_variable("v", [1])
2102          assert v.name == "foo/bar/v:0"
2103  ```
2104
2105  Simple example of how to reenter a premade variable scope safely:
2106
2107  ```python
2108  with tf.compat.v1.variable_scope("foo") as vs:
2109    pass
2110
2111  # Re-enter the variable scope.
2112  with tf.compat.v1.variable_scope(vs,
2113                         auxiliary_name_scope=False) as vs1:
2114    # Restore the original name_scope.
2115    with tf.name_scope(vs1.original_name_scope):
2116        v = tf.compat.v1.get_variable("v", [1])
2117        assert v.name == "foo/v:0"
2118        c = tf.constant([1], name="c")
2119        assert c.name == "foo/c:0"
2120  ```
2121
2122  Keep in mind that the counters for `default_name` are discarded once the
2123  parent scope is exited. Therefore when the code re-enters the scope (for
2124  instance by saving it), all nested default_name counters will be restarted.
2125
2126  For instance:
2127
2128  ```python
2129  with tf.compat.v1.variable_scope("foo") as vs:
2130    with tf.compat.v1.variable_scope(None, default_name="bar"):
2131      v = tf.compat.v1.get_variable("a", [1])
2132      assert v.name == "foo/bar/a:0", v.name
2133    with tf.compat.v1.variable_scope(None, default_name="bar"):
2134      v = tf.compat.v1.get_variable("b", [1])
2135      assert v.name == "foo/bar_1/b:0"
2136
2137  with tf.compat.v1.variable_scope(vs):
2138    with tf.compat.v1.variable_scope(None, default_name="bar"):
2139      v = tf.compat.v1.get_variable("c", [1])
2140      assert v.name == "foo/bar/c:0"   # Uses bar instead of bar_2!
2141  ```
2142
2143  Basic example of sharing a variable AUTO_REUSE:
2144
2145  ```python
2146  def foo():
2147    with tf.compat.v1.variable_scope("foo", reuse=tf.compat.v1.AUTO_REUSE):
2148      v = tf.compat.v1.get_variable("v", [1])
2149    return v
2150
2151  v1 = foo()  # Creates v.
2152  v2 = foo()  # Gets the same, existing v.
2153  assert v1 == v2
2154  ```
2155
2156  Basic example of sharing a variable with reuse=True:
2157
2158  ```python
2159  with tf.compat.v1.variable_scope("foo"):
2160      v = tf.compat.v1.get_variable("v", [1])
2161  with tf.compat.v1.variable_scope("foo", reuse=True):
2162      v1 = tf.compat.v1.get_variable("v", [1])
2163  assert v1 == v
2164  ```
2165
2166  Sharing a variable by capturing a scope and setting reuse:
2167
2168  ```python
2169  with tf.compat.v1.variable_scope("foo") as scope:
2170      v = tf.compat.v1.get_variable("v", [1])
2171      scope.reuse_variables()
2172      v1 = tf.compat.v1.get_variable("v", [1])
2173  assert v1 == v
2174  ```
2175
2176  To prevent accidental sharing of variables, we raise an exception when getting
2177  an existing variable in a non-reusing scope.
2178
2179  ```python
2180  with tf.compat.v1.variable_scope("foo"):
2181      v = tf.compat.v1.get_variable("v", [1])
2182      v1 = tf.compat.v1.get_variable("v", [1])
2183      #  Raises ValueError("... v already exists ...").
2184  ```
2185
2186  Similarly, we raise an exception when trying to get a variable that does not
2187  exist in reuse mode.
2188
2189  ```python
2190  with tf.compat.v1.variable_scope("foo", reuse=True):
2191      v = tf.compat.v1.get_variable("v", [1])
2192      #  Raises ValueError("... v does not exists ...").
2193  ```
2194
2195  Note that the `reuse` flag is inherited: if we open a reusing scope, then all
2196  its sub-scopes become reusing as well.
2197
2198  A note about name scoping: Setting `reuse` does not impact the naming of other
2199  ops such as mult. See related discussion on
2200  [github#6189](https://github.com/tensorflow/tensorflow/issues/6189)
2201
2202  Note that up to and including version 1.0, it was allowed (though explicitly
2203  discouraged) to pass False to the reuse argument, yielding undocumented
2204  behaviour slightly different from None. Starting at 1.1.0 passing None and
2205  False as reuse has exactly the same effect.
2206
2207  A note about using variable scopes in multi-threaded environment: Variable
2208  scopes are thread local, so one thread will not see another thread's current
2209  scope. Also, when using `default_name`, unique scopes names are also generated
2210  only on a per thread basis. If the same name was used within a different
2211  thread, that doesn't prevent a new thread from creating the same scope.
2212  However, the underlying variable store is shared across threads (within the
2213  same graph). As such, if another thread tries to create a new variable with
2214  the same name as a variable created by a previous thread, it will fail unless
2215  reuse is True.
2216
2217  Further, each thread starts with an empty variable scope. So if you wish to
2218  preserve name prefixes from a scope from the main thread, you should capture
2219  the main thread's scope and re-enter it in each thread. For e.g.
2220
2221  ```
2222  main_thread_scope = variable_scope.get_variable_scope()
2223
2224  # Thread's target function:
2225  def thread_target_fn(captured_scope):
2226    with variable_scope.variable_scope(captured_scope):
2227      # .... regular code for this thread
2228
2229
2230  thread = threading.Thread(target=thread_target_fn, args=(main_thread_scope,))
2231  ```
2232  """
2233
2234  def __init__(self,
2235               name_or_scope,
2236               default_name=None,
2237               values=None,
2238               initializer=None,
2239               regularizer=None,
2240               caching_device=None,
2241               partitioner=None,
2242               custom_getter=None,
2243               reuse=None,
2244               dtype=None,
2245               use_resource=None,
2246               constraint=None,
2247               auxiliary_name_scope=True):
2248    """Initialize the context manager.
2249
2250    Args:
2251      name_or_scope: `string` or `VariableScope`: the scope to open.
2252      default_name: The default name to use if the `name_or_scope` argument is
2253        `None`, this name will be uniquified. If name_or_scope is provided it
2254        won't be used and therefore it is not required and can be None.
2255      values: The list of `Tensor` arguments that are passed to the op function.
2256      initializer: default initializer for variables within this scope.
2257      regularizer: default regularizer for variables within this scope.
2258      caching_device: default caching device for variables within this scope.
2259      partitioner: default partitioner for variables within this scope.
2260      custom_getter: default custom getter for variables within this scope.
2261      reuse: `True`, None, or tf.compat.v1.AUTO_REUSE; if `True`, we go into
2262        reuse mode for this scope as well as all sub-scopes; if
2263        tf.compat.v1.AUTO_REUSE, we create variables if they do not exist, and
2264        return them otherwise; if None, we inherit the parent scope's reuse
2265        flag. When eager execution is enabled, new variables are always created
2266        unless an EagerVariableStore or template is currently active.
2267      dtype: type of variables created in this scope (defaults to the type in
2268        the passed scope, or inherited from parent scope).
2269      use_resource: If False, all variables will be regular Variables. If True,
2270        experimental ResourceVariables with well-defined semantics will be used
2271        instead. Defaults to False (will later change to True). When eager
2272        execution is enabled this argument is always forced to be True.
2273      constraint: An optional projection function to be applied to the variable
2274        after being updated by an `Optimizer` (e.g. used to implement norm
2275        constraints or value constraints for layer weights). The function must
2276        take as input the unprojected Tensor representing the value of the
2277        variable and return the Tensor for the projected value (which must have
2278        the same shape). Constraints are not safe to use when doing asynchronous
2279        distributed training.
2280      auxiliary_name_scope: If `True`, we create an auxiliary name scope with
2281        the scope. If `False`, we don't create it. Note that the argument is not
2282        inherited, and it only takes effect for once when creating. You should
2283        only use it for re-entering a premade variable scope.
2284
2285    Returns:
2286      A scope that can be captured and reused.
2287
2288    Raises:
2289      ValueError: when trying to reuse within a create scope, or create within
2290        a reuse scope.
2291      TypeError: when the types of some arguments are not appropriate.
2292    """
2293    self._name_or_scope = name_or_scope
2294    self._default_name = default_name
2295    self._values = values
2296    self._initializer = initializer
2297    self._regularizer = regularizer
2298    self._caching_device = caching_device
2299    self._partitioner = partitioner
2300    self._custom_getter = custom_getter
2301    self._reuse = reuse
2302    self._dtype = dtype
2303    self._use_resource = use_resource
2304    self._constraint = constraint
2305    if self._default_name is None and self._name_or_scope is None:
2306      raise TypeError("If default_name is None then name_or_scope is required")
2307    if self._reuse is False:
2308      # We don't allow non-inheriting scopes, False = None here.
2309      self._reuse = None
2310    if not (self._reuse is True
2311            or self._reuse is None
2312            or self._reuse is AUTO_REUSE):
2313      raise ValueError("The reuse parameter must be True or False or None.")
2314    if self._values is None:
2315      self._values = []
2316    self._in_graph_mode = not context.executing_eagerly()
2317    if self._in_graph_mode:
2318      self._graph = ops._get_graph_from_inputs(self._values)  # pylint: disable=protected-access
2319    self._cached_pure_variable_scope = None
2320    self._current_name_scope = None
2321    if not isinstance(auxiliary_name_scope, bool):
2322      raise TypeError("The auxiliary_name_scope must be `True` or `False`, "
2323                      "while get {}".format(auxiliary_name_scope))
2324    self._auxiliary_name_scope = auxiliary_name_scope
2325
2326  def __enter__(self):
2327    # If the default graph is building a function, then we should not replace it
2328    # with the cached graph.
2329    if ops.get_default_graph().building_function:
2330      self._building_function = True
2331    else:
2332      self._building_function = False
2333    if self._in_graph_mode and not self._building_function:
2334      self._graph_context_manager = self._graph.as_default()
2335      self._graph_context_manager.__enter__()
2336    if self._cached_pure_variable_scope is not None:
2337      # Fast path for re-entering variable_scopes. We've held on to the pure
2338      # variable scope from a previous successful __enter__, so we avoid some
2339      # overhead by re-using that object.
2340      if self._current_name_scope is not None:
2341        self._current_name_scope.__enter__()
2342      return self._cached_pure_variable_scope.__enter__()
2343
2344    try:
2345      return self._enter_scope_uncached()
2346    except:
2347      if (self._in_graph_mode and not self._building_function and
2348          self._graph_context_manager is not None):
2349        self._graph_context_manager.__exit__(*sys.exc_info())
2350      raise
2351
2352  def _enter_scope_uncached(self):
2353    """Enters the context manager when there is no cached scope yet.
2354
2355    Returns:
2356      The entered variable scope.
2357
2358    Raises:
2359      TypeError: A wrong type is passed as `scope` at __init__().
2360      ValueError: `reuse` is incorrectly set at __init__().
2361    """
2362    if self._auxiliary_name_scope:
2363      # Create a new name scope later
2364      current_name_scope = None
2365    else:
2366      # Reenter the current name scope
2367      name_scope = ops.get_name_scope()
2368      if name_scope:
2369        # Hack to reenter
2370        name_scope += "/"
2371        current_name_scope = ops.name_scope(name_scope, skip_on_eager=False)
2372      else:
2373        # Root scope
2374        current_name_scope = ops.name_scope(name_scope, skip_on_eager=False)
2375
2376    # IMPORTANT: Only assign to self._cached_pure_variable_scope and
2377    # self._current_name_scope after successful __enter__() calls.
2378    if self._name_or_scope is not None:
2379      if not isinstance(self._name_or_scope,
2380                        (VariableScope,) + six.string_types):
2381        raise TypeError("VariableScope: name_or_scope must be a string or "
2382                        "VariableScope.")
2383      if isinstance(self._name_or_scope, six.string_types):
2384        name_scope = self._name_or_scope
2385      else:
2386        name_scope = self._name_or_scope.name.split("/")[-1]
2387      if name_scope or current_name_scope:
2388        current_name_scope = current_name_scope or ops.name_scope(
2389            name_scope, skip_on_eager=False)
2390        try:
2391          current_name_scope_name = current_name_scope.__enter__()
2392        except:
2393          current_name_scope.__exit__(*sys.exc_info())
2394          raise
2395        self._current_name_scope = current_name_scope
2396        if isinstance(self._name_or_scope, six.string_types):
2397          old_name_scope = current_name_scope_name
2398        else:
2399          old_name_scope = self._name_or_scope.original_name_scope
2400        pure_variable_scope = _pure_variable_scope(
2401            self._name_or_scope,
2402            reuse=self._reuse,
2403            initializer=self._initializer,
2404            regularizer=self._regularizer,
2405            caching_device=self._caching_device,
2406            partitioner=self._partitioner,
2407            custom_getter=self._custom_getter,
2408            old_name_scope=old_name_scope,
2409            dtype=self._dtype,
2410            use_resource=self._use_resource,
2411            constraint=self._constraint)
2412        try:
2413          entered_pure_variable_scope = pure_variable_scope.__enter__()
2414        except:
2415          pure_variable_scope.__exit__(*sys.exc_info())
2416          raise
2417        self._cached_pure_variable_scope = pure_variable_scope
2418        return entered_pure_variable_scope
2419      else:
2420        self._current_name_scope = None
2421        # This can only happen if someone is entering the root variable scope.
2422        pure_variable_scope = _pure_variable_scope(
2423            self._name_or_scope,
2424            reuse=self._reuse,
2425            initializer=self._initializer,
2426            regularizer=self._regularizer,
2427            caching_device=self._caching_device,
2428            partitioner=self._partitioner,
2429            custom_getter=self._custom_getter,
2430            dtype=self._dtype,
2431            use_resource=self._use_resource,
2432            constraint=self._constraint)
2433        try:
2434          entered_pure_variable_scope = pure_variable_scope.__enter__()
2435        except:
2436          pure_variable_scope.__exit__(*sys.exc_info())
2437          raise
2438        self._cached_pure_variable_scope = pure_variable_scope
2439        return entered_pure_variable_scope
2440
2441    else:  # Here name_or_scope is None. Using default name, but made unique.
2442      if self._reuse:
2443        raise ValueError("reuse=True cannot be used without a name_or_scope")
2444      current_name_scope = current_name_scope or ops.name_scope(
2445          self._default_name, skip_on_eager=False)
2446      try:
2447        current_name_scope_name = current_name_scope.__enter__()
2448      except:
2449        current_name_scope.__exit__(*sys.exc_info())
2450        raise
2451      self._current_name_scope = current_name_scope
2452      unique_default_name = _get_unique_variable_scope(self._default_name)
2453      pure_variable_scope = _pure_variable_scope(
2454          unique_default_name,
2455          initializer=self._initializer,
2456          regularizer=self._regularizer,
2457          caching_device=self._caching_device,
2458          partitioner=self._partitioner,
2459          custom_getter=self._custom_getter,
2460          old_name_scope=current_name_scope_name,
2461          dtype=self._dtype,
2462          use_resource=self._use_resource,
2463          constraint=self._constraint)
2464      try:
2465        entered_pure_variable_scope = pure_variable_scope.__enter__()
2466      except:
2467        pure_variable_scope.__exit__(*sys.exc_info())
2468        raise
2469      self._cached_pure_variable_scope = pure_variable_scope
2470      return entered_pure_variable_scope
2471
2472  def __exit__(self, type_arg, value_arg, traceback_arg):
2473    try:
2474      self._cached_pure_variable_scope.__exit__(type_arg, value_arg,
2475                                                traceback_arg)
2476    finally:
2477      try:
2478        if self._current_name_scope:
2479          self._current_name_scope.__exit__(type_arg, value_arg,
2480                                            traceback_arg)
2481      finally:
2482        if self._in_graph_mode and not self._building_function:
2483          self._graph_context_manager.__exit__(type_arg, value_arg,
2484                                               traceback_arg)
2485
2486
2487# pylint: disable=g-doc-return-or-yield
2488@tf_export(v1=["variable_op_scope"])
2489@tf_contextlib.contextmanager
2490def variable_op_scope(values,
2491                      name_or_scope,
2492                      default_name=None,
2493                      initializer=None,
2494                      regularizer=None,
2495                      caching_device=None,
2496                      partitioner=None,
2497                      custom_getter=None,
2498                      reuse=None,
2499                      dtype=None,
2500                      use_resource=None,
2501                      constraint=None):
2502  """Deprecated: context manager for defining an op that creates variables."""
2503  logging.warn("tf.variable_op_scope(values, name, default_name) is deprecated,"
2504               " use tf.variable_scope(name, default_name, values)")
2505  with variable_scope(
2506      name_or_scope,
2507      default_name=default_name,
2508      values=values,
2509      initializer=initializer,
2510      regularizer=regularizer,
2511      caching_device=caching_device,
2512      partitioner=partitioner,
2513      custom_getter=custom_getter,
2514      reuse=reuse,
2515      dtype=dtype,
2516      use_resource=use_resource,
2517      constraint=constraint) as scope:
2518    yield scope
2519
2520
2521def _call_partitioner(partitioner, shape, dtype):
2522  """Call partitioner validating its inputs/output.
2523
2524  Args:
2525    partitioner: a function mapping `Tensor` shape and dtype to a list of
2526      partitions.
2527    shape: shape of the `Tensor` to partition, must have at least two
2528      dimensions.
2529    dtype: dtype of the elements in the `Tensor`.
2530
2531  Returns:
2532    A list with elements >=1 and exactly one >1. The index of that
2533    element corresponds to the partitioning axis.
2534  """
2535  if not shape.is_fully_defined():
2536    raise ValueError("Shape of a new partitioned variable must be "
2537                     "fully defined, but instead was %s." % (shape,))
2538  if shape.ndims < 1:
2539    raise ValueError("A partitioned Variable must have rank at least 1, "
2540                     "shape: %s" % shape)
2541
2542  slicing = partitioner(shape=shape, dtype=dtype)
2543  if not isinstance(slicing, collections_abc.Sequence):
2544    raise ValueError("Partitioner must return a sequence, but saw: %s" %
2545                     slicing)
2546  if len(slicing) != shape.ndims:
2547    raise ValueError(
2548        "Partitioner returned a partition list that does not match the "
2549        "Variable's rank: %s vs. %s" % (slicing, shape))
2550  if any(p < 1 for p in slicing):
2551    raise ValueError("Partitioner returned zero partitions for some axes: %s" %
2552                     slicing)
2553  if sum(p > 1 for p in slicing) > 1:
2554    raise ValueError("Can only slice a variable along one dimension: "
2555                     "shape: %s, partitioning: %s" % (shape, slicing))
2556  return slicing
2557
2558
2559# TODO(slebedev): could be inlined, but
2560# `_VariableStore._get_partitioned_variable` is too complex even
2561# without this logic.
2562def _get_slice_dim_and_num_slices(slicing):
2563  """Get slicing dimension and number of slices from the partitioner output."""
2564  for slice_dim, num_slices in enumerate(slicing):
2565    if num_slices > 1:
2566      break
2567  else:
2568    # Degenerate case: no partitioning applied.
2569    slice_dim = 0
2570    num_slices = 1
2571  return slice_dim, num_slices
2572
2573
2574def _iter_slices(full_shape, num_slices, slice_dim):
2575  """Slices a given a shape along the specified dimension."""
2576  num_slices_with_excess = full_shape[slice_dim] % num_slices
2577  offset = [0] * len(full_shape)
2578  min_slice_len = full_shape[slice_dim] // num_slices
2579  for i in xrange(num_slices):
2580    shape = full_shape[:]
2581    shape[slice_dim] = min_slice_len + bool(i < num_slices_with_excess)
2582    yield offset[:], shape
2583    offset[slice_dim] += shape[slice_dim]
2584
2585
2586def default_variable_creator(next_creator=None, **kwargs):
2587  """Default variable creator."""
2588  assert next_creator is None
2589  initial_value = kwargs.get("initial_value", None)
2590  trainable = kwargs.get("trainable", None)
2591  collections = kwargs.get("collections", None)
2592  validate_shape = kwargs.get("validate_shape", True)
2593  caching_device = kwargs.get("caching_device", None)
2594  name = kwargs.get("name", None)
2595  variable_def = kwargs.get("variable_def", None)
2596  dtype = kwargs.get("dtype", None)
2597  expected_shape = kwargs.get("expected_shape", None)
2598  import_scope = kwargs.get("import_scope", None)
2599  constraint = kwargs.get("constraint", None)
2600  use_resource = kwargs.get("use_resource", None)
2601  synchronization = kwargs.get("synchronization", None)
2602  aggregation = kwargs.get("aggregation", None)
2603  shape = kwargs.get("shape", None)
2604
2605  if use_resource is None:
2606    use_resource = get_variable_scope().use_resource
2607  if use_resource is None:
2608    use_resource = _DEFAULT_USE_RESOURCE
2609  use_resource = use_resource or context.executing_eagerly()
2610  if use_resource:
2611    distribute_strategy = kwargs.get("distribute_strategy", None)
2612    return resource_variable_ops.ResourceVariable(
2613        initial_value=initial_value,
2614        trainable=trainable,
2615        collections=collections,
2616        validate_shape=validate_shape,
2617        caching_device=caching_device,
2618        name=name,
2619        dtype=dtype,
2620        constraint=constraint,
2621        variable_def=variable_def,
2622        import_scope=import_scope,
2623        distribute_strategy=distribute_strategy,
2624        synchronization=synchronization,
2625        aggregation=aggregation,
2626        shape=shape)
2627  else:
2628    return variables.RefVariable(
2629        initial_value=initial_value,
2630        trainable=trainable,
2631        collections=collections,
2632        validate_shape=validate_shape,
2633        caching_device=caching_device,
2634        name=name,
2635        dtype=dtype,
2636        constraint=constraint,
2637        variable_def=variable_def,
2638        expected_shape=expected_shape,
2639        import_scope=import_scope,
2640        synchronization=synchronization,
2641        aggregation=aggregation,
2642        shape=shape)
2643
2644
2645def default_variable_creator_v2(next_creator=None, **kwargs):
2646  """Default variable creator."""
2647  assert next_creator is None
2648  initial_value = kwargs.get("initial_value", None)
2649  trainable = kwargs.get("trainable", None)
2650  validate_shape = kwargs.get("validate_shape", True)
2651  caching_device = kwargs.get("caching_device", None)
2652  name = kwargs.get("name", None)
2653  variable_def = kwargs.get("variable_def", None)
2654  dtype = kwargs.get("dtype", None)
2655  import_scope = kwargs.get("import_scope", None)
2656  constraint = kwargs.get("constraint", None)
2657  distribute_strategy = kwargs.get("distribute_strategy", None)
2658  synchronization = kwargs.get("synchronization", None)
2659  aggregation = kwargs.get("aggregation", None)
2660  shape = kwargs.get("shape", None)
2661
2662  return resource_variable_ops.ResourceVariable(
2663      initial_value=initial_value,
2664      trainable=trainable,
2665      validate_shape=validate_shape,
2666      caching_device=caching_device,
2667      name=name,
2668      dtype=dtype,
2669      constraint=constraint,
2670      variable_def=variable_def,
2671      import_scope=import_scope,
2672      distribute_strategy=distribute_strategy,
2673      synchronization=synchronization,
2674      aggregation=aggregation,
2675      shape=shape)
2676
2677
2678variables.default_variable_creator = default_variable_creator
2679variables.default_variable_creator_v2 = default_variable_creator_v2
2680
2681
2682def _make_getter(captured_getter, captured_previous):
2683  """Gets around capturing loop variables in python being broken."""
2684  return lambda **kwargs: captured_getter(captured_previous, **kwargs)
2685
2686
2687# TODO(apassos) remove forwarding symbol
2688variable = variables.VariableV1
2689
2690
2691@tf_export(v1=["variable_creator_scope"])
2692@tf_contextlib.contextmanager
2693def variable_creator_scope_v1(variable_creator):
2694  """Scope which defines a variable creation function to be used by variable().
2695
2696  variable_creator is expected to be a function with the following signature:
2697
2698  ```
2699    def variable_creator(next_creator, **kwargs)
2700  ```
2701
2702  The creator is supposed to eventually call the next_creator to create a
2703  variable if it does want to create a variable and not call Variable or
2704  ResourceVariable directly. This helps make creators composable. A creator may
2705  choose to create multiple variables, return already existing variables, or
2706  simply register that a variable was created and defer to the next creators in
2707  line. Creators can also modify the keyword arguments seen by the next
2708  creators.
2709
2710  Custom getters in the variable scope will eventually resolve down to these
2711  custom creators when they do create variables.
2712
2713  The valid keyword arguments in kwds are:
2714
2715   * initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
2716        which is the initial value for the Variable. The initial value must have
2717        a shape specified unless `validate_shape` is set to False. Can also be a
2718        callable with no argument that returns the initial value when called. In
2719        that case, `dtype` must be specified. (Note that initializer functions
2720        from init_ops.py must first be bound to a shape before being used here.)
2721   * trainable: If `True`, the default, also adds the variable to the graph
2722        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
2723        the default list of variables to use by the `Optimizer` classes.
2724        `trainable` defaults to `True`, unless `synchronization` is
2725        set to `ON_READ`, in which case it defaults to `False`.
2726   * collections: List of graph collections keys. The new variable is added to
2727        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
2728   * validate_shape: If `False`, allows the variable to be initialized with a
2729        value of unknown shape. If `True`, the default, the shape of
2730        `initial_value` must be known.
2731   * caching_device: Optional device string describing where the Variable
2732        should be cached for reading.  Defaults to the Variable's device.
2733        If not `None`, caches on another device.  Typical use is to cache
2734        on the device where the Ops using the Variable reside, to deduplicate
2735        copying through `Switch` and other conditional statements.
2736   * name: Optional name for the variable. Defaults to `'Variable'` and gets
2737        uniquified automatically.
2738   * dtype: If set, initial_value will be converted to the given type.
2739        If `None`, either the datatype will be kept (if `initial_value` is
2740        a Tensor), or `convert_to_tensor` will decide.
2741   * constraint: A constraint function to be applied to the variable after
2742        updates by some algorithms.
2743   * use_resource: if True, a ResourceVariable is always created.
2744   * synchronization: Indicates when a distributed a variable will be
2745        aggregated. Accepted values are constants defined in the class
2746        `tf.VariableSynchronization`. By default the synchronization is set to
2747        `AUTO` and the current `DistributionStrategy` chooses
2748        when to synchronize.
2749   * aggregation: Indicates how a distributed variable will be aggregated.
2750        Accepted values are constants defined in the class
2751        `tf.VariableAggregation`.
2752
2753  This set may grow over time, so it's important the signature of creators is as
2754  mentioned above.
2755
2756  Args:
2757    variable_creator: the passed creator
2758
2759  Yields:
2760    A scope in which the creator is active
2761  """
2762  with ops.get_default_graph()._variable_creator_scope(variable_creator):  # pylint: disable=protected-access
2763    yield
2764
2765
2766# Note: only the docstrings differ between this and v1.
2767@tf_export("variable_creator_scope", v1=[])
2768@tf_contextlib.contextmanager
2769def variable_creator_scope(variable_creator):
2770  """Scope which defines a variable creation function to be used by variable().
2771
2772  variable_creator is expected to be a function with the following signature:
2773
2774  ```
2775    def variable_creator(next_creator, **kwargs)
2776  ```
2777
2778  The creator is supposed to eventually call the next_creator to create a
2779  variable if it does want to create a variable and not call Variable or
2780  ResourceVariable directly. This helps make creators composable. A creator may
2781  choose to create multiple variables, return already existing variables, or
2782  simply register that a variable was created and defer to the next creators in
2783  line. Creators can also modify the keyword arguments seen by the next
2784  creators.
2785
2786  Custom getters in the variable scope will eventually resolve down to these
2787  custom creators when they do create variables.
2788
2789  The valid keyword arguments in kwds are:
2790
2791   * initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
2792        which is the initial value for the Variable. The initial value must have
2793        a shape specified unless `validate_shape` is set to False. Can also be a
2794        callable with no argument that returns the initial value when called. In
2795        that case, `dtype` must be specified. (Note that initializer functions
2796        from init_ops.py must first be bound to a shape before being used here.)
2797   * trainable: If `True`, the default, GradientTapes automatically watch
2798        uses of this Variable.
2799   * validate_shape: If `False`, allows the variable to be initialized with a
2800        value of unknown shape. If `True`, the default, the shape of
2801        `initial_value` must be known.
2802   * caching_device: Optional device string describing where the Variable
2803        should be cached for reading.  Defaults to the Variable's device.
2804        If not `None`, caches on another device.  Typical use is to cache
2805        on the device where the Ops using the Variable reside, to deduplicate
2806        copying through `Switch` and other conditional statements.
2807   * name: Optional name for the variable. Defaults to `'Variable'` and gets
2808        uniquified automatically.
2809      dtype: If set, initial_value will be converted to the given type.
2810        If `None`, either the datatype will be kept (if `initial_value` is
2811        a Tensor), or `convert_to_tensor` will decide.
2812   * constraint: A constraint function to be applied to the variable after
2813        updates by some algorithms.
2814   * synchronization: Indicates when a distributed a variable will be
2815        aggregated. Accepted values are constants defined in the class
2816        `tf.VariableSynchronization`. By default the synchronization is set to
2817        `AUTO` and the current `DistributionStrategy` chooses
2818        when to synchronize.
2819   * aggregation: Indicates how a distributed variable will be aggregated.
2820        Accepted values are constants defined in the class
2821        `tf.VariableAggregation`.
2822
2823  This set may grow over time, so it's important the signature of creators is as
2824  mentioned above.
2825
2826  Args:
2827    variable_creator: the passed creator
2828
2829  Yields:
2830    A scope in which the creator is active
2831  """
2832  with ops.get_default_graph()._variable_creator_scope(variable_creator):  # pylint: disable=protected-access
2833    yield
2834