1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A class to store named variables and a scope operator to manage sharing.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import enum # pylint: disable=g-bad-import-order 23import functools 24import sys 25import threading 26import traceback 27 28import six 29from six import iteritems 30from six.moves import xrange, zip # pylint: disable=redefined-builtin 31 32from tensorflow.python import tf2 33from tensorflow.python.client import session 34from tensorflow.python.eager import context 35from tensorflow.python.eager import monitoring 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_shape 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import init_ops 41from tensorflow.python.ops import resource_variable_ops 42from tensorflow.python.ops import variables 43from tensorflow.python.platform import tf_logging as logging 44from tensorflow.python.types import core 45from tensorflow.python.util import deprecation 46from tensorflow.python.util import function_utils 47from tensorflow.python.util import tf_contextlib 48from tensorflow.python.util import tf_inspect 49from tensorflow.python.util.compat import collections_abc 50from tensorflow.python.util.tf_export import tf_export 51 52__all__ = [ 53 "AUTO_REUSE", "VariableScope", "get_variable_scope", "get_variable", 54 "get_local_variable", "variable_scope", "variable_op_scope", 55 "no_regularizer", "VariableSynchronization", "VariableAggregation" 56] 57 58_api_usage_gauge = monitoring.BoolGauge( 59 "/tensorflow/api/resource_variables", 60 "Whether variable_scope.enable_resource_variables() is called.") 61 62 63class _PartitionInfo(object): 64 """Holds partition info used by initializer functions.""" 65 66 __slots__ = ["_full_shape", "_var_offset"] 67 68 def __init__(self, full_shape, var_offset): 69 """Constructor. 70 71 Args: 72 full_shape: Tuple or list of `int` indicating the full combined shape of 73 the partitioned variables. 74 var_offset: Tuple or list of `int` specifying offset of this partition 75 with respect to the full variable for each dimension. 76 77 Raises: 78 TypeError: If `full_shape` or `var_offset` is not a sequence. 79 ValueError: If `full_shape` or `var_offset` differ in length. If 80 `var_offset` exceeds `full_shape` in any dimension. 81 """ 82 if not isinstance(full_shape, collections_abc.Sequence) or isinstance( 83 full_shape, six.string_types): 84 raise TypeError( 85 "`full_shape` must be a sequence (like tuple or list) instead of " + 86 type(full_shape).__name__) 87 88 if not isinstance(var_offset, collections_abc.Sequence) or isinstance( 89 var_offset, six.string_types): 90 raise TypeError( 91 "`var_offset` must be a sequence (like tuple or list) instead of " + 92 type(var_offset).__name__) 93 94 if len(var_offset) != len(full_shape): 95 raise ValueError( 96 "Expected equal length, but `var_offset` is of length {} while " 97 "full_shape is of length {}.".format( 98 len(var_offset), len(full_shape))) 99 100 for offset, shape in zip(var_offset, full_shape): 101 if offset < 0 or offset >= shape: 102 raise ValueError( 103 "Expected 0 <= offset < shape but found offset={}, shape={} for " 104 "var_offset={}, full_shape={}".format(offset, shape, var_offset, 105 full_shape)) 106 107 self._full_shape = full_shape 108 self._var_offset = var_offset 109 110 @property 111 def full_shape(self): 112 return self._full_shape 113 114 @property 115 def var_offset(self): 116 return self._var_offset 117 118 def single_offset(self, shape): 119 """Returns the offset when the variable is partitioned in at most one dim. 120 121 Args: 122 shape: Tuple or list of `int` indicating the shape of one specific 123 variable partition. 124 125 Returns: 126 `int` representing the offset in the dimension along which the variable is 127 partitioned. Returns 0 if the variable is not being partitioned. 128 129 Raises: 130 ValueError: Depending on self.single_slice_dim(). 131 """ 132 133 single_slice_dim = self.single_slice_dim(shape) 134 # If this variable is not being partitioned at all, single_slice_dim() could 135 # return None. 136 if single_slice_dim is None: 137 return 0 138 return self.var_offset[single_slice_dim] 139 140 def single_slice_dim(self, shape): 141 """Returns the slice dim when the variable is partitioned only in one dim. 142 143 Args: 144 shape: Tuple or list of `int` indicating the shape of one specific 145 variable partition. 146 147 Returns: 148 `int` representing the dimension that the variable is partitioned in, or 149 `None` if the variable doesn't seem to be partitioned at all. 150 151 Raises: 152 TypeError: If `shape` is not a sequence. 153 ValueError: If `shape` is not the same length as `self.full_shape`. If 154 the variable is partitioned in more than one dimension. 155 """ 156 if not isinstance(shape, collections_abc.Sequence) or isinstance( 157 shape, six.string_types): 158 raise TypeError( 159 "`shape` must be a sequence (like tuple or list) instead of " + 160 type(shape).__name__) 161 162 if len(shape) != len(self.full_shape): 163 raise ValueError( 164 "Expected equal length, but received shape={} of length {} while " 165 "self.full_shape={} is of length {}.".format(shape, len(shape), 166 self.full_shape, 167 len(self.full_shape))) 168 169 for i in xrange(len(shape)): 170 if self.var_offset[i] + shape[i] > self.full_shape[i]: 171 raise ValueError( 172 "With self.var_offset={}, a partition of shape={} would exceed " 173 "self.full_shape={} in dimension {}.".format( 174 self.var_offset, shape, self.full_shape, i)) 175 176 slice_dim = None 177 for i in xrange(len(shape)): 178 if shape[i] == self.full_shape[i]: 179 continue 180 if slice_dim is not None: 181 raise ValueError( 182 "Cannot use single_slice_dim() with shape={} and " 183 "self.full_shape={} since slice dim could be either dimension {} " 184 "or {}.".format(shape, self.full_shape, i, slice_dim)) 185 slice_dim = i 186 187 return slice_dim 188 189 190class _ReuseMode(enum.Enum): 191 """Mode for variable access within a variable scope.""" 192 193 # Indicates that variables are to be fetched if they already exist or 194 # otherwise created. 195 AUTO_REUSE = 1 196 197 # TODO(alive): For TensorFlow 2.0, Deprecate True/False/None API in favor of 198 # enum values. 199 # REUSE_FALSE = 2 200 # REUSE_TRUE = 3 201 202 203# TODO(apassos) remove these forwarding symbols. 204VariableSynchronization = variables.VariableSynchronization # pylint: disable=invalid-name 205VariableAggregation = variables.VariableAggregation # pylint: disable=invalid-name 206 207AUTO_REUSE = _ReuseMode.AUTO_REUSE 208tf_export(v1=["AUTO_REUSE"]).export_constant(__name__, "AUTO_REUSE") 209AUTO_REUSE.__doc__ = """ 210When passed in as the value for the `reuse` flag, AUTO_REUSE indicates that 211get_variable() should create the requested variable if it doesn't exist or, if 212it does exist, simply return it. 213""" 214 215_DEFAULT_USE_RESOURCE = tf2.enabled() 216 217 218@tf_export(v1=["enable_resource_variables"]) 219def enable_resource_variables(): 220 """Creates resource variables by default. 221 222 Resource variables are improved versions of TensorFlow variables with a 223 well-defined memory model. Accessing a resource variable reads its value, and 224 all ops which access a specific read value of the variable are guaranteed to 225 see the same value for that tensor. Writes which happen after a read (by 226 having a control or data dependency on the read) are guaranteed not to affect 227 the value of the read tensor, and similarly writes which happen before a read 228 are guaranteed to affect the value. No guarantees are made about unordered 229 read/write pairs. 230 231 Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0 232 feature. 233 """ 234 global _DEFAULT_USE_RESOURCE 235 _DEFAULT_USE_RESOURCE = True 236 logging.vlog(1, "Enabling resource variables") 237 _api_usage_gauge.get_cell().set(True) 238 239 240@tf_export(v1=["resource_variables_enabled"]) 241def resource_variables_enabled(): 242 """Returns `True` if resource variables are enabled. 243 244 Resource variables are improved versions of TensorFlow variables with a 245 well-defined memory model. Accessing a resource variable reads its value, and 246 all ops which access a specific read value of the variable are guaranteed to 247 see the same value for that tensor. Writes which happen after a read (by 248 having a control or data dependency on the read) are guaranteed not to affect 249 the value of the read tensor, and similarly writes which happen before a read 250 are guaranteed to affect the value. No guarantees are made about unordered 251 read/write pairs. 252 253 Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0 254 feature. 255 """ 256 global _DEFAULT_USE_RESOURCE 257 return _DEFAULT_USE_RESOURCE 258 259 260@deprecation.deprecated( 261 None, "non-resource variables are not supported in the long term") 262@tf_export(v1=["disable_resource_variables"]) 263def disable_resource_variables(): 264 """Opts out of resource variables. 265 266 If your code needs tf.disable_resource_variables() to be called to work 267 properly please file a bug. 268 """ 269 global _DEFAULT_USE_RESOURCE 270 _DEFAULT_USE_RESOURCE = False 271 logging.vlog(1, "Disabling resource variables") 272 _api_usage_gauge.get_cell().set(False) 273 274 275def _needs_no_arguments(python_callable): 276 """Returns true if the callable needs no arguments to call.""" 277 # TODO(bfontain): Switch to inspect.signature when we are python 3 only. 278 # signature = inspect.signature(python_callable) 279 # return not [1 for param in signature.parameters.values() 280 # if param.default == param.empty] 281 num_arguments = len(tf_inspect.getargspec(python_callable).args) 282 if not tf_inspect.isfunction(python_callable) and not isinstance( 283 python_callable, functools.partial): 284 # getargspec includes self for function objects (which aren't 285 # functools.partial). This has no default so we need to remove it. 286 # It is not even an argument so its odd that getargspec returns this. 287 # Note that this is fixed with inspect.signature in Python 3. 288 num_arguments -= 1 289 return num_arguments == len( 290 tf_inspect.getargspec(python_callable).defaults or []) 291 292 293class _VariableStore(object): 294 """Variable store that carries a number of named Variables. 295 296 New variable names and new variables can be created; all stored 297 variables are initialized with the initializer passed to __init__. 298 299 Attributes: 300 vars: a dictionary with string names (same as passed in GetVar) as keys and 301 the corresponding TensorFlow Variables as values. 302 """ 303 304 __slots__ = ["_vars", "_partitioned_vars", "_store_eager_variables"] 305 306 def __init__(self): 307 """Create a variable store.""" 308 self._vars = {} # A dictionary of the stored TensorFlow variables. 309 self._partitioned_vars = {} # A dict of the stored PartitionedVariables. 310 self._store_eager_variables = False 311 312 def get_variable(self, 313 name, 314 shape=None, 315 dtype=dtypes.float32, 316 initializer=None, 317 regularizer=None, 318 reuse=None, 319 trainable=None, 320 collections=None, 321 caching_device=None, 322 partitioner=None, 323 validate_shape=True, 324 use_resource=None, 325 custom_getter=None, 326 constraint=None, 327 synchronization=VariableSynchronization.AUTO, 328 aggregation=VariableAggregation.NONE): 329 """Gets an existing variable with these parameters or create a new one. 330 331 If a variable with the given name is already stored, we return the stored 332 variable. Otherwise, we create a new one. 333 334 Set `reuse` to `True` when you only want to reuse existing Variables. 335 Set `reuse` to `False` when you only want to create new Variables. 336 Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want 337 variables to be created if they don't exist or returned if they do. 338 339 If initializer is `None` (the default), the default initializer passed in 340 the constructor is used. If that one is `None` too, we use a new 341 `glorot_uniform_initializer`. If initializer is a Tensor, we use 342 it as a value and derive the shape from the initializer. 343 344 If a partitioner is provided, a `PartitionedVariable` is returned. 345 Accessing this object as a `Tensor` returns the shards concatenated along 346 the partition axis. 347 348 Some useful partitioners are available. See, e.g., 349 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 350 351 Args: 352 name: The name of the new or existing variable. 353 shape: Shape of the new or existing variable. 354 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 355 initializer: Initializer for the variable. 356 regularizer: A (Tensor -> Tensor or None) function; the result of applying 357 it on a newly created variable will be added to the collection 358 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 359 reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of 360 variables. When eager execution is enabled this argument is always 361 forced to be False. 362 trainable: If `True` also add the variable to the graph collection 363 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). `trainable` 364 defaults to `True`, unless `synchronization` is set to `ON_READ`, in 365 which case it defaults to `False`. 366 collections: List of graph collections keys to add the `Variable` to. 367 Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 368 caching_device: Optional device string or function describing where the 369 Variable should be cached for reading. Defaults to the Variable's 370 device. If not `None`, caches on another device. Typical use is to 371 cache on the device where the Ops using the `Variable` reside, to 372 deduplicate copying through `Switch` and other conditional statements. 373 partitioner: Optional callable that accepts a fully defined `TensorShape` 374 and dtype of the `Variable` to be created, and returns a list of 375 partitions for each axis (currently only one axis can be partitioned). 376 validate_shape: If False, allows the variable to be initialized with a 377 value of unknown shape. If True, the default, the shape of initial_value 378 must be known. 379 use_resource: If False, creates a regular Variable. If True, creates 380 instead an experimental ResourceVariable which has well-defined 381 semantics. Defaults to False (will later change to True). When eager 382 execution is enabled this argument is always forced to be true. 383 custom_getter: Callable that takes as a first argument the true getter, 384 and allows overwriting the internal get_variable method. The signature 385 of `custom_getter` should match that of this method, 386 but the most future-proof version will allow for changes: `def 387 custom_getter(getter, *args, **kwargs)`. Direct access to 388 all `get_variable` parameters is also allowed: `def 389 custom_getter(getter, name, *args, **kwargs)`. A simple identity 390 custom getter that simply creates variables with modified names is: 391 ```python 392 def custom_getter(getter, name, *args, **kwargs): return getter(name + 393 '_suffix', *args, **kwargs) ``` 394 constraint: An optional projection function to be applied to the variable 395 after being updated by an `Optimizer` (e.g. used to implement norm 396 constraints or value constraints for layer weights). The function must 397 take as input the unprojected Tensor representing the value of the 398 variable and return the Tensor for the projected value (which must have 399 the same shape). Constraints are not safe to use when doing asynchronous 400 distributed training. 401 synchronization: Indicates when a distributed a variable will be 402 aggregated. Accepted values are constants defined in the class 403 `tf.VariableSynchronization`. By default the synchronization is set to 404 `AUTO` and the current `DistributionStrategy` chooses when to 405 synchronize. 406 aggregation: Indicates how a distributed variable will be aggregated. 407 Accepted values are constants defined in the class 408 `tf.VariableAggregation`. 409 410 Returns: 411 The created or existing `Variable` (or `PartitionedVariable`, if a 412 partitioner was used). 413 414 Raises: 415 ValueError: when creating a new variable and shape is not declared, 416 when reusing a variable and specifying a conflicting shape, 417 or when violating reuse during variable creation. 418 RuntimeError: when eager execution is enabled and not called from an 419 EagerVariableStore. 420 """ 421 if custom_getter is not None and not callable(custom_getter): 422 raise ValueError("Passed a custom_getter which is not callable: %s" % 423 custom_getter) 424 425 with ops.init_scope(): 426 if context.executing_eagerly(): 427 # Variable creation and initialization takes place in `init_scope`s; 428 # as such, if an `init_scope` lifts us into the eager context, then we 429 # need to use `ResourceVariable`s. 430 use_resource = True 431 432 # Note that it's fine to reuse eager variables whose initialization was 433 # lifted from a function-building graph into the eager context (that's why 434 # the following clause is not wrapped in an `init_scope`); lifted variables 435 # are tracked by the graph's `VariableStore`. 436 if context.executing_eagerly(): 437 if not self._store_eager_variables and reuse: 438 raise RuntimeError( 439 "When eager execution is enabled variable reuse is only supported" 440 " when an EagerVariableStore is active. See the documentation on" 441 " EagerVariableStore for example usage.") 442 if self._store_eager_variables: 443 reuse = AUTO_REUSE 444 445 # If a *_ref type is passed in an error would be triggered further down the 446 # stack. We prevent this using base_dtype to get a non-ref version of the 447 # type, before doing anything else. When _ref types are removed in favor of 448 # resources, this line can be removed. 449 try: 450 dtype = dtype.base_dtype 451 except AttributeError: 452 # .base_dtype not existing means that we will try and use the raw dtype 453 # which was passed in - this might be a NumPy type which is valid. 454 pass 455 456 # This is the main logic of get_variable. However, custom_getter 457 # may override this logic. So we save it as a callable and pass 458 # it to custom_getter. 459 # Note: the parameters of _true_getter, and their documentation, match 460 # *exactly* item-for-item with the docstring of this method. 461 def _true_getter( # pylint: disable=missing-docstring 462 name, 463 shape=None, 464 dtype=dtypes.float32, 465 initializer=None, 466 regularizer=None, 467 reuse=None, 468 trainable=None, 469 collections=None, 470 caching_device=None, 471 partitioner=None, 472 validate_shape=True, 473 use_resource=None, 474 constraint=None, 475 synchronization=VariableSynchronization.AUTO, 476 aggregation=VariableAggregation.NONE): 477 is_scalar = ( 478 shape is not None and isinstance(shape, collections_abc.Sequence) and 479 not shape) 480 # Partitioned variable case 481 if partitioner is not None and not is_scalar: 482 if not callable(partitioner): 483 raise ValueError("Partitioner must be callable, but received: %s" % 484 partitioner) 485 with ops.name_scope(None): 486 return self._get_partitioned_variable( 487 name=name, 488 shape=shape, 489 dtype=dtype, 490 initializer=initializer, 491 regularizer=regularizer, 492 reuse=reuse, 493 trainable=trainable, 494 collections=collections, 495 caching_device=caching_device, 496 partitioner=partitioner, 497 validate_shape=validate_shape, 498 use_resource=use_resource, 499 constraint=constraint, 500 synchronization=synchronization, 501 aggregation=aggregation) 502 503 # Special case for partitioned variable to allow reuse without having to 504 # specify partitioner. 505 if (reuse is True and partitioner is None 506 and name in self._partitioned_vars): 507 return self._get_partitioned_variable( 508 name=name, 509 shape=shape, 510 dtype=dtype, 511 initializer=initializer, 512 regularizer=regularizer, 513 reuse=reuse, 514 trainable=trainable, 515 collections=collections, 516 caching_device=caching_device, 517 partitioner=None, 518 validate_shape=validate_shape, 519 use_resource=use_resource, 520 constraint=constraint, 521 synchronization=synchronization, 522 aggregation=aggregation) 523 524 # Single variable case 525 if "%s/part_0" % name in self._vars: 526 raise ValueError( 527 "No partitioner was provided, but a partitioned version of the " 528 "variable was found: %s/part_0. Perhaps a variable of the same " 529 "name was already created with partitioning?" % name) 530 531 return self._get_single_variable( 532 name=name, 533 shape=shape, 534 dtype=dtype, 535 initializer=initializer, 536 regularizer=regularizer, 537 reuse=reuse, 538 trainable=trainable, 539 collections=collections, 540 caching_device=caching_device, 541 validate_shape=validate_shape, 542 use_resource=use_resource, 543 constraint=constraint, 544 synchronization=synchronization, 545 aggregation=aggregation) 546 547 synchronization, aggregation, trainable = ( 548 variables.validate_synchronization_aggregation_trainable( 549 synchronization, aggregation, trainable, name)) 550 551 if custom_getter is not None: 552 # Handle backwards compatibility with getter arguments that were added 553 # to the API after users started writing custom getters. 554 custom_getter_kwargs = { 555 "getter": _true_getter, 556 "name": name, 557 "shape": shape, 558 "dtype": dtype, 559 "initializer": initializer, 560 "regularizer": regularizer, 561 "reuse": reuse, 562 "trainable": trainable, 563 "collections": collections, 564 "caching_device": caching_device, 565 "partitioner": partitioner, 566 "validate_shape": validate_shape, 567 "use_resource": use_resource, 568 "synchronization": synchronization, 569 "aggregation": aggregation, 570 } 571 # `fn_args` and `has_kwargs` can handle functions, `functools.partial`, 572 # `lambda`. 573 if ("constraint" in function_utils.fn_args(custom_getter) or 574 function_utils.has_kwargs(custom_getter)): 575 custom_getter_kwargs["constraint"] = constraint 576 return custom_getter(**custom_getter_kwargs) 577 else: 578 return _true_getter( 579 name, 580 shape=shape, 581 dtype=dtype, 582 initializer=initializer, 583 regularizer=regularizer, 584 reuse=reuse, 585 trainable=trainable, 586 collections=collections, 587 caching_device=caching_device, 588 partitioner=partitioner, 589 validate_shape=validate_shape, 590 use_resource=use_resource, 591 constraint=constraint, 592 synchronization=synchronization, 593 aggregation=aggregation) 594 595 def _get_partitioned_variable(self, 596 name, 597 partitioner, 598 shape=None, 599 dtype=dtypes.float32, 600 initializer=None, 601 regularizer=None, 602 reuse=None, 603 trainable=None, 604 collections=None, 605 caching_device=None, 606 validate_shape=True, 607 use_resource=None, 608 constraint=None, 609 synchronization=VariableSynchronization.AUTO, 610 aggregation=VariableAggregation.NONE): 611 """Gets or creates a sharded variable list with these parameters. 612 613 The `partitioner` must be a callable that accepts a fully defined 614 `TensorShape` and returns a sequence of integers (the `partitions`). 615 These integers describe how to partition the given sharded `Variable` 616 along the given dimension. That is, `partitions[1] = 3` means split 617 the `Variable` into 3 shards along dimension 1. Currently, sharding along 618 only one axis is supported. 619 620 If the list of variables with the given name (prefix) is already stored, 621 we return the stored variables. Otherwise, we create a new one. 622 623 Set `reuse` to `True` when you only want to reuse existing Variables. 624 Set `reuse` to `False` when you only want to create new Variables. 625 Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want 626 variables to be created if they don't exist or returned if they do. 627 628 If initializer is `None` (the default), the default initializer passed in 629 the constructor is used. If that one is `None` too, we use a new 630 `glorot_uniform_initializer`. If initializer is a Tensor, we use 631 it as a value and derive the shape from the initializer. 632 633 If the initializer is a callable, then it will be called for each 634 shard. Otherwise the initializer should match the shape of the entire 635 sharded Variable, and it will be sliced accordingly for each shard. 636 637 Some useful partitioners are available. See, e.g., 638 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 639 640 Args: 641 name: the name of the new or existing sharded variable. 642 partitioner: Optional callable that accepts a fully defined `TensorShape` 643 and `dtype` of the Variable to be created, and returns a list of 644 partitions for each axis (currently only one axis can be partitioned). 645 shape: shape of the new or existing sharded variable. 646 dtype: type of the new or existing sharded variable (defaults to 647 `DT_FLOAT`). 648 initializer: initializer for the sharded variable. 649 regularizer: a (Tensor -> Tensor or None) function; the result of applying 650 it on a newly created variable will be added to the collection 651 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 652 reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of 653 variables. 654 trainable: If `True` also add the variable to the graph collection 655 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 656 collections: List of graph collections keys to add the Variable to. 657 Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 658 caching_device: Optional device string or function describing where the 659 Variable should be cached for reading. Defaults to the Variable's 660 device. If not `None`, caches on another device. Typical use is to 661 cache on the device where the Ops using the Variable reside, to 662 deduplicate copying through `Switch` and other conditional statements. 663 validate_shape: If False, allows the variable to be initialized with a 664 value of unknown shape. If True, the default, the shape of initial_value 665 must be known. 666 use_resource: If False, creates a regular Variable. If True, creates an 667 experimental ResourceVariable which has well-defined semantics. Defaults 668 to False (will later change to True). 669 constraint: An optional projection function to be applied to the variable 670 after being updated by an `Optimizer` (e.g. used to implement norm 671 constraints or value constraints for layer weights). The function must 672 take as input the unprojected Tensor representing the value of the 673 variable and return the Tensor for the projected value (which must have 674 the same shape). Constraints are not safe to use when doing asynchronous 675 distributed training. 676 synchronization: Indicates when a distributed a variable will be 677 aggregated. Accepted values are constants defined in the class 678 `tf.VariableSynchronization`. By default the synchronization is set to 679 `AUTO` and the current `DistributionStrategy` chooses when to 680 synchronize. 681 aggregation: Indicates how a distributed variable will be aggregated. 682 Accepted values are constants defined in the class 683 `tf.VariableAggregation`. 684 685 Returns: 686 A `PartitionedVariable` object. 687 688 Raises: 689 ValueError: when creating a new variable and shape is not declared, 690 when reusing a variable and specifying a conflicting shape, 691 when violating reuse during variable creation, or if an existing 692 sharded variable exists for the given name but with different sharding. 693 """ 694 initializing_from_value = initializer is not None and isinstance( 695 initializer, ops.Tensor) 696 if name in self._vars: 697 raise ValueError( 698 "A partitioner was provided, but an unpartitioned version of the " 699 "variable was found: %s. Perhaps a variable of the same name was " 700 "already created without partitioning?" % name) 701 702 shape = tensor_shape.as_shape(shape) 703 if initializing_from_value: 704 shape = shape.merge_with(initializer.get_shape()) 705 706 partitions = None 707 if not reuse or partitioner: 708 partitions = _call_partitioner(partitioner, shape, dtype) 709 710 if name in self._partitioned_vars: 711 if reuse is False: 712 raise ValueError( 713 "Partitioned variable with name %s already exists. Did you mean to " 714 "set reuse=True or reuse=tf.AUTO_REUSE in VarScope?" % name) 715 716 existing_var = self._partitioned_vars[name] 717 if not shape.is_compatible_with(existing_var.get_shape()): 718 raise ValueError( 719 "Trying to reuse partitioned variable %s, but specified shape %s " 720 "and found shape %s." % (name, shape, existing_var.get_shape())) 721 if not dtype.is_compatible_with(existing_var.dtype): 722 raise ValueError( 723 "Trying to reuse partitioned variable %s, but specified dtype %s " 724 "and found dtype %s." % (name, dtype.name, existing_var.dtype.name)) 725 726 # pylint: disable=protected-access 727 if (partitions is not None and 728 existing_var._get_partitions() != partitions): 729 raise ValueError( 730 "Trying to reuse partitioned variable %s, but specified partitions " 731 "%s and found partitions %s." % 732 (name, partitions, existing_var._get_partitions())) 733 # pylint: enable=protected-access 734 735 return existing_var 736 737 if reuse is True: 738 raise ValueError("PartitionedVariable %s does not exist, or was not " 739 "created with tf.get_variable(). Did you mean to set " 740 "reuse=False or reuse=tf.AUTO_REUSE in VarScope?" % name) 741 742 slice_dim, num_slices = _get_slice_dim_and_num_slices(partitions) 743 744 if "%s/part_0" % name in self._vars: 745 if "%s/part_%d" % (name, num_slices - 1) not in self._vars: 746 raise ValueError( 747 "Partitioner returned a different partitioning than what was " 748 "already found. Partitioner returned %d shards, and shard " 749 "%s/part_0 was found, but %s/part_%d was not." % 750 (num_slices, name, name, num_slices - 1)) 751 if "%s/part_%d" % (name, num_slices) in self._vars: 752 raise ValueError( 753 "Partitioner returned a different partitioning than what was " 754 "already found. Partitioner returned %d shards, and shard " 755 "%s/part_0 was found, but so was the extra shard %s/part_%d." % 756 (num_slices, name, name, num_slices)) 757 758 vs = [] 759 for i, (var_offset, var_shape) in enumerate( 760 _iter_slices(shape.as_list(), num_slices, slice_dim)): 761 partition_info = _PartitionInfo( 762 full_shape=shape.as_list(), var_offset=var_offset) 763 var_full_name = "%s/part_%d" % (name, i) 764 with ops.name_scope( 765 var_full_name + "/PartitionedInitializer", skip_on_eager=False): 766 # Create the tensor to initialize the variable with default value. 767 if initializer is None: 768 init, initializing_from_value = self._get_default_initializer( 769 name=name, shape=shape, dtype=dtype) 770 if initializing_from_value: 771 init_shape = None 772 else: 773 init_shape = var_shape 774 elif callable(initializer): 775 init = initializer 776 init_shape = var_shape 777 elif isinstance(initializer, ops.Tensor): 778 init = array_ops.slice(initializer, var_offset, var_shape) 779 # Use the dtype of the given tensor. 780 dtype = init.dtype.base_dtype 781 init_shape = None 782 else: 783 init = ops.convert_to_tensor(initializer, dtype=dtype) 784 init = array_ops.slice(init, var_offset, var_shape) 785 init_shape = None 786 787 with ops.name_scope(None): 788 var = self._get_single_variable( 789 name=var_full_name, 790 shape=init_shape, 791 dtype=dtype, 792 initializer=init, 793 partition_info=partition_info, 794 regularizer=regularizer, 795 reuse=reuse, 796 trainable=trainable, 797 collections=collections, 798 caching_device=caching_device, 799 validate_shape=validate_shape, 800 use_resource=use_resource, 801 constraint=constraint, 802 synchronization=synchronization, 803 aggregation=aggregation) 804 805 # pylint: disable=protected-access 806 var._set_save_slice_info( 807 variables.Variable.SaveSliceInfo(name, shape.as_list(), var_offset, 808 var_shape)) 809 vs.append(var) 810 # pylint: enable=protected-access 811 812 partitioned_var = variables.PartitionedVariable( 813 name=name, 814 shape=shape, 815 dtype=dtype, 816 variable_list=vs, 817 partitions=partitions) 818 if not context.executing_eagerly() or self._store_eager_variables: 819 self._partitioned_vars[name] = partitioned_var 820 return partitioned_var 821 822 def _get_single_variable(self, 823 name, 824 shape=None, 825 dtype=dtypes.float32, 826 initializer=None, 827 regularizer=None, 828 partition_info=None, 829 reuse=None, 830 trainable=None, 831 collections=None, 832 caching_device=None, 833 validate_shape=True, 834 use_resource=None, 835 constraint=None, 836 synchronization=VariableSynchronization.AUTO, 837 aggregation=VariableAggregation.NONE): 838 """Get or create a single Variable (e.g. 839 840 a shard or entire variable). 841 842 See the documentation of get_variable above (ignore partitioning components) 843 for details. 844 845 Args: 846 name: see get_variable. 847 shape: see get_variable. 848 dtype: see get_variable. 849 initializer: see get_variable. 850 regularizer: see get_variable. 851 partition_info: _PartitionInfo object. 852 reuse: see get_variable. 853 trainable: see get_variable. 854 collections: see get_variable. 855 caching_device: see get_variable. 856 validate_shape: see get_variable. 857 use_resource: see get_variable. 858 constraint: see get_variable. 859 synchronization: see get_variable. 860 aggregation: see get_variable. 861 862 Returns: 863 A Variable. See documentation of get_variable above. 864 865 Raises: 866 ValueError: See documentation of get_variable above. 867 """ 868 # Set to true if initializer is a constant. 869 initializing_from_value = False 870 if initializer is not None and not callable(initializer): 871 initializing_from_value = True 872 if shape is not None and initializing_from_value: 873 raise ValueError("If initializer is a constant, do not specify shape.") 874 875 dtype = dtypes.as_dtype(dtype) 876 shape = tensor_shape.as_shape(shape) 877 878 if name in self._vars: 879 # Here we handle the case when returning an existing variable. 880 if reuse is False: 881 var = self._vars[name] 882 err_msg = ("Variable %s already exists, disallowed." 883 " Did you mean to set reuse=True or " 884 "reuse=tf.AUTO_REUSE in VarScope?" % name) 885 # ResourceVariables don't have an op associated with so no traceback 886 if isinstance(var, resource_variable_ops.ResourceVariable): 887 raise ValueError(err_msg) 888 tb = var.op.traceback[::-1] 889 # Throw away internal tf entries and only take a few lines. In some 890 # cases the traceback can be longer (e.g. if someone uses factory 891 # functions to create variables) so we take more than needed in the 892 # default case. 893 tb = [x for x in tb if "tensorflow/python" not in x[0]][:5] 894 raise ValueError("%s Originally defined at:\n\n%s" % 895 (err_msg, "".join(traceback.format_list(tb)))) 896 found_var = self._vars[name] 897 if not shape.is_compatible_with(found_var.get_shape()): 898 raise ValueError("Trying to share variable %s, but specified shape %s" 899 " and found shape %s." % 900 (name, shape, found_var.get_shape())) 901 if not dtype.is_compatible_with(found_var.dtype): 902 dtype_str = dtype.name 903 found_type_str = found_var.dtype.name 904 raise ValueError("Trying to share variable %s, but specified dtype %s" 905 " and found dtype %s." % 906 (name, dtype_str, found_type_str)) 907 return found_var 908 909 # The code below handles only the case of creating a new variable. 910 if reuse is True: 911 raise ValueError("Variable %s does not exist, or was not created with " 912 "tf.get_variable(). Did you mean to set " 913 "reuse=tf.AUTO_REUSE in VarScope?" % name) 914 915 # Create the tensor to initialize the variable with default value. 916 if initializer is None: 917 initializer, initializing_from_value = self._get_default_initializer( 918 name=name, shape=shape, dtype=dtype) 919 # Enter an init scope when creating the initializer. 920 with ops.init_scope(): 921 if initializing_from_value: 922 init_val = initializer 923 variable_dtype = None 924 else: 925 # Instantiate initializer if provided initializer is a type object. 926 if tf_inspect.isclass(initializer): 927 initializer = initializer() 928 if shape.is_fully_defined(): 929 if "partition_info" in tf_inspect.getargspec(initializer).args: 930 init_val = functools.partial(initializer, 931 shape.as_list(), 932 dtype=dtype, 933 partition_info=partition_info) 934 else: 935 init_val = functools.partial(initializer, 936 shape.as_list(), dtype=dtype) 937 variable_dtype = dtype.base_dtype 938 elif _needs_no_arguments(initializer): 939 init_val = initializer 940 variable_dtype = None 941 else: 942 raise ValueError("The initializer passed is not valid. It should " 943 "be a callable with no arguments and the " 944 "shape should not be provided or an instance of " 945 "`tf.keras.initializers.*' and `shape` should be " 946 "fully defined.") 947 948 # Create the variable. 949 if use_resource is None: 950 # Set the default value if unspecified. 951 use_resource = _DEFAULT_USE_RESOURCE 952 v = variables.VariableV1( 953 initial_value=init_val, 954 name=name, 955 trainable=trainable, 956 collections=collections, 957 caching_device=caching_device, 958 dtype=variable_dtype, 959 validate_shape=validate_shape, 960 constraint=constraint, 961 use_resource=use_resource, 962 synchronization=synchronization, 963 aggregation=aggregation) 964 if context.executing_eagerly() and self._store_eager_variables: 965 if collections: 966 ops.add_to_collections(collections, v) 967 else: 968 ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, v) 969 if trainable: 970 ops.add_to_collection(ops.GraphKeys.TRAINABLE_VARIABLES, v) 971 972 if not context.executing_eagerly() or self._store_eager_variables: 973 # In eager mode we do not want to keep default references to Variable 974 # objects as this will prevent their memory from being released. 975 self._vars[name] = v 976 logging.vlog(1, "Created variable %s with shape %s and init %s", v.name, 977 format(shape), initializer) 978 979 # Run the regularizer if requested and save the resulting loss. 980 if regularizer: 981 def make_regularizer_op(): 982 with ops.colocate_with(v): 983 with ops.name_scope(name + "/Regularizer/"): 984 return regularizer(v) 985 986 if regularizer(v) is not None: 987 lazy_eval_tensor = _LazyEvalTensor(make_regularizer_op) 988 ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, 989 lazy_eval_tensor) 990 991 return v 992 993 # Initialize variable when no initializer provided 994 def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32): 995 """Provide a default initializer and a corresponding value. 996 997 Args: 998 name: see get_variable. 999 shape: see get_variable. 1000 dtype: see get_variable. 1001 1002 Returns: 1003 initializer and initializing_from_value. See get_variable above. 1004 1005 Raises: 1006 ValueError: When giving unsupported dtype. 1007 """ 1008 del shape 1009 # If dtype is DT_FLOAT, provide a uniform unit scaling initializer 1010 if dtype.is_floating: 1011 initializer = init_ops.glorot_uniform_initializer() 1012 initializing_from_value = False 1013 # If dtype is DT_INT/DT_UINT, provide a default value `zero` 1014 # If dtype is DT_BOOL, provide a default value `FALSE` 1015 elif (dtype.is_integer or dtype.is_unsigned or dtype.is_bool or 1016 dtype == dtypes.string): 1017 initializer = init_ops.zeros_initializer() 1018 initializing_from_value = False 1019 # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here? 1020 else: 1021 raise ValueError("An initializer for variable %s of %s is required" % 1022 (name, dtype.base_dtype)) 1023 1024 return initializer, initializing_from_value 1025 1026 1027class _LazyEvalTensor(core.Tensor): 1028 """A Tensor-like object that only evaluates its thunk when used.""" 1029 1030 def __init__(self, thunk): 1031 """Initializes a _LazyEvalTensor object. 1032 1033 Args: 1034 thunk: A callable. A thunk which computes the value of the tensor. 1035 """ 1036 self._thunk = thunk 1037 self._master_tensor = thunk() 1038 1039 def _as_tensor(self, dtype=None, name=None, as_ref=False): 1040 del name 1041 assert not as_ref 1042 assert dtype in [None, self.dtype] 1043 1044 return self._thunk() 1045 1046 1047def _make_master_property(name): 1048 @property 1049 def prop(self): 1050 return getattr(self._master_tensor, name) # pylint: disable=protected-access 1051 return prop 1052 1053_master_property_list = ("device", "dtype", "graph", "name", "op", "shape", 1054 "value_index") 1055for _name in _master_property_list: 1056 setattr(_LazyEvalTensor, _name, _make_master_property(_name)) 1057 1058 1059def _make_master_method(name): 1060 def method(self, *args, **kwargs): 1061 return getattr(self._master_tensor, name)(*args, **kwargs) # pylint: disable=protected-access 1062 return method 1063 1064_master_method_list = ("get_shape", "__str__", "shape_as_list") 1065for _name in _master_method_list: 1066 setattr(_LazyEvalTensor, _name, _make_master_method(_name)) 1067 1068 1069def _make_op_method(name): 1070 def method(self, *args, **kwargs): 1071 return getattr(self._as_tensor(), name)(*args, **kwargs) # pylint: disable=protected-access 1072 return method 1073 1074_op_list = ("__abs__", "__add__", "__and__", "__bool__", "__div__", "__eq__", 1075 "__floordiv__", "__ge__", "__getitem__", "__gt__", "__invert__", 1076 "__iter__", "__le__", "__len__", "__lt__", "__matmul__", "__mod__", 1077 "__mul__", "__ne__", "__neg__", "__nonzero__", "__or__", "__pow__", 1078 "__radd__", "__rand__", "__rdiv__", "__rfloordiv__", "__rmatmul__", 1079 "__rmod__", "__rmul__", "__ror__", "__rpow__", "__rsub__", 1080 "__rtruediv__", "__rxor__", "__sub__", "__truediv__", "__xor__", 1081 "eval", "numpy") 1082for _name in _op_list: 1083 setattr(_LazyEvalTensor, _name, _make_op_method(_name)) 1084 1085 1086ops.register_tensor_conversion_function( 1087 _LazyEvalTensor, 1088 lambda val, dtype, name, as_ref: val._as_tensor(dtype, name, as_ref) # pylint: disable=protected-access 1089 ) 1090 1091session.register_session_run_conversion_functions( 1092 _LazyEvalTensor, 1093 lambda fetch: ([fetch._master_tensor], lambda fetched_vals: fetched_vals[0]) # pylint: disable=protected-access 1094 ) 1095 1096 1097# To stop regularization, use this regularizer 1098@tf_export(v1=["no_regularizer"]) 1099def no_regularizer(_): 1100 """Use this function to prevent regularization of variables.""" 1101 return None 1102 1103 1104# TODO(alive): support caching devices and partitioned variables in Eager mode. 1105@tf_export(v1=["VariableScope"]) 1106class VariableScope(object): 1107 """Variable scope object to carry defaults to provide to `get_variable`. 1108 1109 Many of the arguments we need for `get_variable` in a variable store are most 1110 easily handled with a context. This object is used for the defaults. 1111 1112 Attributes: 1113 name: name of the current scope, used as prefix in get_variable. 1114 initializer: default initializer passed to get_variable. 1115 regularizer: default regularizer passed to get_variable. 1116 reuse: Boolean, None, or tf.compat.v1.AUTO_REUSE, setting the reuse in 1117 get_variable. When eager execution is enabled this argument is always 1118 forced to be False. 1119 caching_device: string, callable, or None: the caching device passed to 1120 get_variable. 1121 partitioner: callable or `None`: the partitioner passed to `get_variable`. 1122 custom_getter: default custom getter passed to get_variable. 1123 name_scope: The name passed to `tf.name_scope`. 1124 dtype: default type passed to get_variable (defaults to DT_FLOAT). 1125 use_resource: if False, create a normal Variable; if True create an 1126 experimental ResourceVariable with well-defined semantics. Defaults to 1127 False (will later change to True). When eager execution is enabled this 1128 argument is always forced to be True. 1129 constraint: An optional projection function to be applied to the variable 1130 after being updated by an `Optimizer` (e.g. used to implement norm 1131 constraints or value constraints for layer weights). The function must 1132 take as input the unprojected Tensor representing the value of the 1133 variable and return the Tensor for the projected value (which must have 1134 the same shape). Constraints are not safe to use when doing asynchronous 1135 distributed training. 1136 """ 1137 1138 def __init__(self, 1139 reuse, 1140 name="", 1141 initializer=None, 1142 regularizer=None, 1143 caching_device=None, 1144 partitioner=None, 1145 custom_getter=None, 1146 name_scope="", 1147 dtype=dtypes.float32, 1148 use_resource=None, 1149 constraint=None): 1150 """Creates a new VariableScope with the given properties.""" 1151 self._name = name 1152 self._initializer = initializer 1153 self._regularizer = regularizer 1154 self._reuse = reuse 1155 self._caching_device = caching_device 1156 self._partitioner = partitioner 1157 self._custom_getter = custom_getter 1158 self._name_scope = name_scope 1159 self._dtype = dtype 1160 self._use_resource = use_resource 1161 self._constraint = constraint 1162 if context.executing_eagerly(): 1163 if self._caching_device is not None: 1164 raise NotImplementedError("Caching devices is not yet supported " 1165 "when eager execution is enabled.") 1166 self._reuse = AUTO_REUSE 1167 self._use_resource = True 1168 1169 @property 1170 def name(self): 1171 return self._name 1172 1173 @property 1174 def original_name_scope(self): 1175 return self._name_scope 1176 1177 @property 1178 def reuse(self): 1179 return self._reuse 1180 1181 @property 1182 def initializer(self): 1183 return self._initializer 1184 1185 @property 1186 def dtype(self): 1187 return self._dtype 1188 1189 @property 1190 def use_resource(self): 1191 return self._use_resource 1192 1193 @property 1194 def regularizer(self): 1195 return self._regularizer 1196 1197 @property 1198 def caching_device(self): 1199 return self._caching_device 1200 1201 @property 1202 def partitioner(self): 1203 return self._partitioner 1204 1205 @property 1206 def custom_getter(self): 1207 return self._custom_getter 1208 1209 @property 1210 def constraint(self): 1211 return self._constraint 1212 1213 def reuse_variables(self): 1214 """Reuse variables in this scope.""" 1215 self._reuse = True 1216 1217 def set_initializer(self, initializer): 1218 """Set initializer for this scope.""" 1219 self._initializer = initializer 1220 1221 def set_dtype(self, dtype): 1222 """Set data type for this scope.""" 1223 self._dtype = dtype 1224 1225 def set_use_resource(self, use_resource): 1226 """Sets whether to use ResourceVariables for this scope.""" 1227 if context.executing_eagerly() and not use_resource: 1228 raise ValueError("When eager execution is enabled, " 1229 "use_resource cannot be set to false.") 1230 self._use_resource = use_resource 1231 1232 def set_regularizer(self, regularizer): 1233 """Set regularizer for this scope.""" 1234 self._regularizer = regularizer 1235 1236 def set_caching_device(self, caching_device): 1237 """Set caching_device for this scope.""" 1238 if context.executing_eagerly(): 1239 raise NotImplementedError("Caching devices are not yet supported " 1240 "when eager execution is enabled.") 1241 self._caching_device = caching_device 1242 1243 def set_partitioner(self, partitioner): 1244 """Set partitioner for this scope.""" 1245 self._partitioner = partitioner 1246 1247 def set_custom_getter(self, custom_getter): 1248 """Set custom getter for this scope.""" 1249 self._custom_getter = custom_getter 1250 1251 def get_collection(self, name): 1252 """Get this scope's variables.""" 1253 scope = self._name + "/" if self._name else "" 1254 return ops.get_collection(name, scope) 1255 1256 def trainable_variables(self): 1257 """Get this scope's trainable variables.""" 1258 return self.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) 1259 1260 def global_variables(self): 1261 """Get this scope's global variables.""" 1262 return self.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 1263 1264 def local_variables(self): 1265 """Get this scope's local variables.""" 1266 return self.get_collection(ops.GraphKeys.LOCAL_VARIABLES) 1267 1268 def get_variable(self, 1269 var_store, 1270 name, 1271 shape=None, 1272 dtype=None, 1273 initializer=None, 1274 regularizer=None, 1275 reuse=None, 1276 trainable=None, 1277 collections=None, 1278 caching_device=None, 1279 partitioner=None, 1280 validate_shape=True, 1281 use_resource=None, 1282 custom_getter=None, 1283 constraint=None, 1284 synchronization=VariableSynchronization.AUTO, 1285 aggregation=VariableAggregation.NONE): 1286 """Gets an existing variable with this name or create a new one.""" 1287 if regularizer is None: 1288 regularizer = self._regularizer 1289 if caching_device is None: 1290 caching_device = self._caching_device 1291 if partitioner is None: 1292 partitioner = self._partitioner 1293 if custom_getter is None: 1294 custom_getter = self._custom_getter 1295 if context.executing_eagerly(): 1296 reuse = False 1297 use_resource = True 1298 else: 1299 if reuse is None: 1300 reuse = self._reuse 1301 if use_resource is None: 1302 use_resource = self._use_resource 1303 1304 full_name = self.name + "/" + name if self.name else name 1305 # Variable names only depend on variable_scope (full_name here), 1306 # not name_scope, so we reset it below for the time of variable creation. 1307 with ops.name_scope(None, skip_on_eager=False): 1308 # Check that `initializer` dtype and `dtype` are consistent before 1309 # replacing them with defaults. 1310 if (dtype is not None and initializer is not None and 1311 not callable(initializer)): 1312 init_dtype = ops.convert_to_tensor(initializer).dtype.base_dtype 1313 if init_dtype != dtype: 1314 raise ValueError("Initializer type '%s' and explicit dtype '%s' " 1315 "don't match." % (init_dtype, dtype)) 1316 if initializer is None: 1317 initializer = self._initializer 1318 if constraint is None: 1319 constraint = self._constraint 1320 if dtype is None: 1321 dtype = self._dtype 1322 return var_store.get_variable( 1323 full_name, 1324 shape=shape, 1325 dtype=dtype, 1326 initializer=initializer, 1327 regularizer=regularizer, 1328 reuse=reuse, 1329 trainable=trainable, 1330 collections=collections, 1331 caching_device=caching_device, 1332 partitioner=partitioner, 1333 validate_shape=validate_shape, 1334 use_resource=use_resource, 1335 custom_getter=custom_getter, 1336 constraint=constraint, 1337 synchronization=synchronization, 1338 aggregation=aggregation) 1339 1340 def _get_partitioned_variable(self, 1341 var_store, 1342 name, 1343 shape=None, 1344 dtype=None, 1345 initializer=None, 1346 regularizer=None, 1347 trainable=None, 1348 collections=None, 1349 caching_device=None, 1350 partitioner=None, 1351 validate_shape=True, 1352 use_resource=None, 1353 constraint=None, 1354 synchronization=VariableSynchronization.AUTO, 1355 aggregation=VariableAggregation.NONE): 1356 """Gets an existing variable with this name or create a new one.""" 1357 if initializer is None: 1358 initializer = self._initializer 1359 if regularizer is None: 1360 regularizer = self._regularizer 1361 if constraint is None: 1362 constraint = self._constraint 1363 if caching_device is None: 1364 caching_device = self._caching_device 1365 if partitioner is None: 1366 partitioner = self._partitioner 1367 if dtype is None: 1368 dtype = self._dtype 1369 if use_resource is None: 1370 use_resource = self._use_resource 1371 1372 if self._custom_getter is not None: 1373 raise ValueError( 1374 "Private access to _get_partitioned_variable is not allowed when " 1375 "a custom getter is set. Current custom getter: %s. " 1376 "It is likely that you're using create_partitioned_variables. " 1377 "If so, consider instead using get_variable with a non-empty " 1378 "partitioner parameter instead." % self._custom_getter) 1379 1380 if partitioner is None: 1381 raise ValueError("No partitioner was specified") 1382 1383 # This allows the variable scope name to be used as the variable name if 1384 # this function is invoked with an empty name arg, for backward 1385 # compatibility with create_partitioned_variables(). 1386 full_name_list = [] 1387 if self.name: 1388 full_name_list.append(self.name) 1389 if name: 1390 full_name_list.append(name) 1391 full_name = "/".join(full_name_list) 1392 1393 # Variable names only depend on variable_scope (full_name here), 1394 # not name_scope, so we reset it below for the time of variable creation. 1395 with ops.name_scope(None, skip_on_eager=False): 1396 # pylint: disable=protected-access 1397 return var_store._get_partitioned_variable( 1398 full_name, 1399 shape=shape, 1400 dtype=dtype, 1401 initializer=initializer, 1402 regularizer=regularizer, 1403 reuse=self.reuse, 1404 trainable=trainable, 1405 collections=collections, 1406 caching_device=caching_device, 1407 partitioner=partitioner, 1408 validate_shape=validate_shape, 1409 use_resource=use_resource, 1410 constraint=constraint, 1411 synchronization=synchronization, 1412 aggregation=aggregation) 1413 # pylint: enable=protected-access 1414 1415 1416_VARSTORE_KEY = ("__variable_store",) 1417_VARSCOPESTORE_KEY = ("__varscope",) 1418 1419 1420class _VariableScopeStore(threading.local): 1421 """A thread local store for the current variable scope and scope counts.""" 1422 1423 def __init__(self): 1424 super(_VariableScopeStore, self).__init__() 1425 self.current_scope = VariableScope(False) 1426 self.variable_scopes_count = {} 1427 1428 def open_variable_scope(self, scope_name): 1429 if scope_name in self.variable_scopes_count: 1430 self.variable_scopes_count[scope_name] += 1 1431 else: 1432 self.variable_scopes_count[scope_name] = 1 1433 1434 def close_variable_subscopes(self, scope_name): 1435 for k in list(self.variable_scopes_count.keys()): 1436 if scope_name is None or k.startswith(scope_name + "/"): 1437 self.variable_scopes_count[k] = 0 1438 1439 def variable_scope_count(self, scope_name): 1440 return self.variable_scopes_count.get(scope_name, 0) 1441 1442 1443def get_variable_scope_store(): 1444 """Returns the variable scope store for current thread.""" 1445 scope_store = ops.get_collection(_VARSCOPESTORE_KEY) 1446 1447 if not scope_store: 1448 scope_store = _VariableScopeStore() 1449 ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store) 1450 else: 1451 scope_store = scope_store[0] 1452 1453 return scope_store 1454 1455 1456@tf_export(v1=["get_variable_scope"]) 1457def get_variable_scope(): 1458 """Returns the current variable scope.""" 1459 return get_variable_scope_store().current_scope 1460 1461 1462def _get_default_variable_store(): 1463 store = ops.get_collection(_VARSTORE_KEY) 1464 if store: 1465 return store[0] 1466 store = _VariableStore() 1467 ops.add_to_collection(_VARSTORE_KEY, store) 1468 return store 1469 1470 1471@tf_contextlib.contextmanager 1472def with_variable_store(store): 1473 store_collection = ops.get_collection_ref(_VARSTORE_KEY) 1474 old = list(store_collection) 1475 store_collection[:] = [store] 1476 try: 1477 yield 1478 finally: 1479 store_collection[:] = old 1480 1481 1482class EagerVariableStore(object): 1483 """Wrapper allowing functional layers to be used with eager execution. 1484 1485 When eager execution is enabled Variables get deleted when they go out of 1486 scope, and are not stored in global collections by default. A lot of code 1487 (mostly the functional layers in tf.layers) assumes that variables are kept in 1488 a global list. 1489 1490 EagerVariableStore can be used in conjunction with this code to make it 1491 eager-friendly. For example, to create a dense layer, use: 1492 1493 ``` 1494 container = tfe.EagerVariableStore() 1495 for input in dataset_iterator: 1496 with container.as_default(): 1497 x = tf.compat.v1.layers.dense(input, name="l1") 1498 print(container.variables) # Should print the variables used in the layer. 1499 ``` 1500 """ 1501 1502 def __init__(self, store=None): 1503 if store is not None: 1504 if not store._store_eager_variables: # pylint: disable=protected-access 1505 raise ValueError("Cannot construct EagerVariableStore from a " 1506 "VariableStore object that does not hold eager " 1507 "variables.") 1508 self._store = store 1509 else: 1510 self._store = _VariableStore() 1511 self._store._store_eager_variables = True # pylint: disable=protected-access 1512 1513 def as_default(self): 1514 return with_variable_store(self._store) 1515 1516 def variables(self): 1517 return sorted(self._store._vars.values(), key=lambda x: x.name) # pylint: disable=protected-access 1518 1519 def trainable_variables(self): 1520 # pylint: disable=protected-access 1521 return sorted([x for x in self._store._vars.values() if x.trainable], 1522 key=lambda x: x.name) 1523 # pylint: enable=protected-access 1524 1525 def non_trainable_variables(self): 1526 # pylint: disable=protected-access 1527 return sorted([x for x in self._store._vars.values() if not x.trainable], 1528 key=lambda x: x.name) 1529 # pylint: enable=protected-access 1530 1531 def copy(self): 1532 """Copy this variable store and all of its contents. 1533 1534 Variables contained in this store will be copied over to the new variable 1535 store, meaning that they can be modified without affecting the variables in 1536 this store. 1537 1538 Returns: 1539 A new EagerVariableStore instance containing copied variables. 1540 """ 1541 # pylint: disable=protected-access 1542 new_store = EagerVariableStore() 1543 for key, var in iteritems(self._store._vars): 1544 # Strip device out of variable name. 1545 try: 1546 index = var.name.index(":") 1547 except ValueError: 1548 stripped_var_name = var.name 1549 else: 1550 stripped_var_name = var.name[:index] 1551 1552 # Create new variable with same value, name, and "trainable" flag. 1553 new_var = resource_variable_ops.ResourceVariable( 1554 var.read_value(), name=stripped_var_name, trainable=var.trainable) 1555 new_store._store._vars[key] = new_var 1556 return new_store 1557 # pylint: enable=protected-access 1558 1559 1560# The argument list for get_variable must match arguments to get_local_variable. 1561# So, if you are updating the arguments, also update arguments to 1562# get_local_variable below. 1563@tf_export(v1=["get_variable"]) 1564def get_variable(name, 1565 shape=None, 1566 dtype=None, 1567 initializer=None, 1568 regularizer=None, 1569 trainable=None, 1570 collections=None, 1571 caching_device=None, 1572 partitioner=None, 1573 validate_shape=True, 1574 use_resource=None, 1575 custom_getter=None, 1576 constraint=None, 1577 synchronization=VariableSynchronization.AUTO, 1578 aggregation=VariableAggregation.NONE): 1579 return get_variable_scope().get_variable( 1580 _get_default_variable_store(), 1581 name, 1582 shape=shape, 1583 dtype=dtype, 1584 initializer=initializer, 1585 regularizer=regularizer, 1586 trainable=trainable, 1587 collections=collections, 1588 caching_device=caching_device, 1589 partitioner=partitioner, 1590 validate_shape=validate_shape, 1591 use_resource=use_resource, 1592 custom_getter=custom_getter, 1593 constraint=constraint, 1594 synchronization=synchronization, 1595 aggregation=aggregation) 1596 1597 1598get_variable_or_local_docstring = ("""%s 1599 1600%sThis function prefixes the name with the current variable scope 1601and performs reuse checks. See the 1602[Variable Scope How To](https://tensorflow.org/guide/variables) 1603for an extensive description of how reusing works. Here is a basic example: 1604 1605```python 1606def foo(): 1607 with tf.variable_scope("foo", reuse=tf.AUTO_REUSE): 1608 v = tf.get_variable("v", [1]) 1609 return v 1610 1611v1 = foo() # Creates v. 1612v2 = foo() # Gets the same, existing v. 1613assert v1 == v2 1614``` 1615 1616If initializer is `None` (the default), the default initializer passed in 1617the variable scope will be used. If that one is `None` too, a 1618`glorot_uniform_initializer` will be used. The initializer can also be 1619a Tensor, in which case the variable is initialized to this value and shape. 1620 1621Similarly, if the regularizer is `None` (the default), the default regularizer 1622passed in the variable scope will be used (if that is `None` too, 1623then by default no regularization is performed). 1624 1625If a partitioner is provided, a `PartitionedVariable` is returned. 1626Accessing this object as a `Tensor` returns the shards concatenated along 1627the partition axis. 1628 1629Some useful partitioners are available. See, e.g., 1630`variable_axis_size_partitioner` and `min_max_variable_partitioner`. 1631 1632Args: 1633 name: The name of the new or existing variable. 1634 shape: Shape of the new or existing variable. 1635 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 1636 initializer: Initializer for the variable if one is created. Can either be 1637 an initializer object or a Tensor. If it's a Tensor, its shape must be known 1638 unless validate_shape is False. 1639 regularizer: A (Tensor -> Tensor or None) function; the result of 1640 applying it on a newly created variable will be added to the collection 1641 `tf.GraphKeys.REGULARIZATION_LOSSES` and can be used for regularization. 1642 %scollections: List of graph collections keys to add the Variable to. 1643 Defaults to `[%s]` (see `tf.Variable`). 1644 caching_device: Optional device string or function describing where the 1645 Variable should be cached for reading. Defaults to the Variable's 1646 device. If not `None`, caches on another device. Typical use is to 1647 cache on the device where the Ops using the Variable reside, to 1648 deduplicate copying through `Switch` and other conditional statements. 1649 partitioner: Optional callable that accepts a fully defined `TensorShape` 1650 and `dtype` of the Variable to be created, and returns a list of 1651 partitions for each axis (currently only one axis can be partitioned). 1652 validate_shape: If False, allows the variable to be initialized with a 1653 value of unknown shape. If True, the default, the shape of initial_value 1654 must be known. For this to be used the initializer must be a Tensor and 1655 not an initializer object. 1656 use_resource: If False, creates a regular Variable. If true, creates an 1657 experimental ResourceVariable instead with well-defined semantics. 1658 Defaults to False (will later change to True). When eager execution is 1659 enabled this argument is always forced to be True. 1660 custom_getter: Callable that takes as a first argument the true getter, and 1661 allows overwriting the internal get_variable method. 1662 The signature of `custom_getter` should match that of this method, 1663 but the most future-proof version will allow for changes: 1664 `def custom_getter(getter, *args, **kwargs)`. Direct access to 1665 all `get_variable` parameters is also allowed: 1666 `def custom_getter(getter, name, *args, **kwargs)`. A simple identity 1667 custom getter that simply creates variables with modified names is: 1668 ```python 1669 def custom_getter(getter, name, *args, **kwargs): 1670 return getter(name + '_suffix', *args, **kwargs) 1671 ``` 1672 constraint: An optional projection function to be applied to the variable 1673 after being updated by an `Optimizer` (e.g. used to implement norm 1674 constraints or value constraints for layer weights). The function must 1675 take as input the unprojected Tensor representing the value of the 1676 variable and return the Tensor for the projected value 1677 (which must have the same shape). Constraints are not safe to 1678 use when doing asynchronous distributed training. 1679 synchronization: Indicates when a distributed a variable will be 1680 aggregated. Accepted values are constants defined in the class 1681 `tf.VariableSynchronization`. By default the synchronization is set to 1682 `AUTO` and the current `DistributionStrategy` chooses 1683 when to synchronize. 1684 aggregation: Indicates how a distributed variable will be aggregated. 1685 Accepted values are constants defined in the class 1686 `tf.VariableAggregation`. 1687 1688Returns: 1689 The created or existing `Variable` (or `PartitionedVariable`, if a 1690 partitioner was used). 1691 1692Raises: 1693 ValueError: when creating a new variable and shape is not declared, 1694 when violating reuse during variable creation, or when `initializer` dtype 1695 and `dtype` don't match. Reuse is set inside `variable_scope`. 1696""") 1697get_variable.__doc__ = get_variable_or_local_docstring % ( 1698 "Gets an existing variable with these parameters or create a new one.", "", 1699 "trainable: If `True` also add the variable to the graph collection\n" 1700 " `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).\n ", 1701 "GraphKeys.GLOBAL_VARIABLES") 1702 1703 1704# The argument list for get_local_variable must match arguments to get_variable. 1705# So, if you are updating the arguments, also update arguments to get_variable. 1706@tf_export(v1=["get_local_variable"]) 1707def get_local_variable( # pylint: disable=missing-docstring 1708 name, 1709 shape=None, 1710 dtype=None, 1711 initializer=None, 1712 regularizer=None, 1713 trainable=False, # pylint: disable=unused-argument 1714 collections=None, 1715 caching_device=None, 1716 partitioner=None, 1717 validate_shape=True, 1718 use_resource=None, 1719 custom_getter=None, 1720 constraint=None, 1721 synchronization=VariableSynchronization.AUTO, 1722 aggregation=VariableAggregation.NONE): 1723 if collections: 1724 collections += [ops.GraphKeys.LOCAL_VARIABLES] 1725 else: 1726 collections = [ops.GraphKeys.LOCAL_VARIABLES] 1727 return get_variable( 1728 name, 1729 shape=shape, 1730 dtype=dtype, 1731 initializer=initializer, 1732 regularizer=regularizer, 1733 trainable=False, 1734 collections=collections, 1735 caching_device=caching_device, 1736 partitioner=partitioner, 1737 validate_shape=validate_shape, 1738 use_resource=use_resource, 1739 synchronization=synchronization, 1740 aggregation=aggregation, 1741 custom_getter=custom_getter, 1742 constraint=constraint) 1743 1744 1745get_local_variable.__doc__ = get_variable_or_local_docstring % ( 1746 "Gets an existing *local* variable or creates a new one.", 1747 "Behavior is the same as in `get_variable`, except that variables are\n" 1748 "added to the `LOCAL_VARIABLES` collection and `trainable` is set to\n" 1749 "`False`.\n", "", "GraphKeys.LOCAL_VARIABLES") 1750 1751 1752def _get_partitioned_variable(name, 1753 shape=None, 1754 dtype=None, 1755 initializer=None, 1756 regularizer=None, 1757 trainable=True, 1758 collections=None, 1759 caching_device=None, 1760 partitioner=None, 1761 validate_shape=True, 1762 use_resource=None, 1763 constraint=None, 1764 synchronization=VariableSynchronization.AUTO, 1765 aggregation=VariableAggregation.NONE): 1766 """Gets or creates a sharded variable list with these parameters. 1767 1768 The `partitioner` must be a callable that accepts a fully defined 1769 `TensorShape` and returns a sequence of integers (the `partitions`). 1770 These integers describe how to partition the given sharded `Variable` 1771 along the given dimension. That is, `partitions[1] = 3` means split 1772 the `Variable` into 3 shards along dimension 1. Currently, sharding along 1773 only one axis is supported. 1774 1775 If the list of variables with the given name (prefix) is already stored, 1776 we return the stored variables. Otherwise, we create a new one. 1777 1778 If initializer is `None` (the default), the default initializer passed in 1779 the constructor is used. If that one is `None` too, we use a new 1780 `glorot_uniform_initializer`. If initializer is a Tensor, we use 1781 it as a value and derive the shape from the initializer. 1782 1783 If the initializer is a callable, then it will be called for each 1784 shard. Otherwise the initializer should match the shape of the entire 1785 sharded Variable, and it will be sliced accordingly for each shard. 1786 1787 Some useful partitioners are available. See, e.g., 1788 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 1789 1790 Args: 1791 name: The name of the new or existing variable. 1792 shape: Shape of the new or existing variable. 1793 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 1794 initializer: Initializer for the variable if one is created. 1795 regularizer: A (Tensor -> Tensor or None) function; the result of applying 1796 it on a newly created variable will be added to the collection 1797 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 1798 trainable: If `True` also add the variable to the graph collection 1799 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 1800 collections: List of graph collections keys to add the Variable to. Defaults 1801 to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 1802 caching_device: Optional device string or function describing where the 1803 Variable should be cached for reading. Defaults to the Variable's device. 1804 If not `None`, caches on another device. Typical use is to cache on the 1805 device where the Ops using the Variable reside, to deduplicate copying 1806 through `Switch` and other conditional statements. 1807 partitioner: Optional callable that accepts a fully defined `TensorShape` 1808 and `dtype` of the Variable to be created, and returns a list of 1809 partitions for each axis (currently only one axis can be partitioned). 1810 validate_shape: If False, allows the variable to be initialized with a value 1811 of unknown shape. If True, the default, the shape of initial_value must be 1812 known. 1813 use_resource: If False, creates a regular Variable. If True, creates an 1814 experimental ResourceVariable instead which has well-defined semantics. 1815 Defaults to False (will later change to True). 1816 constraint: An optional projection function to be applied to the variable 1817 after being updated by an `Optimizer` (e.g. used to implement norm 1818 constraints or value constraints for layer weights). The function must 1819 take as input the unprojected Tensor representing the value of the 1820 variable and return the Tensor for the projected value (which must have 1821 the same shape). Constraints are not safe to use when doing asynchronous 1822 distributed training. 1823 synchronization: Indicates when a distributed a variable will be aggregated. 1824 Accepted values are constants defined in the class 1825 `tf.VariableSynchronization`. By default the synchronization is set to 1826 `AUTO` and the current `DistributionStrategy` chooses when to synchronize. 1827 aggregation: Indicates how a distributed variable will be aggregated. 1828 Accepted values are constants defined in the class 1829 `tf.VariableAggregation`. 1830 1831 Returns: 1832 A tuple `(shards, partitions)` where `shards` is the list of `Variable` 1833 shards and `partitions` is the output of the partitioner on the input 1834 shape. 1835 1836 Raises: 1837 ValueError: when creating a new variable and shape is not declared, 1838 or when violating reuse during variable creation. Reuse is set inside 1839 `variable_scope`. 1840 """ 1841 # pylint: disable=protected-access 1842 scope = get_variable_scope() 1843 if scope.custom_getter is not None: 1844 raise ValueError( 1845 "Private access to _get_partitioned_variable is not allowed when " 1846 "a custom getter is set. Current custom getter: %s. " 1847 "It is likely that you're using create_partitioned_variables. " 1848 "If so, consider instead using get_variable with a non-empty " 1849 "partitioner parameter instead." % scope.custom_getter) 1850 return scope._get_partitioned_variable( 1851 _get_default_variable_store(), 1852 name, 1853 shape=shape, 1854 dtype=dtype, 1855 initializer=initializer, 1856 regularizer=regularizer, 1857 trainable=trainable, 1858 collections=collections, 1859 caching_device=caching_device, 1860 partitioner=partitioner, 1861 validate_shape=validate_shape, 1862 use_resource=use_resource, 1863 constraint=constraint, 1864 synchronization=synchronization, 1865 aggregation=aggregation) 1866 # pylint: enable=protected-access 1867 1868 1869# Named like a function for compatibility with the previous 1870# @tf_contextlib.contextmanager definition. 1871class _pure_variable_scope(object): # pylint: disable=invalid-name 1872 """A context for the variable_scope, see `variable_scope` for docs.""" 1873 1874 def __init__(self, 1875 name_or_scope, 1876 reuse=None, 1877 initializer=None, 1878 regularizer=None, 1879 caching_device=None, 1880 partitioner=None, 1881 custom_getter=None, 1882 old_name_scope=None, 1883 dtype=dtypes.float32, 1884 use_resource=None, 1885 constraint=None): 1886 """Creates a context for the variable_scope, see `variable_scope` for docs. 1887 1888 Note: this does not create a name scope. 1889 1890 Args: 1891 name_or_scope: `string` or `VariableScope`: the scope to open. 1892 reuse: `True` or None, or tf.compat.v1.AUTO_REUSE; if `None`, we inherit 1893 the parent scope's reuse flag. 1894 initializer: default initializer for variables within this scope. 1895 regularizer: default regularizer for variables within this scope. 1896 caching_device: default caching device for variables within this scope. 1897 partitioner: default partitioner for variables within this scope. 1898 custom_getter: default custom getter for variables within this scope. 1899 old_name_scope: the original name scope when re-entering a variable scope. 1900 dtype: type of the variables within this scope (defaults to `DT_FLOAT`). 1901 use_resource: If False, variables in this scope will be regular Variables. 1902 If True, experimental ResourceVariables will be creates instead, with 1903 well-defined semantics. Defaults to False (will later change to True). 1904 constraint: An optional projection function to be applied to the variable 1905 after being updated by an `Optimizer` (e.g. used to implement norm 1906 constraints or value constraints for layer weights). The function must 1907 take as input the unprojected Tensor representing the value of the 1908 variable and return the Tensor for the projected value (which must have 1909 the same shape). Constraints are not safe to use when doing asynchronous 1910 distributed training. 1911 """ 1912 self._name_or_scope = name_or_scope 1913 self._reuse = reuse 1914 self._initializer = initializer 1915 self._regularizer = regularizer 1916 self._caching_device = caching_device 1917 self._partitioner = partitioner 1918 self._custom_getter = custom_getter 1919 self._old_name_scope = old_name_scope 1920 self._dtype = dtype 1921 self._use_resource = use_resource 1922 self._constraint = constraint 1923 self._var_store = _get_default_variable_store() 1924 self._var_scope_store = get_variable_scope_store() 1925 self._last_variable_scope_object = None 1926 if isinstance(self._name_or_scope, VariableScope): 1927 self._new_name = self._name_or_scope.name 1928 name_scope = self._name_or_scope._name_scope # pylint: disable=protected-access 1929 # Handler for the case when we jump to a shared scope. We create a new 1930 # VariableScope (self._var_scope_object) that contains a copy of the 1931 # provided shared scope, possibly with changed reuse and initializer, if 1932 # the user requested this. 1933 variable_scope_object = VariableScope( 1934 self._name_or_scope.reuse if not self._reuse else self._reuse, 1935 name=self._new_name, 1936 initializer=self._name_or_scope.initializer, 1937 regularizer=self._name_or_scope.regularizer, 1938 caching_device=self._name_or_scope.caching_device, 1939 partitioner=self._name_or_scope.partitioner, 1940 dtype=self._name_or_scope.dtype, 1941 custom_getter=self._name_or_scope.custom_getter, 1942 name_scope=name_scope, 1943 use_resource=self._name_or_scope.use_resource, 1944 constraint=self._constraint) 1945 if self._initializer is not None: 1946 variable_scope_object.set_initializer(self._initializer) 1947 if self._regularizer is not None: 1948 variable_scope_object.set_regularizer(self._regularizer) 1949 if self._caching_device is not None: 1950 variable_scope_object.set_caching_device(self._caching_device) 1951 if self._partitioner is not None: 1952 variable_scope_object.set_partitioner(self._partitioner) 1953 if self._custom_getter is not None: 1954 variable_scope_object.set_custom_getter( 1955 _maybe_wrap_custom_getter(self._custom_getter, 1956 self._name_or_scope.custom_getter)) 1957 if self._dtype is not None: 1958 variable_scope_object.set_dtype(self._dtype) 1959 if self._use_resource is not None: 1960 variable_scope_object.set_use_resource(self._use_resource) 1961 self._cached_variable_scope_object = variable_scope_object 1962 1963 def __enter__(self): 1964 """Begins the scope block. 1965 1966 Returns: 1967 A VariableScope. 1968 Raises: 1969 ValueError: when trying to reuse within a create scope, or create within 1970 a reuse scope, or if reuse is not `None` or `True`. 1971 TypeError: when the types of some arguments are not appropriate. 1972 """ 1973 self._old = self._var_scope_store.current_scope 1974 if isinstance(self._name_or_scope, VariableScope): 1975 self._var_scope_store.open_variable_scope(self._new_name) 1976 self._old_subscopes = copy.copy( 1977 self._var_scope_store.variable_scopes_count) 1978 variable_scope_object = self._cached_variable_scope_object 1979 else: 1980 # Handler for the case when we just prolong current variable scope. 1981 # VariableScope with name extended by the provided one, and inherited 1982 # reuse and initializer (except if the user provided values to set). 1983 self._new_name = ( 1984 self._old.name + "/" + 1985 self._name_or_scope if self._old.name else self._name_or_scope) 1986 self._reuse = (self._reuse or 1987 self._old.reuse) # Re-using is inherited by sub-scopes. 1988 if self._old_name_scope is None: 1989 name_scope = self._name_or_scope 1990 else: 1991 name_scope = self._old_name_scope 1992 variable_scope_object = VariableScope( 1993 self._reuse, 1994 name=self._new_name, 1995 initializer=self._old.initializer, 1996 regularizer=self._old.regularizer, 1997 caching_device=self._old.caching_device, 1998 partitioner=self._old.partitioner, 1999 dtype=self._old.dtype, 2000 use_resource=self._old.use_resource, 2001 custom_getter=self._old.custom_getter, 2002 name_scope=name_scope, 2003 constraint=self._constraint) 2004 if self._initializer is not None: 2005 variable_scope_object.set_initializer(self._initializer) 2006 if self._regularizer is not None: 2007 variable_scope_object.set_regularizer(self._regularizer) 2008 if self._caching_device is not None: 2009 variable_scope_object.set_caching_device(self._caching_device) 2010 if self._partitioner is not None: 2011 variable_scope_object.set_partitioner(self._partitioner) 2012 if self._custom_getter is not None: 2013 variable_scope_object.set_custom_getter( 2014 _maybe_wrap_custom_getter(self._custom_getter, 2015 self._old.custom_getter)) 2016 if self._dtype is not None: 2017 variable_scope_object.set_dtype(self._dtype) 2018 if self._use_resource is not None: 2019 variable_scope_object.set_use_resource(self._use_resource) 2020 self._var_scope_store.open_variable_scope(self._new_name) 2021 self._var_scope_store.current_scope = variable_scope_object 2022 self._last_variable_scope_object = variable_scope_object 2023 return variable_scope_object 2024 2025 def __exit__(self, type_arg, value_arg, traceback_arg): 2026 if (self._var_scope_store.current_scope is 2027 not self._last_variable_scope_object): 2028 raise RuntimeError("Improper nesting of variable_scope.") 2029 # If jumping out from a non-prolonged scope, restore counts. 2030 if isinstance(self._name_or_scope, VariableScope): 2031 self._var_scope_store.variable_scopes_count = self._old_subscopes 2032 else: 2033 self._var_scope_store.close_variable_subscopes(self._new_name) 2034 self._var_scope_store.current_scope = self._old 2035 2036 2037def _maybe_wrap_custom_getter(custom_getter, old_getter): 2038 """Wrap a call to a custom_getter to use the old_getter internally.""" 2039 if old_getter is None: 2040 return custom_getter 2041 2042 # The new custom_getter should call the old one 2043 def wrapped_custom_getter(getter, *args, **kwargs): 2044 # Call: 2045 # custom_getter( 2046 # lambda: old_getter(true_getter, ...), *args, **kwargs) 2047 # which means custom_getter will call old_getter, which 2048 # will call the true_getter, perform any intermediate 2049 # processing, and return the results to the current 2050 # getter, which will also perform additional processing. 2051 return custom_getter(functools.partial(old_getter, getter), *args, **kwargs) 2052 2053 return wrapped_custom_getter 2054 2055 2056def _get_unique_variable_scope(prefix): 2057 """Get a name with the given prefix unique in the current variable scope.""" 2058 var_scope_store = get_variable_scope_store() 2059 current_scope = get_variable_scope() 2060 name = current_scope.name + "/" + prefix if current_scope.name else prefix 2061 if var_scope_store.variable_scope_count(name) == 0: 2062 return prefix 2063 idx = 1 2064 while var_scope_store.variable_scope_count(name + ("_%d" % idx)) > 0: 2065 idx += 1 2066 return prefix + ("_%d" % idx) 2067 2068 2069# Named like a function for backwards compatibility with the 2070# @tf_contextlib.contextmanager version, which was switched to a class to avoid 2071# some object creation overhead. 2072@tf_export(v1=["variable_scope"]) # pylint: disable=invalid-name 2073class variable_scope(object): 2074 """A context manager for defining ops that creates variables (layers). 2075 2076 This context manager validates that the (optional) `values` are from the same 2077 graph, ensures that graph is the default graph, and pushes a name scope and a 2078 variable scope. 2079 2080 If `name_or_scope` is not None, it is used as is. If `name_or_scope` is None, 2081 then `default_name` is used. In that case, if the same name has been 2082 previously used in the same scope, it will be made unique by appending `_N` 2083 to it. 2084 2085 Variable scope allows you to create new variables and to share already created 2086 ones while providing checks to not create or share by accident. For details, 2087 see the [Variable Scope How To](https://tensorflow.org/guide/variables), here 2088 we present only a few basic examples. 2089 2090 The Variable Scope works as expected when the Eager Execution is Disabled. 2091 2092 ```python 2093 tf.compat.v1.disable_eager_execution() 2094 ``` 2095 2096 Simple example of how to create a new variable: 2097 2098 ```python 2099 with tf.compat.v1.variable_scope("foo"): 2100 with tf.compat.v1.variable_scope("bar"): 2101 v = tf.compat.v1.get_variable("v", [1]) 2102 assert v.name == "foo/bar/v:0" 2103 ``` 2104 2105 Simple example of how to reenter a premade variable scope safely: 2106 2107 ```python 2108 with tf.compat.v1.variable_scope("foo") as vs: 2109 pass 2110 2111 # Re-enter the variable scope. 2112 with tf.compat.v1.variable_scope(vs, 2113 auxiliary_name_scope=False) as vs1: 2114 # Restore the original name_scope. 2115 with tf.name_scope(vs1.original_name_scope): 2116 v = tf.compat.v1.get_variable("v", [1]) 2117 assert v.name == "foo/v:0" 2118 c = tf.constant([1], name="c") 2119 assert c.name == "foo/c:0" 2120 ``` 2121 2122 Keep in mind that the counters for `default_name` are discarded once the 2123 parent scope is exited. Therefore when the code re-enters the scope (for 2124 instance by saving it), all nested default_name counters will be restarted. 2125 2126 For instance: 2127 2128 ```python 2129 with tf.compat.v1.variable_scope("foo") as vs: 2130 with tf.compat.v1.variable_scope(None, default_name="bar"): 2131 v = tf.compat.v1.get_variable("a", [1]) 2132 assert v.name == "foo/bar/a:0", v.name 2133 with tf.compat.v1.variable_scope(None, default_name="bar"): 2134 v = tf.compat.v1.get_variable("b", [1]) 2135 assert v.name == "foo/bar_1/b:0" 2136 2137 with tf.compat.v1.variable_scope(vs): 2138 with tf.compat.v1.variable_scope(None, default_name="bar"): 2139 v = tf.compat.v1.get_variable("c", [1]) 2140 assert v.name == "foo/bar/c:0" # Uses bar instead of bar_2! 2141 ``` 2142 2143 Basic example of sharing a variable AUTO_REUSE: 2144 2145 ```python 2146 def foo(): 2147 with tf.compat.v1.variable_scope("foo", reuse=tf.compat.v1.AUTO_REUSE): 2148 v = tf.compat.v1.get_variable("v", [1]) 2149 return v 2150 2151 v1 = foo() # Creates v. 2152 v2 = foo() # Gets the same, existing v. 2153 assert v1 == v2 2154 ``` 2155 2156 Basic example of sharing a variable with reuse=True: 2157 2158 ```python 2159 with tf.compat.v1.variable_scope("foo"): 2160 v = tf.compat.v1.get_variable("v", [1]) 2161 with tf.compat.v1.variable_scope("foo", reuse=True): 2162 v1 = tf.compat.v1.get_variable("v", [1]) 2163 assert v1 == v 2164 ``` 2165 2166 Sharing a variable by capturing a scope and setting reuse: 2167 2168 ```python 2169 with tf.compat.v1.variable_scope("foo") as scope: 2170 v = tf.compat.v1.get_variable("v", [1]) 2171 scope.reuse_variables() 2172 v1 = tf.compat.v1.get_variable("v", [1]) 2173 assert v1 == v 2174 ``` 2175 2176 To prevent accidental sharing of variables, we raise an exception when getting 2177 an existing variable in a non-reusing scope. 2178 2179 ```python 2180 with tf.compat.v1.variable_scope("foo"): 2181 v = tf.compat.v1.get_variable("v", [1]) 2182 v1 = tf.compat.v1.get_variable("v", [1]) 2183 # Raises ValueError("... v already exists ..."). 2184 ``` 2185 2186 Similarly, we raise an exception when trying to get a variable that does not 2187 exist in reuse mode. 2188 2189 ```python 2190 with tf.compat.v1.variable_scope("foo", reuse=True): 2191 v = tf.compat.v1.get_variable("v", [1]) 2192 # Raises ValueError("... v does not exists ..."). 2193 ``` 2194 2195 Note that the `reuse` flag is inherited: if we open a reusing scope, then all 2196 its sub-scopes become reusing as well. 2197 2198 A note about name scoping: Setting `reuse` does not impact the naming of other 2199 ops such as mult. See related discussion on 2200 [github#6189](https://github.com/tensorflow/tensorflow/issues/6189) 2201 2202 Note that up to and including version 1.0, it was allowed (though explicitly 2203 discouraged) to pass False to the reuse argument, yielding undocumented 2204 behaviour slightly different from None. Starting at 1.1.0 passing None and 2205 False as reuse has exactly the same effect. 2206 2207 A note about using variable scopes in multi-threaded environment: Variable 2208 scopes are thread local, so one thread will not see another thread's current 2209 scope. Also, when using `default_name`, unique scopes names are also generated 2210 only on a per thread basis. If the same name was used within a different 2211 thread, that doesn't prevent a new thread from creating the same scope. 2212 However, the underlying variable store is shared across threads (within the 2213 same graph). As such, if another thread tries to create a new variable with 2214 the same name as a variable created by a previous thread, it will fail unless 2215 reuse is True. 2216 2217 Further, each thread starts with an empty variable scope. So if you wish to 2218 preserve name prefixes from a scope from the main thread, you should capture 2219 the main thread's scope and re-enter it in each thread. For e.g. 2220 2221 ``` 2222 main_thread_scope = variable_scope.get_variable_scope() 2223 2224 # Thread's target function: 2225 def thread_target_fn(captured_scope): 2226 with variable_scope.variable_scope(captured_scope): 2227 # .... regular code for this thread 2228 2229 2230 thread = threading.Thread(target=thread_target_fn, args=(main_thread_scope,)) 2231 ``` 2232 """ 2233 2234 def __init__(self, 2235 name_or_scope, 2236 default_name=None, 2237 values=None, 2238 initializer=None, 2239 regularizer=None, 2240 caching_device=None, 2241 partitioner=None, 2242 custom_getter=None, 2243 reuse=None, 2244 dtype=None, 2245 use_resource=None, 2246 constraint=None, 2247 auxiliary_name_scope=True): 2248 """Initialize the context manager. 2249 2250 Args: 2251 name_or_scope: `string` or `VariableScope`: the scope to open. 2252 default_name: The default name to use if the `name_or_scope` argument is 2253 `None`, this name will be uniquified. If name_or_scope is provided it 2254 won't be used and therefore it is not required and can be None. 2255 values: The list of `Tensor` arguments that are passed to the op function. 2256 initializer: default initializer for variables within this scope. 2257 regularizer: default regularizer for variables within this scope. 2258 caching_device: default caching device for variables within this scope. 2259 partitioner: default partitioner for variables within this scope. 2260 custom_getter: default custom getter for variables within this scope. 2261 reuse: `True`, None, or tf.compat.v1.AUTO_REUSE; if `True`, we go into 2262 reuse mode for this scope as well as all sub-scopes; if 2263 tf.compat.v1.AUTO_REUSE, we create variables if they do not exist, and 2264 return them otherwise; if None, we inherit the parent scope's reuse 2265 flag. When eager execution is enabled, new variables are always created 2266 unless an EagerVariableStore or template is currently active. 2267 dtype: type of variables created in this scope (defaults to the type in 2268 the passed scope, or inherited from parent scope). 2269 use_resource: If False, all variables will be regular Variables. If True, 2270 experimental ResourceVariables with well-defined semantics will be used 2271 instead. Defaults to False (will later change to True). When eager 2272 execution is enabled this argument is always forced to be True. 2273 constraint: An optional projection function to be applied to the variable 2274 after being updated by an `Optimizer` (e.g. used to implement norm 2275 constraints or value constraints for layer weights). The function must 2276 take as input the unprojected Tensor representing the value of the 2277 variable and return the Tensor for the projected value (which must have 2278 the same shape). Constraints are not safe to use when doing asynchronous 2279 distributed training. 2280 auxiliary_name_scope: If `True`, we create an auxiliary name scope with 2281 the scope. If `False`, we don't create it. Note that the argument is not 2282 inherited, and it only takes effect for once when creating. You should 2283 only use it for re-entering a premade variable scope. 2284 2285 Returns: 2286 A scope that can be captured and reused. 2287 2288 Raises: 2289 ValueError: when trying to reuse within a create scope, or create within 2290 a reuse scope. 2291 TypeError: when the types of some arguments are not appropriate. 2292 """ 2293 self._name_or_scope = name_or_scope 2294 self._default_name = default_name 2295 self._values = values 2296 self._initializer = initializer 2297 self._regularizer = regularizer 2298 self._caching_device = caching_device 2299 self._partitioner = partitioner 2300 self._custom_getter = custom_getter 2301 self._reuse = reuse 2302 self._dtype = dtype 2303 self._use_resource = use_resource 2304 self._constraint = constraint 2305 if self._default_name is None and self._name_or_scope is None: 2306 raise TypeError("If default_name is None then name_or_scope is required") 2307 if self._reuse is False: 2308 # We don't allow non-inheriting scopes, False = None here. 2309 self._reuse = None 2310 if not (self._reuse is True 2311 or self._reuse is None 2312 or self._reuse is AUTO_REUSE): 2313 raise ValueError("The reuse parameter must be True or False or None.") 2314 if self._values is None: 2315 self._values = [] 2316 self._in_graph_mode = not context.executing_eagerly() 2317 if self._in_graph_mode: 2318 self._graph = ops._get_graph_from_inputs(self._values) # pylint: disable=protected-access 2319 self._cached_pure_variable_scope = None 2320 self._current_name_scope = None 2321 if not isinstance(auxiliary_name_scope, bool): 2322 raise TypeError("The auxiliary_name_scope must be `True` or `False`, " 2323 "while get {}".format(auxiliary_name_scope)) 2324 self._auxiliary_name_scope = auxiliary_name_scope 2325 2326 def __enter__(self): 2327 # If the default graph is building a function, then we should not replace it 2328 # with the cached graph. 2329 if ops.get_default_graph().building_function: 2330 self._building_function = True 2331 else: 2332 self._building_function = False 2333 if self._in_graph_mode and not self._building_function: 2334 self._graph_context_manager = self._graph.as_default() 2335 self._graph_context_manager.__enter__() 2336 if self._cached_pure_variable_scope is not None: 2337 # Fast path for re-entering variable_scopes. We've held on to the pure 2338 # variable scope from a previous successful __enter__, so we avoid some 2339 # overhead by re-using that object. 2340 if self._current_name_scope is not None: 2341 self._current_name_scope.__enter__() 2342 return self._cached_pure_variable_scope.__enter__() 2343 2344 try: 2345 return self._enter_scope_uncached() 2346 except: 2347 if (self._in_graph_mode and not self._building_function and 2348 self._graph_context_manager is not None): 2349 self._graph_context_manager.__exit__(*sys.exc_info()) 2350 raise 2351 2352 def _enter_scope_uncached(self): 2353 """Enters the context manager when there is no cached scope yet. 2354 2355 Returns: 2356 The entered variable scope. 2357 2358 Raises: 2359 TypeError: A wrong type is passed as `scope` at __init__(). 2360 ValueError: `reuse` is incorrectly set at __init__(). 2361 """ 2362 if self._auxiliary_name_scope: 2363 # Create a new name scope later 2364 current_name_scope = None 2365 else: 2366 # Reenter the current name scope 2367 name_scope = ops.get_name_scope() 2368 if name_scope: 2369 # Hack to reenter 2370 name_scope += "/" 2371 current_name_scope = ops.name_scope(name_scope, skip_on_eager=False) 2372 else: 2373 # Root scope 2374 current_name_scope = ops.name_scope(name_scope, skip_on_eager=False) 2375 2376 # IMPORTANT: Only assign to self._cached_pure_variable_scope and 2377 # self._current_name_scope after successful __enter__() calls. 2378 if self._name_or_scope is not None: 2379 if not isinstance(self._name_or_scope, 2380 (VariableScope,) + six.string_types): 2381 raise TypeError("VariableScope: name_or_scope must be a string or " 2382 "VariableScope.") 2383 if isinstance(self._name_or_scope, six.string_types): 2384 name_scope = self._name_or_scope 2385 else: 2386 name_scope = self._name_or_scope.name.split("/")[-1] 2387 if name_scope or current_name_scope: 2388 current_name_scope = current_name_scope or ops.name_scope( 2389 name_scope, skip_on_eager=False) 2390 try: 2391 current_name_scope_name = current_name_scope.__enter__() 2392 except: 2393 current_name_scope.__exit__(*sys.exc_info()) 2394 raise 2395 self._current_name_scope = current_name_scope 2396 if isinstance(self._name_or_scope, six.string_types): 2397 old_name_scope = current_name_scope_name 2398 else: 2399 old_name_scope = self._name_or_scope.original_name_scope 2400 pure_variable_scope = _pure_variable_scope( 2401 self._name_or_scope, 2402 reuse=self._reuse, 2403 initializer=self._initializer, 2404 regularizer=self._regularizer, 2405 caching_device=self._caching_device, 2406 partitioner=self._partitioner, 2407 custom_getter=self._custom_getter, 2408 old_name_scope=old_name_scope, 2409 dtype=self._dtype, 2410 use_resource=self._use_resource, 2411 constraint=self._constraint) 2412 try: 2413 entered_pure_variable_scope = pure_variable_scope.__enter__() 2414 except: 2415 pure_variable_scope.__exit__(*sys.exc_info()) 2416 raise 2417 self._cached_pure_variable_scope = pure_variable_scope 2418 return entered_pure_variable_scope 2419 else: 2420 self._current_name_scope = None 2421 # This can only happen if someone is entering the root variable scope. 2422 pure_variable_scope = _pure_variable_scope( 2423 self._name_or_scope, 2424 reuse=self._reuse, 2425 initializer=self._initializer, 2426 regularizer=self._regularizer, 2427 caching_device=self._caching_device, 2428 partitioner=self._partitioner, 2429 custom_getter=self._custom_getter, 2430 dtype=self._dtype, 2431 use_resource=self._use_resource, 2432 constraint=self._constraint) 2433 try: 2434 entered_pure_variable_scope = pure_variable_scope.__enter__() 2435 except: 2436 pure_variable_scope.__exit__(*sys.exc_info()) 2437 raise 2438 self._cached_pure_variable_scope = pure_variable_scope 2439 return entered_pure_variable_scope 2440 2441 else: # Here name_or_scope is None. Using default name, but made unique. 2442 if self._reuse: 2443 raise ValueError("reuse=True cannot be used without a name_or_scope") 2444 current_name_scope = current_name_scope or ops.name_scope( 2445 self._default_name, skip_on_eager=False) 2446 try: 2447 current_name_scope_name = current_name_scope.__enter__() 2448 except: 2449 current_name_scope.__exit__(*sys.exc_info()) 2450 raise 2451 self._current_name_scope = current_name_scope 2452 unique_default_name = _get_unique_variable_scope(self._default_name) 2453 pure_variable_scope = _pure_variable_scope( 2454 unique_default_name, 2455 initializer=self._initializer, 2456 regularizer=self._regularizer, 2457 caching_device=self._caching_device, 2458 partitioner=self._partitioner, 2459 custom_getter=self._custom_getter, 2460 old_name_scope=current_name_scope_name, 2461 dtype=self._dtype, 2462 use_resource=self._use_resource, 2463 constraint=self._constraint) 2464 try: 2465 entered_pure_variable_scope = pure_variable_scope.__enter__() 2466 except: 2467 pure_variable_scope.__exit__(*sys.exc_info()) 2468 raise 2469 self._cached_pure_variable_scope = pure_variable_scope 2470 return entered_pure_variable_scope 2471 2472 def __exit__(self, type_arg, value_arg, traceback_arg): 2473 try: 2474 self._cached_pure_variable_scope.__exit__(type_arg, value_arg, 2475 traceback_arg) 2476 finally: 2477 try: 2478 if self._current_name_scope: 2479 self._current_name_scope.__exit__(type_arg, value_arg, 2480 traceback_arg) 2481 finally: 2482 if self._in_graph_mode and not self._building_function: 2483 self._graph_context_manager.__exit__(type_arg, value_arg, 2484 traceback_arg) 2485 2486 2487# pylint: disable=g-doc-return-or-yield 2488@tf_export(v1=["variable_op_scope"]) 2489@tf_contextlib.contextmanager 2490def variable_op_scope(values, 2491 name_or_scope, 2492 default_name=None, 2493 initializer=None, 2494 regularizer=None, 2495 caching_device=None, 2496 partitioner=None, 2497 custom_getter=None, 2498 reuse=None, 2499 dtype=None, 2500 use_resource=None, 2501 constraint=None): 2502 """Deprecated: context manager for defining an op that creates variables.""" 2503 logging.warn("tf.variable_op_scope(values, name, default_name) is deprecated," 2504 " use tf.variable_scope(name, default_name, values)") 2505 with variable_scope( 2506 name_or_scope, 2507 default_name=default_name, 2508 values=values, 2509 initializer=initializer, 2510 regularizer=regularizer, 2511 caching_device=caching_device, 2512 partitioner=partitioner, 2513 custom_getter=custom_getter, 2514 reuse=reuse, 2515 dtype=dtype, 2516 use_resource=use_resource, 2517 constraint=constraint) as scope: 2518 yield scope 2519 2520 2521def _call_partitioner(partitioner, shape, dtype): 2522 """Call partitioner validating its inputs/output. 2523 2524 Args: 2525 partitioner: a function mapping `Tensor` shape and dtype to a list of 2526 partitions. 2527 shape: shape of the `Tensor` to partition, must have at least two 2528 dimensions. 2529 dtype: dtype of the elements in the `Tensor`. 2530 2531 Returns: 2532 A list with elements >=1 and exactly one >1. The index of that 2533 element corresponds to the partitioning axis. 2534 """ 2535 if not shape.is_fully_defined(): 2536 raise ValueError("Shape of a new partitioned variable must be " 2537 "fully defined, but instead was %s." % (shape,)) 2538 if shape.ndims < 1: 2539 raise ValueError("A partitioned Variable must have rank at least 1, " 2540 "shape: %s" % shape) 2541 2542 slicing = partitioner(shape=shape, dtype=dtype) 2543 if not isinstance(slicing, collections_abc.Sequence): 2544 raise ValueError("Partitioner must return a sequence, but saw: %s" % 2545 slicing) 2546 if len(slicing) != shape.ndims: 2547 raise ValueError( 2548 "Partitioner returned a partition list that does not match the " 2549 "Variable's rank: %s vs. %s" % (slicing, shape)) 2550 if any(p < 1 for p in slicing): 2551 raise ValueError("Partitioner returned zero partitions for some axes: %s" % 2552 slicing) 2553 if sum(p > 1 for p in slicing) > 1: 2554 raise ValueError("Can only slice a variable along one dimension: " 2555 "shape: %s, partitioning: %s" % (shape, slicing)) 2556 return slicing 2557 2558 2559# TODO(slebedev): could be inlined, but 2560# `_VariableStore._get_partitioned_variable` is too complex even 2561# without this logic. 2562def _get_slice_dim_and_num_slices(slicing): 2563 """Get slicing dimension and number of slices from the partitioner output.""" 2564 for slice_dim, num_slices in enumerate(slicing): 2565 if num_slices > 1: 2566 break 2567 else: 2568 # Degenerate case: no partitioning applied. 2569 slice_dim = 0 2570 num_slices = 1 2571 return slice_dim, num_slices 2572 2573 2574def _iter_slices(full_shape, num_slices, slice_dim): 2575 """Slices a given a shape along the specified dimension.""" 2576 num_slices_with_excess = full_shape[slice_dim] % num_slices 2577 offset = [0] * len(full_shape) 2578 min_slice_len = full_shape[slice_dim] // num_slices 2579 for i in xrange(num_slices): 2580 shape = full_shape[:] 2581 shape[slice_dim] = min_slice_len + bool(i < num_slices_with_excess) 2582 yield offset[:], shape 2583 offset[slice_dim] += shape[slice_dim] 2584 2585 2586def default_variable_creator(next_creator=None, **kwargs): 2587 """Default variable creator.""" 2588 assert next_creator is None 2589 initial_value = kwargs.get("initial_value", None) 2590 trainable = kwargs.get("trainable", None) 2591 collections = kwargs.get("collections", None) 2592 validate_shape = kwargs.get("validate_shape", True) 2593 caching_device = kwargs.get("caching_device", None) 2594 name = kwargs.get("name", None) 2595 variable_def = kwargs.get("variable_def", None) 2596 dtype = kwargs.get("dtype", None) 2597 expected_shape = kwargs.get("expected_shape", None) 2598 import_scope = kwargs.get("import_scope", None) 2599 constraint = kwargs.get("constraint", None) 2600 use_resource = kwargs.get("use_resource", None) 2601 synchronization = kwargs.get("synchronization", None) 2602 aggregation = kwargs.get("aggregation", None) 2603 shape = kwargs.get("shape", None) 2604 2605 if use_resource is None: 2606 use_resource = get_variable_scope().use_resource 2607 if use_resource is None: 2608 use_resource = _DEFAULT_USE_RESOURCE 2609 use_resource = use_resource or context.executing_eagerly() 2610 if use_resource: 2611 distribute_strategy = kwargs.get("distribute_strategy", None) 2612 return resource_variable_ops.ResourceVariable( 2613 initial_value=initial_value, 2614 trainable=trainable, 2615 collections=collections, 2616 validate_shape=validate_shape, 2617 caching_device=caching_device, 2618 name=name, 2619 dtype=dtype, 2620 constraint=constraint, 2621 variable_def=variable_def, 2622 import_scope=import_scope, 2623 distribute_strategy=distribute_strategy, 2624 synchronization=synchronization, 2625 aggregation=aggregation, 2626 shape=shape) 2627 else: 2628 return variables.RefVariable( 2629 initial_value=initial_value, 2630 trainable=trainable, 2631 collections=collections, 2632 validate_shape=validate_shape, 2633 caching_device=caching_device, 2634 name=name, 2635 dtype=dtype, 2636 constraint=constraint, 2637 variable_def=variable_def, 2638 expected_shape=expected_shape, 2639 import_scope=import_scope, 2640 synchronization=synchronization, 2641 aggregation=aggregation, 2642 shape=shape) 2643 2644 2645def default_variable_creator_v2(next_creator=None, **kwargs): 2646 """Default variable creator.""" 2647 assert next_creator is None 2648 initial_value = kwargs.get("initial_value", None) 2649 trainable = kwargs.get("trainable", None) 2650 validate_shape = kwargs.get("validate_shape", True) 2651 caching_device = kwargs.get("caching_device", None) 2652 name = kwargs.get("name", None) 2653 variable_def = kwargs.get("variable_def", None) 2654 dtype = kwargs.get("dtype", None) 2655 import_scope = kwargs.get("import_scope", None) 2656 constraint = kwargs.get("constraint", None) 2657 distribute_strategy = kwargs.get("distribute_strategy", None) 2658 synchronization = kwargs.get("synchronization", None) 2659 aggregation = kwargs.get("aggregation", None) 2660 shape = kwargs.get("shape", None) 2661 2662 return resource_variable_ops.ResourceVariable( 2663 initial_value=initial_value, 2664 trainable=trainable, 2665 validate_shape=validate_shape, 2666 caching_device=caching_device, 2667 name=name, 2668 dtype=dtype, 2669 constraint=constraint, 2670 variable_def=variable_def, 2671 import_scope=import_scope, 2672 distribute_strategy=distribute_strategy, 2673 synchronization=synchronization, 2674 aggregation=aggregation, 2675 shape=shape) 2676 2677 2678variables.default_variable_creator = default_variable_creator 2679variables.default_variable_creator_v2 = default_variable_creator_v2 2680 2681 2682def _make_getter(captured_getter, captured_previous): 2683 """Gets around capturing loop variables in python being broken.""" 2684 return lambda **kwargs: captured_getter(captured_previous, **kwargs) 2685 2686 2687# TODO(apassos) remove forwarding symbol 2688variable = variables.VariableV1 2689 2690 2691@tf_export(v1=["variable_creator_scope"]) 2692@tf_contextlib.contextmanager 2693def variable_creator_scope_v1(variable_creator): 2694 """Scope which defines a variable creation function to be used by variable(). 2695 2696 variable_creator is expected to be a function with the following signature: 2697 2698 ``` 2699 def variable_creator(next_creator, **kwargs) 2700 ``` 2701 2702 The creator is supposed to eventually call the next_creator to create a 2703 variable if it does want to create a variable and not call Variable or 2704 ResourceVariable directly. This helps make creators composable. A creator may 2705 choose to create multiple variables, return already existing variables, or 2706 simply register that a variable was created and defer to the next creators in 2707 line. Creators can also modify the keyword arguments seen by the next 2708 creators. 2709 2710 Custom getters in the variable scope will eventually resolve down to these 2711 custom creators when they do create variables. 2712 2713 The valid keyword arguments in kwds are: 2714 2715 * initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 2716 which is the initial value for the Variable. The initial value must have 2717 a shape specified unless `validate_shape` is set to False. Can also be a 2718 callable with no argument that returns the initial value when called. In 2719 that case, `dtype` must be specified. (Note that initializer functions 2720 from init_ops.py must first be bound to a shape before being used here.) 2721 * trainable: If `True`, the default, also adds the variable to the graph 2722 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 2723 the default list of variables to use by the `Optimizer` classes. 2724 `trainable` defaults to `True`, unless `synchronization` is 2725 set to `ON_READ`, in which case it defaults to `False`. 2726 * collections: List of graph collections keys. The new variable is added to 2727 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 2728 * validate_shape: If `False`, allows the variable to be initialized with a 2729 value of unknown shape. If `True`, the default, the shape of 2730 `initial_value` must be known. 2731 * caching_device: Optional device string describing where the Variable 2732 should be cached for reading. Defaults to the Variable's device. 2733 If not `None`, caches on another device. Typical use is to cache 2734 on the device where the Ops using the Variable reside, to deduplicate 2735 copying through `Switch` and other conditional statements. 2736 * name: Optional name for the variable. Defaults to `'Variable'` and gets 2737 uniquified automatically. 2738 * dtype: If set, initial_value will be converted to the given type. 2739 If `None`, either the datatype will be kept (if `initial_value` is 2740 a Tensor), or `convert_to_tensor` will decide. 2741 * constraint: A constraint function to be applied to the variable after 2742 updates by some algorithms. 2743 * use_resource: if True, a ResourceVariable is always created. 2744 * synchronization: Indicates when a distributed a variable will be 2745 aggregated. Accepted values are constants defined in the class 2746 `tf.VariableSynchronization`. By default the synchronization is set to 2747 `AUTO` and the current `DistributionStrategy` chooses 2748 when to synchronize. 2749 * aggregation: Indicates how a distributed variable will be aggregated. 2750 Accepted values are constants defined in the class 2751 `tf.VariableAggregation`. 2752 2753 This set may grow over time, so it's important the signature of creators is as 2754 mentioned above. 2755 2756 Args: 2757 variable_creator: the passed creator 2758 2759 Yields: 2760 A scope in which the creator is active 2761 """ 2762 with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access 2763 yield 2764 2765 2766# Note: only the docstrings differ between this and v1. 2767@tf_export("variable_creator_scope", v1=[]) 2768@tf_contextlib.contextmanager 2769def variable_creator_scope(variable_creator): 2770 """Scope which defines a variable creation function to be used by variable(). 2771 2772 variable_creator is expected to be a function with the following signature: 2773 2774 ``` 2775 def variable_creator(next_creator, **kwargs) 2776 ``` 2777 2778 The creator is supposed to eventually call the next_creator to create a 2779 variable if it does want to create a variable and not call Variable or 2780 ResourceVariable directly. This helps make creators composable. A creator may 2781 choose to create multiple variables, return already existing variables, or 2782 simply register that a variable was created and defer to the next creators in 2783 line. Creators can also modify the keyword arguments seen by the next 2784 creators. 2785 2786 Custom getters in the variable scope will eventually resolve down to these 2787 custom creators when they do create variables. 2788 2789 The valid keyword arguments in kwds are: 2790 2791 * initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 2792 which is the initial value for the Variable. The initial value must have 2793 a shape specified unless `validate_shape` is set to False. Can also be a 2794 callable with no argument that returns the initial value when called. In 2795 that case, `dtype` must be specified. (Note that initializer functions 2796 from init_ops.py must first be bound to a shape before being used here.) 2797 * trainable: If `True`, the default, GradientTapes automatically watch 2798 uses of this Variable. 2799 * validate_shape: If `False`, allows the variable to be initialized with a 2800 value of unknown shape. If `True`, the default, the shape of 2801 `initial_value` must be known. 2802 * caching_device: Optional device string describing where the Variable 2803 should be cached for reading. Defaults to the Variable's device. 2804 If not `None`, caches on another device. Typical use is to cache 2805 on the device where the Ops using the Variable reside, to deduplicate 2806 copying through `Switch` and other conditional statements. 2807 * name: Optional name for the variable. Defaults to `'Variable'` and gets 2808 uniquified automatically. 2809 dtype: If set, initial_value will be converted to the given type. 2810 If `None`, either the datatype will be kept (if `initial_value` is 2811 a Tensor), or `convert_to_tensor` will decide. 2812 * constraint: A constraint function to be applied to the variable after 2813 updates by some algorithms. 2814 * synchronization: Indicates when a distributed a variable will be 2815 aggregated. Accepted values are constants defined in the class 2816 `tf.VariableSynchronization`. By default the synchronization is set to 2817 `AUTO` and the current `DistributionStrategy` chooses 2818 when to synchronize. 2819 * aggregation: Indicates how a distributed variable will be aggregated. 2820 Accepted values are constants defined in the class 2821 `tf.VariableAggregation`. 2822 2823 This set may grow over time, so it's important the signature of creators is as 2824 mentioned above. 2825 2826 Args: 2827 variable_creator: the passed creator 2828 2829 Yields: 2830 A scope in which the creator is active 2831 """ 2832 with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access 2833 yield 2834