1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A class to store named variables and a scope operator to manage sharing.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import enum # pylint: disable=g-bad-import-order 23import functools 24import sys 25import threading 26import traceback 27 28import six 29from six import iteritems 30from six.moves import xrange, zip # pylint: disable=redefined-builtin 31 32from tensorflow.python import tf2 33from tensorflow.python.client import session 34from tensorflow.python.eager import context 35from tensorflow.python.eager import monitoring 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_shape 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import init_ops 41from tensorflow.python.ops import resource_variable_ops 42from tensorflow.python.ops import variables 43from tensorflow.python.platform import tf_logging as logging 44from tensorflow.python.types import core 45from tensorflow.python.util import deprecation 46from tensorflow.python.util import function_utils 47from tensorflow.python.util import tf_contextlib 48from tensorflow.python.util import tf_inspect 49from tensorflow.python.util.compat import collections_abc 50from tensorflow.python.util.tf_export import tf_export 51 52__all__ = [ 53 "AUTO_REUSE", "VariableScope", "get_variable_scope", "get_variable", 54 "get_local_variable", "variable_scope", "variable_op_scope", 55 "no_regularizer", "VariableSynchronization", "VariableAggregation" 56] 57 58_api_usage_gauge = monitoring.BoolGauge( 59 "/tensorflow/api/resource_variables", 60 "Whether variable_scope.enable_resource_variables() is called.") 61 62 63class _PartitionInfo(object): 64 """Holds partition info used by initializer functions.""" 65 66 __slots__ = ["_full_shape", "_var_offset"] 67 68 def __init__(self, full_shape, var_offset): 69 """Constructor. 70 71 Args: 72 full_shape: Tuple or list of `int` indicating the full combined shape of 73 the partitioned variables. 74 var_offset: Tuple or list of `int` specifying offset of this partition 75 with respect to the full variable for each dimension. 76 77 Raises: 78 TypeError: If `full_shape` or `var_offset` is not a sequence. 79 ValueError: If `full_shape` or `var_offset` differ in length. If 80 `var_offset` exceeds `full_shape` in any dimension. 81 """ 82 if not isinstance(full_shape, collections_abc.Sequence) or isinstance( 83 full_shape, six.string_types): 84 raise TypeError( 85 "`full_shape` must be a sequence (like tuple or list) instead of " + 86 type(full_shape).__name__) 87 88 if not isinstance(var_offset, collections_abc.Sequence) or isinstance( 89 var_offset, six.string_types): 90 raise TypeError( 91 "`var_offset` must be a sequence (like tuple or list) instead of " + 92 type(var_offset).__name__) 93 94 if len(var_offset) != len(full_shape): 95 raise ValueError( 96 "Expected equal length, but `var_offset` is of length {} while " 97 "full_shape is of length {}.".format( 98 len(var_offset), len(full_shape))) 99 100 for offset, shape in zip(var_offset, full_shape): 101 if offset < 0 or offset >= shape: 102 raise ValueError( 103 "Expected 0 <= offset < shape but found offset={}, shape={} for " 104 "var_offset={}, full_shape={}".format(offset, shape, var_offset, 105 full_shape)) 106 107 self._full_shape = full_shape 108 self._var_offset = var_offset 109 110 @property 111 def full_shape(self): 112 return self._full_shape 113 114 @property 115 def var_offset(self): 116 return self._var_offset 117 118 def single_offset(self, shape): 119 """Returns the offset when the variable is partitioned in at most one dim. 120 121 Args: 122 shape: Tuple or list of `int` indicating the shape of one specific 123 variable partition. 124 125 Returns: 126 `int` representing the offset in the dimension along which the variable is 127 partitioned. Returns 0 if the variable is not being partitioned. 128 129 Raises: 130 ValueError: Depending on self.single_slice_dim(). 131 """ 132 133 single_slice_dim = self.single_slice_dim(shape) 134 # If this variable is not being partitioned at all, single_slice_dim() could 135 # return None. 136 if single_slice_dim is None: 137 return 0 138 return self.var_offset[single_slice_dim] 139 140 def single_slice_dim(self, shape): 141 """Returns the slice dim when the variable is partitioned only in one dim. 142 143 Args: 144 shape: Tuple or list of `int` indicating the shape of one specific 145 variable partition. 146 147 Returns: 148 `int` representing the dimension that the variable is partitioned in, or 149 `None` if the variable doesn't seem to be partitioned at all. 150 151 Raises: 152 TypeError: If `shape` is not a sequence. 153 ValueError: If `shape` is not the same length as `self.full_shape`. If 154 the variable is partitioned in more than one dimension. 155 """ 156 if not isinstance(shape, collections_abc.Sequence) or isinstance( 157 shape, six.string_types): 158 raise TypeError( 159 "`shape` must be a sequence (like tuple or list) instead of " + 160 type(shape).__name__) 161 162 if len(shape) != len(self.full_shape): 163 raise ValueError( 164 "Expected equal length, but received shape={} of length {} while " 165 "self.full_shape={} is of length {}.".format(shape, len(shape), 166 self.full_shape, 167 len(self.full_shape))) 168 169 for i in xrange(len(shape)): 170 if self.var_offset[i] + shape[i] > self.full_shape[i]: 171 raise ValueError( 172 "With self.var_offset={}, a partition of shape={} would exceed " 173 "self.full_shape={} in dimension {}.".format( 174 self.var_offset, shape, self.full_shape, i)) 175 176 slice_dim = None 177 for i in xrange(len(shape)): 178 if shape[i] == self.full_shape[i]: 179 continue 180 if slice_dim is not None: 181 raise ValueError( 182 "Cannot use single_slice_dim() with shape={} and " 183 "self.full_shape={} since slice dim could be either dimension {} " 184 "or {}.".format(shape, self.full_shape, i, slice_dim)) 185 slice_dim = i 186 187 return slice_dim 188 189 190class _ReuseMode(enum.Enum): 191 """Mode for variable access within a variable scope.""" 192 193 # Indicates that variables are to be fetched if they already exist or 194 # otherwise created. 195 AUTO_REUSE = 1 196 197 # TODO(alive): For TensorFlow 2.0, Deprecate True/False/None API in favor of 198 # enum values. 199 # REUSE_FALSE = 2 200 # REUSE_TRUE = 3 201 202 203# TODO(apassos) remove these forwarding symbols. 204VariableSynchronization = variables.VariableSynchronization # pylint: disable=invalid-name 205VariableAggregation = variables.VariableAggregation # pylint: disable=invalid-name 206 207AUTO_REUSE = _ReuseMode.AUTO_REUSE 208tf_export(v1=["AUTO_REUSE"]).export_constant(__name__, "AUTO_REUSE") 209AUTO_REUSE.__doc__ = """ 210When passed in as the value for the `reuse` flag, AUTO_REUSE indicates that 211get_variable() should create the requested variable if it doesn't exist or, if 212it does exist, simply return it. 213""" 214 215_DEFAULT_USE_RESOURCE = tf2.enabled() 216 217 218@tf_export(v1=["enable_resource_variables"]) 219def enable_resource_variables(): 220 """Creates resource variables by default. 221 222 Resource variables are improved versions of TensorFlow variables with a 223 well-defined memory model. Accessing a resource variable reads its value, and 224 all ops which access a specific read value of the variable are guaranteed to 225 see the same value for that tensor. Writes which happen after a read (by 226 having a control or data dependency on the read) are guaranteed not to affect 227 the value of the read tensor, and similarly writes which happen before a read 228 are guaranteed to affect the value. No guarantees are made about unordered 229 read/write pairs. 230 231 Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0 232 feature. 233 """ 234 global _DEFAULT_USE_RESOURCE 235 _DEFAULT_USE_RESOURCE = True 236 _api_usage_gauge.get_cell().set(True) 237 238 239@tf_export(v1=["resource_variables_enabled"]) 240def resource_variables_enabled(): 241 """Returns `True` if resource variables are enabled. 242 243 Resource variables are improved versions of TensorFlow variables with a 244 well-defined memory model. Accessing a resource variable reads its value, and 245 all ops which access a specific read value of the variable are guaranteed to 246 see the same value for that tensor. Writes which happen after a read (by 247 having a control or data dependency on the read) are guaranteed not to affect 248 the value of the read tensor, and similarly writes which happen before a read 249 are guaranteed to affect the value. No guarantees are made about unordered 250 read/write pairs. 251 252 Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0 253 feature. 254 """ 255 global _DEFAULT_USE_RESOURCE 256 return _DEFAULT_USE_RESOURCE 257 258 259@deprecation.deprecated( 260 None, "non-resource variables are not supported in the long term") 261@tf_export(v1=["disable_resource_variables"]) 262def disable_resource_variables(): 263 """Opts out of resource variables. 264 265 If your code needs tf.disable_resource_variables() to be called to work 266 properly please file a bug. 267 """ 268 global _DEFAULT_USE_RESOURCE 269 _DEFAULT_USE_RESOURCE = False 270 _api_usage_gauge.get_cell().set(False) 271 272 273def _needs_no_arguments(python_callable): 274 """Returns true if the callable needs no arguments to call.""" 275 # TODO(bfontain): Switch to inspect.signature when we are python 3 only. 276 # signature = inspect.signature(python_callable) 277 # return not [1 for param in signature.parameters.values() 278 # if param.default == param.empty] 279 num_arguments = len(tf_inspect.getargspec(python_callable).args) 280 if not tf_inspect.isfunction(python_callable) and not isinstance( 281 python_callable, functools.partial): 282 # getargspec includes self for function objects (which aren't 283 # functools.partial). This has no default so we need to remove it. 284 # It is not even an argument so its odd that getargspec returns this. 285 # Note that this is fixed with inspect.signature in Python 3. 286 num_arguments -= 1 287 return num_arguments == len( 288 tf_inspect.getargspec(python_callable).defaults or []) 289 290 291class _VariableStore(object): 292 """Variable store that carries a number of named Variables. 293 294 New variable names and new variables can be created; all stored 295 variables are initialized with the initializer passed to __init__. 296 297 Attributes: 298 vars: a dictionary with string names (same as passed in GetVar) as keys and 299 the corresponding TensorFlow Variables as values. 300 """ 301 302 __slots__ = ["_vars", "_partitioned_vars", "_store_eager_variables"] 303 304 def __init__(self): 305 """Create a variable store.""" 306 self._vars = {} # A dictionary of the stored TensorFlow variables. 307 self._partitioned_vars = {} # A dict of the stored PartitionedVariables. 308 self._store_eager_variables = False 309 310 def get_variable(self, 311 name, 312 shape=None, 313 dtype=dtypes.float32, 314 initializer=None, 315 regularizer=None, 316 reuse=None, 317 trainable=None, 318 collections=None, 319 caching_device=None, 320 partitioner=None, 321 validate_shape=True, 322 use_resource=None, 323 custom_getter=None, 324 constraint=None, 325 synchronization=VariableSynchronization.AUTO, 326 aggregation=VariableAggregation.NONE): 327 """Gets an existing variable with these parameters or create a new one. 328 329 If a variable with the given name is already stored, we return the stored 330 variable. Otherwise, we create a new one. 331 332 Set `reuse` to `True` when you only want to reuse existing Variables. 333 Set `reuse` to `False` when you only want to create new Variables. 334 Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want 335 variables to be created if they don't exist or returned if they do. 336 337 If initializer is `None` (the default), the default initializer passed in 338 the constructor is used. If that one is `None` too, we use a new 339 `glorot_uniform_initializer`. If initializer is a Tensor, we use 340 it as a value and derive the shape from the initializer. 341 342 If a partitioner is provided, a `PartitionedVariable` is returned. 343 Accessing this object as a `Tensor` returns the shards concatenated along 344 the partition axis. 345 346 Some useful partitioners are available. See, e.g., 347 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 348 349 Args: 350 name: The name of the new or existing variable. 351 shape: Shape of the new or existing variable. 352 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 353 initializer: Initializer for the variable. 354 regularizer: A (Tensor -> Tensor or None) function; the result of applying 355 it on a newly created variable will be added to the collection 356 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 357 reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of 358 variables. When eager execution is enabled this argument is always 359 forced to be False. 360 trainable: If `True` also add the variable to the graph collection 361 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). `trainable` 362 defaults to `True`, unless `synchronization` is set to `ON_READ`, in 363 which case it defaults to `False`. 364 collections: List of graph collections keys to add the `Variable` to. 365 Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 366 caching_device: Optional device string or function describing where the 367 Variable should be cached for reading. Defaults to the Variable's 368 device. If not `None`, caches on another device. Typical use is to 369 cache on the device where the Ops using the `Variable` reside, to 370 deduplicate copying through `Switch` and other conditional statements. 371 partitioner: Optional callable that accepts a fully defined `TensorShape` 372 and dtype of the `Variable` to be created, and returns a list of 373 partitions for each axis (currently only one axis can be partitioned). 374 validate_shape: If False, allows the variable to be initialized with a 375 value of unknown shape. If True, the default, the shape of initial_value 376 must be known. 377 use_resource: If False, creates a regular Variable. If True, creates 378 instead an experimental ResourceVariable which has well-defined 379 semantics. Defaults to False (will later change to True). When eager 380 execution is enabled this argument is always forced to be true. 381 custom_getter: Callable that takes as a first argument the true getter, 382 and allows overwriting the internal get_variable method. The signature 383 of `custom_getter` should match that of this method, 384 but the most future-proof version will allow for changes: `def 385 custom_getter(getter, *args, **kwargs)`. Direct access to 386 all `get_variable` parameters is also allowed: `def 387 custom_getter(getter, name, *args, **kwargs)`. A simple identity 388 custom getter that simply creates variables with modified names is: 389 ```python 390 def custom_getter(getter, name, *args, **kwargs): return getter(name + 391 '_suffix', *args, **kwargs) ``` 392 constraint: An optional projection function to be applied to the variable 393 after being updated by an `Optimizer` (e.g. used to implement norm 394 constraints or value constraints for layer weights). The function must 395 take as input the unprojected Tensor representing the value of the 396 variable and return the Tensor for the projected value (which must have 397 the same shape). Constraints are not safe to use when doing asynchronous 398 distributed training. 399 synchronization: Indicates when a distributed a variable will be 400 aggregated. Accepted values are constants defined in the class 401 `tf.VariableSynchronization`. By default the synchronization is set to 402 `AUTO` and the current `DistributionStrategy` chooses when to 403 synchronize. 404 aggregation: Indicates how a distributed variable will be aggregated. 405 Accepted values are constants defined in the class 406 `tf.VariableAggregation`. 407 408 Returns: 409 The created or existing `Variable` (or `PartitionedVariable`, if a 410 partitioner was used). 411 412 Raises: 413 ValueError: when creating a new variable and shape is not declared, 414 when reusing a variable and specifying a conflicting shape, 415 or when violating reuse during variable creation. 416 RuntimeError: when eager execution is enabled and not called from an 417 EagerVariableStore. 418 """ 419 if custom_getter is not None and not callable(custom_getter): 420 raise ValueError("Passed a custom_getter which is not callable: %s" % 421 custom_getter) 422 423 with ops.init_scope(): 424 if context.executing_eagerly(): 425 # Variable creation and initialization takes place in `init_scope`s; 426 # as such, if an `init_scope` lifts us into the eager context, then we 427 # need to use `ResourceVariable`s. 428 use_resource = True 429 430 # Note that it's fine to reuse eager variables whose initialization was 431 # lifted from a function-building graph into the eager context (that's why 432 # the following clause is not wrapped in an `init_scope`); lifted variables 433 # are tracked by the graph's `VariableStore`. 434 if context.executing_eagerly(): 435 if not self._store_eager_variables and reuse: 436 raise RuntimeError( 437 "When eager execution is enabled variable reuse is only supported" 438 " when an EagerVariableStore is active. See the documentation on" 439 " EagerVariableStore for example usage.") 440 if self._store_eager_variables: 441 reuse = AUTO_REUSE 442 443 # If a *_ref type is passed in an error would be triggered further down the 444 # stack. We prevent this using base_dtype to get a non-ref version of the 445 # type, before doing anything else. When _ref types are removed in favor of 446 # resources, this line can be removed. 447 try: 448 dtype = dtype.base_dtype 449 except AttributeError: 450 # .base_dtype not existing means that we will try and use the raw dtype 451 # which was passed in - this might be a NumPy type which is valid. 452 pass 453 454 # This is the main logic of get_variable. However, custom_getter 455 # may override this logic. So we save it as a callable and pass 456 # it to custom_getter. 457 # Note: the parameters of _true_getter, and their documentation, match 458 # *exactly* item-for-item with the docstring of this method. 459 def _true_getter( # pylint: disable=missing-docstring 460 name, 461 shape=None, 462 dtype=dtypes.float32, 463 initializer=None, 464 regularizer=None, 465 reuse=None, 466 trainable=None, 467 collections=None, 468 caching_device=None, 469 partitioner=None, 470 validate_shape=True, 471 use_resource=None, 472 constraint=None, 473 synchronization=VariableSynchronization.AUTO, 474 aggregation=VariableAggregation.NONE): 475 is_scalar = ( 476 shape is not None and isinstance(shape, collections_abc.Sequence) and 477 not shape) 478 # Partitioned variable case 479 if partitioner is not None and not is_scalar: 480 if not callable(partitioner): 481 raise ValueError("Partitioner must be callable, but received: %s" % 482 partitioner) 483 with ops.name_scope(None): 484 return self._get_partitioned_variable( 485 name=name, 486 shape=shape, 487 dtype=dtype, 488 initializer=initializer, 489 regularizer=regularizer, 490 reuse=reuse, 491 trainable=trainable, 492 collections=collections, 493 caching_device=caching_device, 494 partitioner=partitioner, 495 validate_shape=validate_shape, 496 use_resource=use_resource, 497 constraint=constraint, 498 synchronization=synchronization, 499 aggregation=aggregation) 500 501 # Special case for partitioned variable to allow reuse without having to 502 # specify partitioner. 503 if (reuse is True and partitioner is None 504 and name in self._partitioned_vars): 505 return self._get_partitioned_variable( 506 name=name, 507 shape=shape, 508 dtype=dtype, 509 initializer=initializer, 510 regularizer=regularizer, 511 reuse=reuse, 512 trainable=trainable, 513 collections=collections, 514 caching_device=caching_device, 515 partitioner=None, 516 validate_shape=validate_shape, 517 use_resource=use_resource, 518 constraint=constraint, 519 synchronization=synchronization, 520 aggregation=aggregation) 521 522 # Single variable case 523 if "%s/part_0" % name in self._vars: 524 raise ValueError( 525 "No partitioner was provided, but a partitioned version of the " 526 "variable was found: %s/part_0. Perhaps a variable of the same " 527 "name was already created with partitioning?" % name) 528 529 return self._get_single_variable( 530 name=name, 531 shape=shape, 532 dtype=dtype, 533 initializer=initializer, 534 regularizer=regularizer, 535 reuse=reuse, 536 trainable=trainable, 537 collections=collections, 538 caching_device=caching_device, 539 validate_shape=validate_shape, 540 use_resource=use_resource, 541 constraint=constraint, 542 synchronization=synchronization, 543 aggregation=aggregation) 544 545 synchronization, aggregation, trainable = ( 546 variables.validate_synchronization_aggregation_trainable( 547 synchronization, aggregation, trainable, name)) 548 549 if custom_getter is not None: 550 # Handle backwards compatibility with getter arguments that were added 551 # to the API after users started writing custom getters. 552 custom_getter_kwargs = { 553 "getter": _true_getter, 554 "name": name, 555 "shape": shape, 556 "dtype": dtype, 557 "initializer": initializer, 558 "regularizer": regularizer, 559 "reuse": reuse, 560 "trainable": trainable, 561 "collections": collections, 562 "caching_device": caching_device, 563 "partitioner": partitioner, 564 "validate_shape": validate_shape, 565 "use_resource": use_resource, 566 "synchronization": synchronization, 567 "aggregation": aggregation, 568 } 569 # `fn_args` and `has_kwargs` can handle functions, `functools.partial`, 570 # `lambda`. 571 if ("constraint" in function_utils.fn_args(custom_getter) or 572 function_utils.has_kwargs(custom_getter)): 573 custom_getter_kwargs["constraint"] = constraint 574 return custom_getter(**custom_getter_kwargs) 575 else: 576 return _true_getter( 577 name, 578 shape=shape, 579 dtype=dtype, 580 initializer=initializer, 581 regularizer=regularizer, 582 reuse=reuse, 583 trainable=trainable, 584 collections=collections, 585 caching_device=caching_device, 586 partitioner=partitioner, 587 validate_shape=validate_shape, 588 use_resource=use_resource, 589 constraint=constraint, 590 synchronization=synchronization, 591 aggregation=aggregation) 592 593 def _get_partitioned_variable(self, 594 name, 595 partitioner, 596 shape=None, 597 dtype=dtypes.float32, 598 initializer=None, 599 regularizer=None, 600 reuse=None, 601 trainable=None, 602 collections=None, 603 caching_device=None, 604 validate_shape=True, 605 use_resource=None, 606 constraint=None, 607 synchronization=VariableSynchronization.AUTO, 608 aggregation=VariableAggregation.NONE): 609 """Gets or creates a sharded variable list with these parameters. 610 611 The `partitioner` must be a callable that accepts a fully defined 612 `TensorShape` and returns a sequence of integers (the `partitions`). 613 These integers describe how to partition the given sharded `Variable` 614 along the given dimension. That is, `partitions[1] = 3` means split 615 the `Variable` into 3 shards along dimension 1. Currently, sharding along 616 only one axis is supported. 617 618 If the list of variables with the given name (prefix) is already stored, 619 we return the stored variables. Otherwise, we create a new one. 620 621 Set `reuse` to `True` when you only want to reuse existing Variables. 622 Set `reuse` to `False` when you only want to create new Variables. 623 Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want 624 variables to be created if they don't exist or returned if they do. 625 626 If initializer is `None` (the default), the default initializer passed in 627 the constructor is used. If that one is `None` too, we use a new 628 `glorot_uniform_initializer`. If initializer is a Tensor, we use 629 it as a value and derive the shape from the initializer. 630 631 If the initializer is a callable, then it will be called for each 632 shard. Otherwise the initializer should match the shape of the entire 633 sharded Variable, and it will be sliced accordingly for each shard. 634 635 Some useful partitioners are available. See, e.g., 636 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 637 638 Args: 639 name: the name of the new or existing sharded variable. 640 partitioner: Optional callable that accepts a fully defined `TensorShape` 641 and `dtype` of the Variable to be created, and returns a list of 642 partitions for each axis (currently only one axis can be partitioned). 643 shape: shape of the new or existing sharded variable. 644 dtype: type of the new or existing sharded variable (defaults to 645 `DT_FLOAT`). 646 initializer: initializer for the sharded variable. 647 regularizer: a (Tensor -> Tensor or None) function; the result of applying 648 it on a newly created variable will be added to the collection 649 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 650 reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of 651 variables. 652 trainable: If `True` also add the variable to the graph collection 653 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 654 collections: List of graph collections keys to add the Variable to. 655 Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 656 caching_device: Optional device string or function describing where the 657 Variable should be cached for reading. Defaults to the Variable's 658 device. If not `None`, caches on another device. Typical use is to 659 cache on the device where the Ops using the Variable reside, to 660 deduplicate copying through `Switch` and other conditional statements. 661 validate_shape: If False, allows the variable to be initialized with a 662 value of unknown shape. If True, the default, the shape of initial_value 663 must be known. 664 use_resource: If False, creates a regular Variable. If True, creates an 665 experimental ResourceVariable which has well-defined semantics. Defaults 666 to False (will later change to True). 667 constraint: An optional projection function to be applied to the variable 668 after being updated by an `Optimizer` (e.g. used to implement norm 669 constraints or value constraints for layer weights). The function must 670 take as input the unprojected Tensor representing the value of the 671 variable and return the Tensor for the projected value (which must have 672 the same shape). Constraints are not safe to use when doing asynchronous 673 distributed training. 674 synchronization: Indicates when a distributed a variable will be 675 aggregated. Accepted values are constants defined in the class 676 `tf.VariableSynchronization`. By default the synchronization is set to 677 `AUTO` and the current `DistributionStrategy` chooses when to 678 synchronize. 679 aggregation: Indicates how a distributed variable will be aggregated. 680 Accepted values are constants defined in the class 681 `tf.VariableAggregation`. 682 683 Returns: 684 A `PartitionedVariable` object. 685 686 Raises: 687 ValueError: when creating a new variable and shape is not declared, 688 when reusing a variable and specifying a conflicting shape, 689 when violating reuse during variable creation, or if an existing 690 sharded variable exists for the given name but with different sharding. 691 """ 692 initializing_from_value = initializer is not None and isinstance( 693 initializer, ops.Tensor) 694 if name in self._vars: 695 raise ValueError( 696 "A partitioner was provided, but an unpartitioned version of the " 697 "variable was found: %s. Perhaps a variable of the same name was " 698 "already created without partitioning?" % name) 699 700 shape = tensor_shape.as_shape(shape) 701 if initializing_from_value: 702 shape = shape.merge_with(initializer.get_shape()) 703 704 partitions = None 705 if not reuse or partitioner: 706 partitions = _call_partitioner(partitioner, shape, dtype) 707 708 if name in self._partitioned_vars: 709 if reuse is False: 710 raise ValueError( 711 "Partitioned variable with name %s already exists. Did you mean to " 712 "set reuse=True or reuse=tf.AUTO_REUSE in VarScope?" % name) 713 714 existing_var = self._partitioned_vars[name] 715 if not shape.is_compatible_with(existing_var.get_shape()): 716 raise ValueError( 717 "Trying to reuse partitioned variable %s, but specified shape %s " 718 "and found shape %s." % (name, shape, existing_var.get_shape())) 719 if not dtype.is_compatible_with(existing_var.dtype): 720 raise ValueError( 721 "Trying to reuse partitioned variable %s, but specified dtype %s " 722 "and found dtype %s." % (name, dtype.name, existing_var.dtype.name)) 723 724 # pylint: disable=protected-access 725 if (partitions is not None and 726 existing_var._get_partitions() != partitions): 727 raise ValueError( 728 "Trying to reuse partitioned variable %s, but specified partitions " 729 "%s and found partitions %s." % 730 (name, partitions, existing_var._get_partitions())) 731 # pylint: enable=protected-access 732 733 return existing_var 734 735 if reuse is True: 736 raise ValueError("PartitionedVariable %s does not exist, or was not " 737 "created with tf.get_variable(). Did you mean to set " 738 "reuse=False or reuse=tf.AUTO_REUSE in VarScope?" % name) 739 740 slice_dim, num_slices = _get_slice_dim_and_num_slices(partitions) 741 742 if "%s/part_0" % name in self._vars: 743 if "%s/part_%d" % (name, num_slices - 1) not in self._vars: 744 raise ValueError( 745 "Partitioner returned a different partitioning than what was " 746 "already found. Partitioner returned %d shards, and shard " 747 "%s/part_0 was found, but %s/part_%d was not." % 748 (num_slices, name, name, num_slices - 1)) 749 if "%s/part_%d" % (name, num_slices) in self._vars: 750 raise ValueError( 751 "Partitioner returned a different partitioning than what was " 752 "already found. Partitioner returned %d shards, and shard " 753 "%s/part_0 was found, but so was the extra shard %s/part_%d." % 754 (num_slices, name, name, num_slices)) 755 756 vs = [] 757 for i, (var_offset, var_shape) in enumerate( 758 _iter_slices(shape.as_list(), num_slices, slice_dim)): 759 partition_info = _PartitionInfo( 760 full_shape=shape.as_list(), var_offset=var_offset) 761 var_full_name = "%s/part_%d" % (name, i) 762 with ops.name_scope( 763 var_full_name + "/PartitionedInitializer", skip_on_eager=False): 764 # Create the tensor to initialize the variable with default value. 765 if initializer is None: 766 init, initializing_from_value = self._get_default_initializer( 767 name=name, shape=shape, dtype=dtype) 768 if initializing_from_value: 769 init_shape = None 770 else: 771 init_shape = var_shape 772 elif callable(initializer): 773 init = initializer 774 init_shape = var_shape 775 elif isinstance(initializer, ops.Tensor): 776 init = array_ops.slice(initializer, var_offset, var_shape) 777 # Use the dtype of the given tensor. 778 dtype = init.dtype.base_dtype 779 init_shape = None 780 else: 781 init = ops.convert_to_tensor(initializer, dtype=dtype) 782 init = array_ops.slice(init, var_offset, var_shape) 783 init_shape = None 784 785 with ops.name_scope(None): 786 var = self._get_single_variable( 787 name=var_full_name, 788 shape=init_shape, 789 dtype=dtype, 790 initializer=init, 791 partition_info=partition_info, 792 regularizer=regularizer, 793 reuse=reuse, 794 trainable=trainable, 795 collections=collections, 796 caching_device=caching_device, 797 validate_shape=validate_shape, 798 use_resource=use_resource, 799 constraint=constraint, 800 synchronization=synchronization, 801 aggregation=aggregation) 802 803 # pylint: disable=protected-access 804 var._set_save_slice_info( 805 variables.Variable.SaveSliceInfo(name, shape.as_list(), var_offset, 806 var_shape)) 807 vs.append(var) 808 # pylint: enable=protected-access 809 810 partitioned_var = variables.PartitionedVariable( 811 name=name, 812 shape=shape, 813 dtype=dtype, 814 variable_list=vs, 815 partitions=partitions) 816 if not context.executing_eagerly() or self._store_eager_variables: 817 self._partitioned_vars[name] = partitioned_var 818 return partitioned_var 819 820 def _get_single_variable(self, 821 name, 822 shape=None, 823 dtype=dtypes.float32, 824 initializer=None, 825 regularizer=None, 826 partition_info=None, 827 reuse=None, 828 trainable=None, 829 collections=None, 830 caching_device=None, 831 validate_shape=True, 832 use_resource=None, 833 constraint=None, 834 synchronization=VariableSynchronization.AUTO, 835 aggregation=VariableAggregation.NONE): 836 """Get or create a single Variable (e.g. 837 838 a shard or entire variable). 839 840 See the documentation of get_variable above (ignore partitioning components) 841 for details. 842 843 Args: 844 name: see get_variable. 845 shape: see get_variable. 846 dtype: see get_variable. 847 initializer: see get_variable. 848 regularizer: see get_variable. 849 partition_info: _PartitionInfo object. 850 reuse: see get_variable. 851 trainable: see get_variable. 852 collections: see get_variable. 853 caching_device: see get_variable. 854 validate_shape: see get_variable. 855 use_resource: see get_variable. 856 constraint: see get_variable. 857 synchronization: see get_variable. 858 aggregation: see get_variable. 859 860 Returns: 861 A Variable. See documentation of get_variable above. 862 863 Raises: 864 ValueError: See documentation of get_variable above. 865 """ 866 # Set to true if initializer is a constant. 867 initializing_from_value = False 868 if initializer is not None and not callable(initializer): 869 initializing_from_value = True 870 if shape is not None and initializing_from_value: 871 raise ValueError("If initializer is a constant, do not specify shape.") 872 873 dtype = dtypes.as_dtype(dtype) 874 shape = tensor_shape.as_shape(shape) 875 876 if name in self._vars: 877 # Here we handle the case when returning an existing variable. 878 if reuse is False: 879 var = self._vars[name] 880 err_msg = ("Variable %s already exists, disallowed." 881 " Did you mean to set reuse=True or " 882 "reuse=tf.AUTO_REUSE in VarScope?" % name) 883 # ResourceVariables don't have an op associated with so no traceback 884 if isinstance(var, resource_variable_ops.ResourceVariable): 885 raise ValueError(err_msg) 886 tb = var.op.traceback[::-1] 887 # Throw away internal tf entries and only take a few lines. In some 888 # cases the traceback can be longer (e.g. if someone uses factory 889 # functions to create variables) so we take more than needed in the 890 # default case. 891 tb = [x for x in tb if "tensorflow/python" not in x[0]][:5] 892 raise ValueError("%s Originally defined at:\n\n%s" % 893 (err_msg, "".join(traceback.format_list(tb)))) 894 found_var = self._vars[name] 895 if not shape.is_compatible_with(found_var.get_shape()): 896 raise ValueError("Trying to share variable %s, but specified shape %s" 897 " and found shape %s." % 898 (name, shape, found_var.get_shape())) 899 if not dtype.is_compatible_with(found_var.dtype): 900 dtype_str = dtype.name 901 found_type_str = found_var.dtype.name 902 raise ValueError("Trying to share variable %s, but specified dtype %s" 903 " and found dtype %s." % 904 (name, dtype_str, found_type_str)) 905 return found_var 906 907 # The code below handles only the case of creating a new variable. 908 if reuse is True: 909 raise ValueError("Variable %s does not exist, or was not created with " 910 "tf.get_variable(). Did you mean to set " 911 "reuse=tf.AUTO_REUSE in VarScope?" % name) 912 913 # Create the tensor to initialize the variable with default value. 914 if initializer is None: 915 initializer, initializing_from_value = self._get_default_initializer( 916 name=name, shape=shape, dtype=dtype) 917 # Enter an init scope when creating the initializer. 918 with ops.init_scope(): 919 if initializing_from_value: 920 init_val = initializer 921 variable_dtype = None 922 else: 923 # Instantiate initializer if provided initializer is a type object. 924 if tf_inspect.isclass(initializer): 925 initializer = initializer() 926 if shape.is_fully_defined(): 927 if "partition_info" in tf_inspect.getargspec(initializer).args: 928 init_val = functools.partial(initializer, 929 shape.as_list(), 930 dtype=dtype, 931 partition_info=partition_info) 932 else: 933 init_val = functools.partial(initializer, 934 shape.as_list(), dtype=dtype) 935 variable_dtype = dtype.base_dtype 936 elif _needs_no_arguments(initializer): 937 init_val = initializer 938 variable_dtype = None 939 else: 940 raise ValueError("The initializer passed is not valid. It should " 941 "be a callable with no arguments and the " 942 "shape should not be provided or an instance of " 943 "`tf.keras.initializers.*' and `shape` should be " 944 "fully defined.") 945 946 # Create the variable. 947 if use_resource is None: 948 # Set the default value if unspecified. 949 use_resource = _DEFAULT_USE_RESOURCE 950 v = variables.VariableV1( 951 initial_value=init_val, 952 name=name, 953 trainable=trainable, 954 collections=collections, 955 caching_device=caching_device, 956 dtype=variable_dtype, 957 validate_shape=validate_shape, 958 constraint=constraint, 959 use_resource=use_resource, 960 synchronization=synchronization, 961 aggregation=aggregation) 962 if context.executing_eagerly() and self._store_eager_variables: 963 if collections: 964 ops.add_to_collections(collections, v) 965 else: 966 ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, v) 967 if trainable: 968 ops.add_to_collection(ops.GraphKeys.TRAINABLE_VARIABLES, v) 969 970 if not context.executing_eagerly() or self._store_eager_variables: 971 # In eager mode we do not want to keep default references to Variable 972 # objects as this will prevent their memory from being released. 973 self._vars[name] = v 974 logging.vlog(1, "Created variable %s with shape %s and init %s", v.name, 975 format(shape), initializer) 976 977 # Run the regularizer if requested and save the resulting loss. 978 if regularizer: 979 def make_regularizer_op(): 980 with ops.colocate_with(v): 981 with ops.name_scope(name + "/Regularizer/"): 982 return regularizer(v) 983 984 if regularizer(v) is not None: 985 lazy_eval_tensor = _LazyEvalTensor(make_regularizer_op) 986 ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, 987 lazy_eval_tensor) 988 989 return v 990 991 # Initialize variable when no initializer provided 992 def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32): 993 """Provide a default initializer and a corresponding value. 994 995 Args: 996 name: see get_variable. 997 shape: see get_variable. 998 dtype: see get_variable. 999 1000 Returns: 1001 initializer and initializing_from_value. See get_variable above. 1002 1003 Raises: 1004 ValueError: When giving unsupported dtype. 1005 """ 1006 del shape 1007 # If dtype is DT_FLOAT, provide a uniform unit scaling initializer 1008 if dtype.is_floating: 1009 initializer = init_ops.glorot_uniform_initializer() 1010 initializing_from_value = False 1011 # If dtype is DT_INT/DT_UINT, provide a default value `zero` 1012 # If dtype is DT_BOOL, provide a default value `FALSE` 1013 elif (dtype.is_integer or dtype.is_unsigned or dtype.is_bool or 1014 dtype == dtypes.string): 1015 initializer = init_ops.zeros_initializer() 1016 initializing_from_value = False 1017 # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here? 1018 else: 1019 raise ValueError("An initializer for variable %s of %s is required" % 1020 (name, dtype.base_dtype)) 1021 1022 return initializer, initializing_from_value 1023 1024 1025class _LazyEvalTensor(core.Tensor): 1026 """A Tensor-like object that only evaluates its thunk when used.""" 1027 1028 def __init__(self, thunk): 1029 """Initializes a _LazyEvalTensor object. 1030 1031 Args: 1032 thunk: A callable. A thunk which computes the value of the tensor. 1033 """ 1034 self._thunk = thunk 1035 self._master_tensor = thunk() 1036 1037 def _as_tensor(self, dtype=None, name=None, as_ref=False): 1038 del name 1039 assert not as_ref 1040 assert dtype in [None, self.dtype] 1041 1042 return self._thunk() 1043 1044 1045def _make_master_property(name): 1046 @property 1047 def prop(self): 1048 return getattr(self._master_tensor, name) # pylint: disable=protected-access 1049 return prop 1050 1051_master_property_list = ("device", "dtype", "graph", "name", "op", "shape", 1052 "value_index") 1053for _name in _master_property_list: 1054 setattr(_LazyEvalTensor, _name, _make_master_property(_name)) 1055 1056 1057def _make_master_method(name): 1058 def method(self, *args, **kwargs): 1059 return getattr(self._master_tensor, name)(*args, **kwargs) # pylint: disable=protected-access 1060 return method 1061 1062_master_method_list = ("get_shape", "__str__", "shape_as_list") 1063for _name in _master_method_list: 1064 setattr(_LazyEvalTensor, _name, _make_master_method(_name)) 1065 1066 1067def _make_op_method(name): 1068 def method(self, *args, **kwargs): 1069 return getattr(self._as_tensor(), name)(*args, **kwargs) # pylint: disable=protected-access 1070 return method 1071 1072_op_list = ("__abs__", "__add__", "__and__", "__bool__", "__div__", "__eq__", 1073 "__floordiv__", "__ge__", "__getitem__", "__gt__", "__invert__", 1074 "__iter__", "__le__", "__len__", "__lt__", "__matmul__", "__mod__", 1075 "__mul__", "__ne__", "__neg__", "__nonzero__", "__or__", "__pow__", 1076 "__radd__", "__rand__", "__rdiv__", "__rfloordiv__", "__rmatmul__", 1077 "__rmod__", "__rmul__", "__ror__", "__rpow__", "__rsub__", 1078 "__rtruediv__", "__rxor__", "__sub__", "__truediv__", "__xor__", 1079 "eval", "numpy") 1080for _name in _op_list: 1081 setattr(_LazyEvalTensor, _name, _make_op_method(_name)) 1082 1083 1084ops.register_tensor_conversion_function( 1085 _LazyEvalTensor, 1086 lambda val, dtype, name, as_ref: val._as_tensor(dtype, name, as_ref) # pylint: disable=protected-access 1087 ) 1088 1089session.register_session_run_conversion_functions( 1090 _LazyEvalTensor, 1091 lambda fetch: ([fetch._master_tensor], lambda fetched_vals: fetched_vals[0]) # pylint: disable=protected-access 1092 ) 1093 1094 1095# To stop regularization, use this regularizer 1096@tf_export(v1=["no_regularizer"]) 1097def no_regularizer(_): 1098 """Use this function to prevent regularization of variables.""" 1099 return None 1100 1101 1102# TODO(alive): support caching devices and partitioned variables in Eager mode. 1103@tf_export(v1=["VariableScope"]) 1104class VariableScope(object): 1105 """Variable scope object to carry defaults to provide to `get_variable`. 1106 1107 Many of the arguments we need for `get_variable` in a variable store are most 1108 easily handled with a context. This object is used for the defaults. 1109 1110 Attributes: 1111 name: name of the current scope, used as prefix in get_variable. 1112 initializer: default initializer passed to get_variable. 1113 regularizer: default regularizer passed to get_variable. 1114 reuse: Boolean, None, or tf.compat.v1.AUTO_REUSE, setting the reuse in 1115 get_variable. When eager execution is enabled this argument is always 1116 forced to be False. 1117 caching_device: string, callable, or None: the caching device passed to 1118 get_variable. 1119 partitioner: callable or `None`: the partitioner passed to `get_variable`. 1120 custom_getter: default custom getter passed to get_variable. 1121 name_scope: The name passed to `tf.name_scope`. 1122 dtype: default type passed to get_variable (defaults to DT_FLOAT). 1123 use_resource: if False, create a normal Variable; if True create an 1124 experimental ResourceVariable with well-defined semantics. Defaults to 1125 False (will later change to True). When eager execution is enabled this 1126 argument is always forced to be True. 1127 constraint: An optional projection function to be applied to the variable 1128 after being updated by an `Optimizer` (e.g. used to implement norm 1129 constraints or value constraints for layer weights). The function must 1130 take as input the unprojected Tensor representing the value of the 1131 variable and return the Tensor for the projected value (which must have 1132 the same shape). Constraints are not safe to use when doing asynchronous 1133 distributed training. 1134 """ 1135 1136 def __init__(self, 1137 reuse, 1138 name="", 1139 initializer=None, 1140 regularizer=None, 1141 caching_device=None, 1142 partitioner=None, 1143 custom_getter=None, 1144 name_scope="", 1145 dtype=dtypes.float32, 1146 use_resource=None, 1147 constraint=None): 1148 """Creates a new VariableScope with the given properties.""" 1149 self._name = name 1150 self._initializer = initializer 1151 self._regularizer = regularizer 1152 self._reuse = reuse 1153 self._caching_device = caching_device 1154 self._partitioner = partitioner 1155 self._custom_getter = custom_getter 1156 self._name_scope = name_scope 1157 self._dtype = dtype 1158 self._use_resource = use_resource 1159 self._constraint = constraint 1160 if context.executing_eagerly(): 1161 if self._caching_device is not None: 1162 raise NotImplementedError("Caching devices is not yet supported " 1163 "when eager execution is enabled.") 1164 self._reuse = AUTO_REUSE 1165 self._use_resource = True 1166 1167 @property 1168 def name(self): 1169 return self._name 1170 1171 @property 1172 def original_name_scope(self): 1173 return self._name_scope 1174 1175 @property 1176 def reuse(self): 1177 return self._reuse 1178 1179 @property 1180 def initializer(self): 1181 return self._initializer 1182 1183 @property 1184 def dtype(self): 1185 return self._dtype 1186 1187 @property 1188 def use_resource(self): 1189 return self._use_resource 1190 1191 @property 1192 def regularizer(self): 1193 return self._regularizer 1194 1195 @property 1196 def caching_device(self): 1197 return self._caching_device 1198 1199 @property 1200 def partitioner(self): 1201 return self._partitioner 1202 1203 @property 1204 def custom_getter(self): 1205 return self._custom_getter 1206 1207 @property 1208 def constraint(self): 1209 return self._constraint 1210 1211 def reuse_variables(self): 1212 """Reuse variables in this scope.""" 1213 self._reuse = True 1214 1215 def set_initializer(self, initializer): 1216 """Set initializer for this scope.""" 1217 self._initializer = initializer 1218 1219 def set_dtype(self, dtype): 1220 """Set data type for this scope.""" 1221 self._dtype = dtype 1222 1223 def set_use_resource(self, use_resource): 1224 """Sets whether to use ResourceVariables for this scope.""" 1225 if context.executing_eagerly() and not use_resource: 1226 raise ValueError("When eager execution is enabled, " 1227 "use_resource cannot be set to false.") 1228 self._use_resource = use_resource 1229 1230 def set_regularizer(self, regularizer): 1231 """Set regularizer for this scope.""" 1232 self._regularizer = regularizer 1233 1234 def set_caching_device(self, caching_device): 1235 """Set caching_device for this scope.""" 1236 if context.executing_eagerly(): 1237 raise NotImplementedError("Caching devices are not yet supported " 1238 "when eager execution is enabled.") 1239 self._caching_device = caching_device 1240 1241 def set_partitioner(self, partitioner): 1242 """Set partitioner for this scope.""" 1243 self._partitioner = partitioner 1244 1245 def set_custom_getter(self, custom_getter): 1246 """Set custom getter for this scope.""" 1247 self._custom_getter = custom_getter 1248 1249 def get_collection(self, name): 1250 """Get this scope's variables.""" 1251 scope = self._name + "/" if self._name else "" 1252 return ops.get_collection(name, scope) 1253 1254 def trainable_variables(self): 1255 """Get this scope's trainable variables.""" 1256 return self.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) 1257 1258 def global_variables(self): 1259 """Get this scope's global variables.""" 1260 return self.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 1261 1262 def local_variables(self): 1263 """Get this scope's local variables.""" 1264 return self.get_collection(ops.GraphKeys.LOCAL_VARIABLES) 1265 1266 def get_variable(self, 1267 var_store, 1268 name, 1269 shape=None, 1270 dtype=None, 1271 initializer=None, 1272 regularizer=None, 1273 reuse=None, 1274 trainable=None, 1275 collections=None, 1276 caching_device=None, 1277 partitioner=None, 1278 validate_shape=True, 1279 use_resource=None, 1280 custom_getter=None, 1281 constraint=None, 1282 synchronization=VariableSynchronization.AUTO, 1283 aggregation=VariableAggregation.NONE): 1284 """Gets an existing variable with this name or create a new one.""" 1285 if regularizer is None: 1286 regularizer = self._regularizer 1287 if caching_device is None: 1288 caching_device = self._caching_device 1289 if partitioner is None: 1290 partitioner = self._partitioner 1291 if custom_getter is None: 1292 custom_getter = self._custom_getter 1293 if context.executing_eagerly(): 1294 reuse = False 1295 use_resource = True 1296 else: 1297 if reuse is None: 1298 reuse = self._reuse 1299 if use_resource is None: 1300 use_resource = self._use_resource 1301 1302 full_name = self.name + "/" + name if self.name else name 1303 # Variable names only depend on variable_scope (full_name here), 1304 # not name_scope, so we reset it below for the time of variable creation. 1305 with ops.name_scope(None, skip_on_eager=False): 1306 # Check that `initializer` dtype and `dtype` are consistent before 1307 # replacing them with defaults. 1308 if (dtype is not None and initializer is not None and 1309 not callable(initializer)): 1310 init_dtype = ops.convert_to_tensor(initializer).dtype.base_dtype 1311 if init_dtype != dtype: 1312 raise ValueError("Initializer type '%s' and explicit dtype '%s' " 1313 "don't match." % (init_dtype, dtype)) 1314 if initializer is None: 1315 initializer = self._initializer 1316 if constraint is None: 1317 constraint = self._constraint 1318 if dtype is None: 1319 dtype = self._dtype 1320 return var_store.get_variable( 1321 full_name, 1322 shape=shape, 1323 dtype=dtype, 1324 initializer=initializer, 1325 regularizer=regularizer, 1326 reuse=reuse, 1327 trainable=trainable, 1328 collections=collections, 1329 caching_device=caching_device, 1330 partitioner=partitioner, 1331 validate_shape=validate_shape, 1332 use_resource=use_resource, 1333 custom_getter=custom_getter, 1334 constraint=constraint, 1335 synchronization=synchronization, 1336 aggregation=aggregation) 1337 1338 def _get_partitioned_variable(self, 1339 var_store, 1340 name, 1341 shape=None, 1342 dtype=None, 1343 initializer=None, 1344 regularizer=None, 1345 trainable=None, 1346 collections=None, 1347 caching_device=None, 1348 partitioner=None, 1349 validate_shape=True, 1350 use_resource=None, 1351 constraint=None, 1352 synchronization=VariableSynchronization.AUTO, 1353 aggregation=VariableAggregation.NONE): 1354 """Gets an existing variable with this name or create a new one.""" 1355 if initializer is None: 1356 initializer = self._initializer 1357 if regularizer is None: 1358 regularizer = self._regularizer 1359 if constraint is None: 1360 constraint = self._constraint 1361 if caching_device is None: 1362 caching_device = self._caching_device 1363 if partitioner is None: 1364 partitioner = self._partitioner 1365 if dtype is None: 1366 dtype = self._dtype 1367 if use_resource is None: 1368 use_resource = self._use_resource 1369 1370 if self._custom_getter is not None: 1371 raise ValueError( 1372 "Private access to _get_partitioned_variable is not allowed when " 1373 "a custom getter is set. Current custom getter: %s. " 1374 "It is likely that you're using create_partitioned_variables. " 1375 "If so, consider instead using get_variable with a non-empty " 1376 "partitioner parameter instead." % self._custom_getter) 1377 1378 if partitioner is None: 1379 raise ValueError("No partitioner was specified") 1380 1381 # This allows the variable scope name to be used as the variable name if 1382 # this function is invoked with an empty name arg, for backward 1383 # compatibility with create_partitioned_variables(). 1384 full_name_list = [] 1385 if self.name: 1386 full_name_list.append(self.name) 1387 if name: 1388 full_name_list.append(name) 1389 full_name = "/".join(full_name_list) 1390 1391 # Variable names only depend on variable_scope (full_name here), 1392 # not name_scope, so we reset it below for the time of variable creation. 1393 with ops.name_scope(None, skip_on_eager=False): 1394 # pylint: disable=protected-access 1395 return var_store._get_partitioned_variable( 1396 full_name, 1397 shape=shape, 1398 dtype=dtype, 1399 initializer=initializer, 1400 regularizer=regularizer, 1401 reuse=self.reuse, 1402 trainable=trainable, 1403 collections=collections, 1404 caching_device=caching_device, 1405 partitioner=partitioner, 1406 validate_shape=validate_shape, 1407 use_resource=use_resource, 1408 constraint=constraint, 1409 synchronization=synchronization, 1410 aggregation=aggregation) 1411 # pylint: enable=protected-access 1412 1413 1414_VARSTORE_KEY = ("__variable_store",) 1415_VARSCOPESTORE_KEY = ("__varscope",) 1416 1417 1418class _VariableScopeStore(threading.local): 1419 """A thread local store for the current variable scope and scope counts.""" 1420 1421 def __init__(self): 1422 super(_VariableScopeStore, self).__init__() 1423 self.current_scope = VariableScope(False) 1424 self.variable_scopes_count = {} 1425 1426 def open_variable_scope(self, scope_name): 1427 if scope_name in self.variable_scopes_count: 1428 self.variable_scopes_count[scope_name] += 1 1429 else: 1430 self.variable_scopes_count[scope_name] = 1 1431 1432 def close_variable_subscopes(self, scope_name): 1433 for k in list(self.variable_scopes_count.keys()): 1434 if scope_name is None or k.startswith(scope_name + "/"): 1435 self.variable_scopes_count[k] = 0 1436 1437 def variable_scope_count(self, scope_name): 1438 return self.variable_scopes_count.get(scope_name, 0) 1439 1440 1441def get_variable_scope_store(): 1442 """Returns the variable scope store for current thread.""" 1443 scope_store = ops.get_collection(_VARSCOPESTORE_KEY) 1444 1445 if not scope_store: 1446 scope_store = _VariableScopeStore() 1447 ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store) 1448 else: 1449 scope_store = scope_store[0] 1450 1451 return scope_store 1452 1453 1454@tf_export(v1=["get_variable_scope"]) 1455def get_variable_scope(): 1456 """Returns the current variable scope.""" 1457 return get_variable_scope_store().current_scope 1458 1459 1460def _get_default_variable_store(): 1461 store = ops.get_collection(_VARSTORE_KEY) 1462 if store: 1463 return store[0] 1464 store = _VariableStore() 1465 ops.add_to_collection(_VARSTORE_KEY, store) 1466 return store 1467 1468 1469@tf_contextlib.contextmanager 1470def with_variable_store(store): 1471 store_collection = ops.get_collection_ref(_VARSTORE_KEY) 1472 old = list(store_collection) 1473 store_collection[:] = [store] 1474 try: 1475 yield 1476 finally: 1477 store_collection[:] = old 1478 1479 1480class EagerVariableStore(object): 1481 """Wrapper allowing functional layers to be used with eager execution. 1482 1483 When eager execution is enabled Variables get deleted when they go out of 1484 scope, and are not stored in global collections by default. A lot of code 1485 (mostly the functional layers in tf.layers) assumes that variables are kept in 1486 a global list. 1487 1488 EagerVariableStore can be used in conjunction with this code to make it 1489 eager-friendly. For example, to create a dense layer, use: 1490 1491 ``` 1492 container = tfe.EagerVariableStore() 1493 for input in dataset_iterator: 1494 with container.as_default(): 1495 x = tf.compat.v1.layers.dense(input, name="l1") 1496 print(container.variables) # Should print the variables used in the layer. 1497 ``` 1498 """ 1499 1500 def __init__(self, store=None): 1501 if store is not None: 1502 if not store._store_eager_variables: # pylint: disable=protected-access 1503 raise ValueError("Cannot construct EagerVariableStore from a " 1504 "VariableStore object that does not hold eager " 1505 "variables.") 1506 self._store = store 1507 else: 1508 self._store = _VariableStore() 1509 self._store._store_eager_variables = True # pylint: disable=protected-access 1510 1511 def as_default(self): 1512 return with_variable_store(self._store) 1513 1514 def variables(self): 1515 return sorted(self._store._vars.values(), key=lambda x: x.name) # pylint: disable=protected-access 1516 1517 def trainable_variables(self): 1518 # pylint: disable=protected-access 1519 return sorted([x for x in self._store._vars.values() if x.trainable], 1520 key=lambda x: x.name) 1521 # pylint: enable=protected-access 1522 1523 def non_trainable_variables(self): 1524 # pylint: disable=protected-access 1525 return sorted([x for x in self._store._vars.values() if not x.trainable], 1526 key=lambda x: x.name) 1527 # pylint: enable=protected-access 1528 1529 def copy(self): 1530 """Copy this variable store and all of its contents. 1531 1532 Variables contained in this store will be copied over to the new variable 1533 store, meaning that they can be modified without affecting the variables in 1534 this store. 1535 1536 Returns: 1537 A new EagerVariableStore instance containing copied variables. 1538 """ 1539 # pylint: disable=protected-access 1540 new_store = EagerVariableStore() 1541 for key, var in iteritems(self._store._vars): 1542 # Strip device out of variable name. 1543 try: 1544 index = var.name.index(":") 1545 except ValueError: 1546 stripped_var_name = var.name 1547 else: 1548 stripped_var_name = var.name[:index] 1549 1550 # Create new variable with same value, name, and "trainable" flag. 1551 new_var = resource_variable_ops.ResourceVariable( 1552 var.read_value(), name=stripped_var_name, trainable=var.trainable) 1553 new_store._store._vars[key] = new_var 1554 return new_store 1555 # pylint: enable=protected-access 1556 1557 1558# The argument list for get_variable must match arguments to get_local_variable. 1559# So, if you are updating the arguments, also update arguments to 1560# get_local_variable below. 1561@tf_export(v1=["get_variable"]) 1562def get_variable(name, 1563 shape=None, 1564 dtype=None, 1565 initializer=None, 1566 regularizer=None, 1567 trainable=None, 1568 collections=None, 1569 caching_device=None, 1570 partitioner=None, 1571 validate_shape=True, 1572 use_resource=None, 1573 custom_getter=None, 1574 constraint=None, 1575 synchronization=VariableSynchronization.AUTO, 1576 aggregation=VariableAggregation.NONE): 1577 return get_variable_scope().get_variable( 1578 _get_default_variable_store(), 1579 name, 1580 shape=shape, 1581 dtype=dtype, 1582 initializer=initializer, 1583 regularizer=regularizer, 1584 trainable=trainable, 1585 collections=collections, 1586 caching_device=caching_device, 1587 partitioner=partitioner, 1588 validate_shape=validate_shape, 1589 use_resource=use_resource, 1590 custom_getter=custom_getter, 1591 constraint=constraint, 1592 synchronization=synchronization, 1593 aggregation=aggregation) 1594 1595 1596get_variable_or_local_docstring = ("""%s 1597 1598%sThis function prefixes the name with the current variable scope 1599and performs reuse checks. See the 1600[Variable Scope How To](https://tensorflow.org/guide/variables) 1601for an extensive description of how reusing works. Here is a basic example: 1602 1603```python 1604def foo(): 1605 with tf.variable_scope("foo", reuse=tf.AUTO_REUSE): 1606 v = tf.get_variable("v", [1]) 1607 return v 1608 1609v1 = foo() # Creates v. 1610v2 = foo() # Gets the same, existing v. 1611assert v1 == v2 1612``` 1613 1614If initializer is `None` (the default), the default initializer passed in 1615the variable scope will be used. If that one is `None` too, a 1616`glorot_uniform_initializer` will be used. The initializer can also be 1617a Tensor, in which case the variable is initialized to this value and shape. 1618 1619Similarly, if the regularizer is `None` (the default), the default regularizer 1620passed in the variable scope will be used (if that is `None` too, 1621then by default no regularization is performed). 1622 1623If a partitioner is provided, a `PartitionedVariable` is returned. 1624Accessing this object as a `Tensor` returns the shards concatenated along 1625the partition axis. 1626 1627Some useful partitioners are available. See, e.g., 1628`variable_axis_size_partitioner` and `min_max_variable_partitioner`. 1629 1630Args: 1631 name: The name of the new or existing variable. 1632 shape: Shape of the new or existing variable. 1633 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 1634 initializer: Initializer for the variable if one is created. Can either be 1635 an initializer object or a Tensor. If it's a Tensor, its shape must be known 1636 unless validate_shape is False. 1637 regularizer: A (Tensor -> Tensor or None) function; the result of 1638 applying it on a newly created variable will be added to the collection 1639 `tf.GraphKeys.REGULARIZATION_LOSSES` and can be used for regularization. 1640 %scollections: List of graph collections keys to add the Variable to. 1641 Defaults to `[%s]` (see `tf.Variable`). 1642 caching_device: Optional device string or function describing where the 1643 Variable should be cached for reading. Defaults to the Variable's 1644 device. If not `None`, caches on another device. Typical use is to 1645 cache on the device where the Ops using the Variable reside, to 1646 deduplicate copying through `Switch` and other conditional statements. 1647 partitioner: Optional callable that accepts a fully defined `TensorShape` 1648 and `dtype` of the Variable to be created, and returns a list of 1649 partitions for each axis (currently only one axis can be partitioned). 1650 validate_shape: If False, allows the variable to be initialized with a 1651 value of unknown shape. If True, the default, the shape of initial_value 1652 must be known. For this to be used the initializer must be a Tensor and 1653 not an initializer object. 1654 use_resource: If False, creates a regular Variable. If true, creates an 1655 experimental ResourceVariable instead with well-defined semantics. 1656 Defaults to False (will later change to True). When eager execution is 1657 enabled this argument is always forced to be True. 1658 custom_getter: Callable that takes as a first argument the true getter, and 1659 allows overwriting the internal get_variable method. 1660 The signature of `custom_getter` should match that of this method, 1661 but the most future-proof version will allow for changes: 1662 `def custom_getter(getter, *args, **kwargs)`. Direct access to 1663 all `get_variable` parameters is also allowed: 1664 `def custom_getter(getter, name, *args, **kwargs)`. A simple identity 1665 custom getter that simply creates variables with modified names is: 1666 ```python 1667 def custom_getter(getter, name, *args, **kwargs): 1668 return getter(name + '_suffix', *args, **kwargs) 1669 ``` 1670 constraint: An optional projection function to be applied to the variable 1671 after being updated by an `Optimizer` (e.g. used to implement norm 1672 constraints or value constraints for layer weights). The function must 1673 take as input the unprojected Tensor representing the value of the 1674 variable and return the Tensor for the projected value 1675 (which must have the same shape). Constraints are not safe to 1676 use when doing asynchronous distributed training. 1677 synchronization: Indicates when a distributed a variable will be 1678 aggregated. Accepted values are constants defined in the class 1679 `tf.VariableSynchronization`. By default the synchronization is set to 1680 `AUTO` and the current `DistributionStrategy` chooses 1681 when to synchronize. 1682 aggregation: Indicates how a distributed variable will be aggregated. 1683 Accepted values are constants defined in the class 1684 `tf.VariableAggregation`. 1685 1686Returns: 1687 The created or existing `Variable` (or `PartitionedVariable`, if a 1688 partitioner was used). 1689 1690Raises: 1691 ValueError: when creating a new variable and shape is not declared, 1692 when violating reuse during variable creation, or when `initializer` dtype 1693 and `dtype` don't match. Reuse is set inside `variable_scope`. 1694""") 1695get_variable.__doc__ = get_variable_or_local_docstring % ( 1696 "Gets an existing variable with these parameters or create a new one.", "", 1697 "trainable: If `True` also add the variable to the graph collection\n" 1698 " `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).\n ", 1699 "GraphKeys.GLOBAL_VARIABLES") 1700 1701 1702# The argument list for get_local_variable must match arguments to get_variable. 1703# So, if you are updating the arguments, also update arguments to get_variable. 1704@tf_export(v1=["get_local_variable"]) 1705def get_local_variable( # pylint: disable=missing-docstring 1706 name, 1707 shape=None, 1708 dtype=None, 1709 initializer=None, 1710 regularizer=None, 1711 trainable=False, # pylint: disable=unused-argument 1712 collections=None, 1713 caching_device=None, 1714 partitioner=None, 1715 validate_shape=True, 1716 use_resource=None, 1717 custom_getter=None, 1718 constraint=None, 1719 synchronization=VariableSynchronization.AUTO, 1720 aggregation=VariableAggregation.NONE): 1721 if collections: 1722 collections += [ops.GraphKeys.LOCAL_VARIABLES] 1723 else: 1724 collections = [ops.GraphKeys.LOCAL_VARIABLES] 1725 return get_variable( 1726 name, 1727 shape=shape, 1728 dtype=dtype, 1729 initializer=initializer, 1730 regularizer=regularizer, 1731 trainable=False, 1732 collections=collections, 1733 caching_device=caching_device, 1734 partitioner=partitioner, 1735 validate_shape=validate_shape, 1736 use_resource=use_resource, 1737 synchronization=synchronization, 1738 aggregation=aggregation, 1739 custom_getter=custom_getter, 1740 constraint=constraint) 1741 1742 1743get_local_variable.__doc__ = get_variable_or_local_docstring % ( 1744 "Gets an existing *local* variable or creates a new one.", 1745 "Behavior is the same as in `get_variable`, except that variables are\n" 1746 "added to the `LOCAL_VARIABLES` collection and `trainable` is set to\n" 1747 "`False`.\n", "", "GraphKeys.LOCAL_VARIABLES") 1748 1749 1750def _get_partitioned_variable(name, 1751 shape=None, 1752 dtype=None, 1753 initializer=None, 1754 regularizer=None, 1755 trainable=True, 1756 collections=None, 1757 caching_device=None, 1758 partitioner=None, 1759 validate_shape=True, 1760 use_resource=None, 1761 constraint=None, 1762 synchronization=VariableSynchronization.AUTO, 1763 aggregation=VariableAggregation.NONE): 1764 """Gets or creates a sharded variable list with these parameters. 1765 1766 The `partitioner` must be a callable that accepts a fully defined 1767 `TensorShape` and returns a sequence of integers (the `partitions`). 1768 These integers describe how to partition the given sharded `Variable` 1769 along the given dimension. That is, `partitions[1] = 3` means split 1770 the `Variable` into 3 shards along dimension 1. Currently, sharding along 1771 only one axis is supported. 1772 1773 If the list of variables with the given name (prefix) is already stored, 1774 we return the stored variables. Otherwise, we create a new one. 1775 1776 If initializer is `None` (the default), the default initializer passed in 1777 the constructor is used. If that one is `None` too, we use a new 1778 `glorot_uniform_initializer`. If initializer is a Tensor, we use 1779 it as a value and derive the shape from the initializer. 1780 1781 If the initializer is a callable, then it will be called for each 1782 shard. Otherwise the initializer should match the shape of the entire 1783 sharded Variable, and it will be sliced accordingly for each shard. 1784 1785 Some useful partitioners are available. See, e.g., 1786 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 1787 1788 Args: 1789 name: The name of the new or existing variable. 1790 shape: Shape of the new or existing variable. 1791 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 1792 initializer: Initializer for the variable if one is created. 1793 regularizer: A (Tensor -> Tensor or None) function; the result of applying 1794 it on a newly created variable will be added to the collection 1795 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 1796 trainable: If `True` also add the variable to the graph collection 1797 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 1798 collections: List of graph collections keys to add the Variable to. Defaults 1799 to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 1800 caching_device: Optional device string or function describing where the 1801 Variable should be cached for reading. Defaults to the Variable's device. 1802 If not `None`, caches on another device. Typical use is to cache on the 1803 device where the Ops using the Variable reside, to deduplicate copying 1804 through `Switch` and other conditional statements. 1805 partitioner: Optional callable that accepts a fully defined `TensorShape` 1806 and `dtype` of the Variable to be created, and returns a list of 1807 partitions for each axis (currently only one axis can be partitioned). 1808 validate_shape: If False, allows the variable to be initialized with a value 1809 of unknown shape. If True, the default, the shape of initial_value must be 1810 known. 1811 use_resource: If False, creates a regular Variable. If True, creates an 1812 experimental ResourceVariable instead which has well-defined semantics. 1813 Defaults to False (will later change to True). 1814 constraint: An optional projection function to be applied to the variable 1815 after being updated by an `Optimizer` (e.g. used to implement norm 1816 constraints or value constraints for layer weights). The function must 1817 take as input the unprojected Tensor representing the value of the 1818 variable and return the Tensor for the projected value (which must have 1819 the same shape). Constraints are not safe to use when doing asynchronous 1820 distributed training. 1821 synchronization: Indicates when a distributed a variable will be aggregated. 1822 Accepted values are constants defined in the class 1823 `tf.VariableSynchronization`. By default the synchronization is set to 1824 `AUTO` and the current `DistributionStrategy` chooses when to synchronize. 1825 aggregation: Indicates how a distributed variable will be aggregated. 1826 Accepted values are constants defined in the class 1827 `tf.VariableAggregation`. 1828 1829 Returns: 1830 A tuple `(shards, partitions)` where `shards` is the list of `Variable` 1831 shards and `partitions` is the output of the partitioner on the input 1832 shape. 1833 1834 Raises: 1835 ValueError: when creating a new variable and shape is not declared, 1836 or when violating reuse during variable creation. Reuse is set inside 1837 `variable_scope`. 1838 """ 1839 # pylint: disable=protected-access 1840 scope = get_variable_scope() 1841 if scope.custom_getter is not None: 1842 raise ValueError( 1843 "Private access to _get_partitioned_variable is not allowed when " 1844 "a custom getter is set. Current custom getter: %s. " 1845 "It is likely that you're using create_partitioned_variables. " 1846 "If so, consider instead using get_variable with a non-empty " 1847 "partitioner parameter instead." % scope.custom_getter) 1848 return scope._get_partitioned_variable( 1849 _get_default_variable_store(), 1850 name, 1851 shape=shape, 1852 dtype=dtype, 1853 initializer=initializer, 1854 regularizer=regularizer, 1855 trainable=trainable, 1856 collections=collections, 1857 caching_device=caching_device, 1858 partitioner=partitioner, 1859 validate_shape=validate_shape, 1860 use_resource=use_resource, 1861 constraint=constraint, 1862 synchronization=synchronization, 1863 aggregation=aggregation) 1864 # pylint: enable=protected-access 1865 1866 1867# Named like a function for compatibility with the previous 1868# @tf_contextlib.contextmanager definition. 1869class _pure_variable_scope(object): # pylint: disable=invalid-name 1870 """A context for the variable_scope, see `variable_scope` for docs.""" 1871 1872 def __init__(self, 1873 name_or_scope, 1874 reuse=None, 1875 initializer=None, 1876 regularizer=None, 1877 caching_device=None, 1878 partitioner=None, 1879 custom_getter=None, 1880 old_name_scope=None, 1881 dtype=dtypes.float32, 1882 use_resource=None, 1883 constraint=None): 1884 """Creates a context for the variable_scope, see `variable_scope` for docs. 1885 1886 Note: this does not create a name scope. 1887 1888 Args: 1889 name_or_scope: `string` or `VariableScope`: the scope to open. 1890 reuse: `True` or None, or tf.compat.v1.AUTO_REUSE; if `None`, we inherit 1891 the parent scope's reuse flag. 1892 initializer: default initializer for variables within this scope. 1893 regularizer: default regularizer for variables within this scope. 1894 caching_device: default caching device for variables within this scope. 1895 partitioner: default partitioner for variables within this scope. 1896 custom_getter: default custom getter for variables within this scope. 1897 old_name_scope: the original name scope when re-entering a variable scope. 1898 dtype: type of the variables within this scope (defaults to `DT_FLOAT`). 1899 use_resource: If False, variables in this scope will be regular Variables. 1900 If True, experimental ResourceVariables will be creates instead, with 1901 well-defined semantics. Defaults to False (will later change to True). 1902 constraint: An optional projection function to be applied to the variable 1903 after being updated by an `Optimizer` (e.g. used to implement norm 1904 constraints or value constraints for layer weights). The function must 1905 take as input the unprojected Tensor representing the value of the 1906 variable and return the Tensor for the projected value (which must have 1907 the same shape). Constraints are not safe to use when doing asynchronous 1908 distributed training. 1909 """ 1910 self._name_or_scope = name_or_scope 1911 self._reuse = reuse 1912 self._initializer = initializer 1913 self._regularizer = regularizer 1914 self._caching_device = caching_device 1915 self._partitioner = partitioner 1916 self._custom_getter = custom_getter 1917 self._old_name_scope = old_name_scope 1918 self._dtype = dtype 1919 self._use_resource = use_resource 1920 self._constraint = constraint 1921 self._var_store = _get_default_variable_store() 1922 self._var_scope_store = get_variable_scope_store() 1923 self._last_variable_scope_object = None 1924 if isinstance(self._name_or_scope, VariableScope): 1925 self._new_name = self._name_or_scope.name 1926 name_scope = self._name_or_scope._name_scope # pylint: disable=protected-access 1927 # Handler for the case when we jump to a shared scope. We create a new 1928 # VariableScope (self._var_scope_object) that contains a copy of the 1929 # provided shared scope, possibly with changed reuse and initializer, if 1930 # the user requested this. 1931 variable_scope_object = VariableScope( 1932 self._name_or_scope.reuse if not self._reuse else self._reuse, 1933 name=self._new_name, 1934 initializer=self._name_or_scope.initializer, 1935 regularizer=self._name_or_scope.regularizer, 1936 caching_device=self._name_or_scope.caching_device, 1937 partitioner=self._name_or_scope.partitioner, 1938 dtype=self._name_or_scope.dtype, 1939 custom_getter=self._name_or_scope.custom_getter, 1940 name_scope=name_scope, 1941 use_resource=self._name_or_scope.use_resource, 1942 constraint=self._constraint) 1943 if self._initializer is not None: 1944 variable_scope_object.set_initializer(self._initializer) 1945 if self._regularizer is not None: 1946 variable_scope_object.set_regularizer(self._regularizer) 1947 if self._caching_device is not None: 1948 variable_scope_object.set_caching_device(self._caching_device) 1949 if self._partitioner is not None: 1950 variable_scope_object.set_partitioner(self._partitioner) 1951 if self._custom_getter is not None: 1952 variable_scope_object.set_custom_getter( 1953 _maybe_wrap_custom_getter(self._custom_getter, 1954 self._name_or_scope.custom_getter)) 1955 if self._dtype is not None: 1956 variable_scope_object.set_dtype(self._dtype) 1957 if self._use_resource is not None: 1958 variable_scope_object.set_use_resource(self._use_resource) 1959 self._cached_variable_scope_object = variable_scope_object 1960 1961 def __enter__(self): 1962 """Begins the scope block. 1963 1964 Returns: 1965 A VariableScope. 1966 Raises: 1967 ValueError: when trying to reuse within a create scope, or create within 1968 a reuse scope, or if reuse is not `None` or `True`. 1969 TypeError: when the types of some arguments are not appropriate. 1970 """ 1971 self._old = self._var_scope_store.current_scope 1972 if isinstance(self._name_or_scope, VariableScope): 1973 self._var_scope_store.open_variable_scope(self._new_name) 1974 self._old_subscopes = copy.copy( 1975 self._var_scope_store.variable_scopes_count) 1976 variable_scope_object = self._cached_variable_scope_object 1977 else: 1978 # Handler for the case when we just prolong current variable scope. 1979 # VariableScope with name extended by the provided one, and inherited 1980 # reuse and initializer (except if the user provided values to set). 1981 self._new_name = ( 1982 self._old.name + "/" + 1983 self._name_or_scope if self._old.name else self._name_or_scope) 1984 self._reuse = (self._reuse or 1985 self._old.reuse) # Re-using is inherited by sub-scopes. 1986 if self._old_name_scope is None: 1987 name_scope = self._name_or_scope 1988 else: 1989 name_scope = self._old_name_scope 1990 variable_scope_object = VariableScope( 1991 self._reuse, 1992 name=self._new_name, 1993 initializer=self._old.initializer, 1994 regularizer=self._old.regularizer, 1995 caching_device=self._old.caching_device, 1996 partitioner=self._old.partitioner, 1997 dtype=self._old.dtype, 1998 use_resource=self._old.use_resource, 1999 custom_getter=self._old.custom_getter, 2000 name_scope=name_scope, 2001 constraint=self._constraint) 2002 if self._initializer is not None: 2003 variable_scope_object.set_initializer(self._initializer) 2004 if self._regularizer is not None: 2005 variable_scope_object.set_regularizer(self._regularizer) 2006 if self._caching_device is not None: 2007 variable_scope_object.set_caching_device(self._caching_device) 2008 if self._partitioner is not None: 2009 variable_scope_object.set_partitioner(self._partitioner) 2010 if self._custom_getter is not None: 2011 variable_scope_object.set_custom_getter( 2012 _maybe_wrap_custom_getter(self._custom_getter, 2013 self._old.custom_getter)) 2014 if self._dtype is not None: 2015 variable_scope_object.set_dtype(self._dtype) 2016 if self._use_resource is not None: 2017 variable_scope_object.set_use_resource(self._use_resource) 2018 self._var_scope_store.open_variable_scope(self._new_name) 2019 self._var_scope_store.current_scope = variable_scope_object 2020 self._last_variable_scope_object = variable_scope_object 2021 return variable_scope_object 2022 2023 def __exit__(self, type_arg, value_arg, traceback_arg): 2024 if (self._var_scope_store.current_scope is 2025 not self._last_variable_scope_object): 2026 raise RuntimeError("Improper nesting of variable_scope.") 2027 # If jumping out from a non-prolonged scope, restore counts. 2028 if isinstance(self._name_or_scope, VariableScope): 2029 self._var_scope_store.variable_scopes_count = self._old_subscopes 2030 else: 2031 self._var_scope_store.close_variable_subscopes(self._new_name) 2032 self._var_scope_store.current_scope = self._old 2033 2034 2035def _maybe_wrap_custom_getter(custom_getter, old_getter): 2036 """Wrap a call to a custom_getter to use the old_getter internally.""" 2037 if old_getter is None: 2038 return custom_getter 2039 2040 # The new custom_getter should call the old one 2041 def wrapped_custom_getter(getter, *args, **kwargs): 2042 # Call: 2043 # custom_getter( 2044 # lambda: old_getter(true_getter, ...), *args, **kwargs) 2045 # which means custom_getter will call old_getter, which 2046 # will call the true_getter, perform any intermediate 2047 # processing, and return the results to the current 2048 # getter, which will also perform additional processing. 2049 return custom_getter(functools.partial(old_getter, getter), *args, **kwargs) 2050 2051 return wrapped_custom_getter 2052 2053 2054def _get_unique_variable_scope(prefix): 2055 """Get a name with the given prefix unique in the current variable scope.""" 2056 var_scope_store = get_variable_scope_store() 2057 current_scope = get_variable_scope() 2058 name = current_scope.name + "/" + prefix if current_scope.name else prefix 2059 if var_scope_store.variable_scope_count(name) == 0: 2060 return prefix 2061 idx = 1 2062 while var_scope_store.variable_scope_count(name + ("_%d" % idx)) > 0: 2063 idx += 1 2064 return prefix + ("_%d" % idx) 2065 2066 2067# Named like a function for backwards compatibility with the 2068# @tf_contextlib.contextmanager version, which was switched to a class to avoid 2069# some object creation overhead. 2070@tf_export(v1=["variable_scope"]) # pylint: disable=invalid-name 2071class variable_scope(object): 2072 """A context manager for defining ops that creates variables (layers). 2073 2074 This context manager validates that the (optional) `values` are from the same 2075 graph, ensures that graph is the default graph, and pushes a name scope and a 2076 variable scope. 2077 2078 If `name_or_scope` is not None, it is used as is. If `name_or_scope` is None, 2079 then `default_name` is used. In that case, if the same name has been 2080 previously used in the same scope, it will be made unique by appending `_N` 2081 to it. 2082 2083 Variable scope allows you to create new variables and to share already created 2084 ones while providing checks to not create or share by accident. For details, 2085 see the [Variable Scope How To](https://tensorflow.org/guide/variables), here 2086 we present only a few basic examples. 2087 2088 The Variable Scope works as expected when the Eager Execution is Disabled. 2089 2090 ```python 2091 tf.compat.v1.disable_eager_execution() 2092 ``` 2093 2094 Simple example of how to create a new variable: 2095 2096 ```python 2097 with tf.compat.v1.variable_scope("foo"): 2098 with tf.compat.v1.variable_scope("bar"): 2099 v = tf.compat.v1.get_variable("v", [1]) 2100 assert v.name == "foo/bar/v:0" 2101 ``` 2102 2103 Simple example of how to reenter a premade variable scope safely: 2104 2105 ```python 2106 with tf.compat.v1.variable_scope("foo") as vs: 2107 pass 2108 2109 # Re-enter the variable scope. 2110 with tf.compat.v1.variable_scope(vs, 2111 auxiliary_name_scope=False) as vs1: 2112 # Restore the original name_scope. 2113 with tf.name_scope(vs1.original_name_scope): 2114 v = tf.compat.v1.get_variable("v", [1]) 2115 assert v.name == "foo/v:0" 2116 c = tf.constant([1], name="c") 2117 assert c.name == "foo/c:0" 2118 ``` 2119 2120 Keep in mind that the counters for `default_name` are discarded once the 2121 parent scope is exited. Therefore when the code re-enters the scope (for 2122 instance by saving it), all nested default_name counters will be restarted. 2123 2124 For instance: 2125 2126 ```python 2127 with tf.compat.v1.variable_scope("foo") as vs: 2128 with tf.compat.v1.variable_scope(None, default_name="bar"): 2129 v = tf.compat.v1.get_variable("a", [1]) 2130 assert v.name == "foo/bar/a:0", v.name 2131 with tf.compat.v1.variable_scope(None, default_name="bar"): 2132 v = tf.compat.v1.get_variable("b", [1]) 2133 assert v.name == "foo/bar_1/b:0" 2134 2135 with tf.compat.v1.variable_scope(vs): 2136 with tf.compat.v1.variable_scope(None, default_name="bar"): 2137 v = tf.compat.v1.get_variable("c", [1]) 2138 assert v.name == "foo/bar/c:0" # Uses bar instead of bar_2! 2139 ``` 2140 2141 Basic example of sharing a variable AUTO_REUSE: 2142 2143 ```python 2144 def foo(): 2145 with tf.compat.v1.variable_scope("foo", reuse=tf.compat.v1.AUTO_REUSE): 2146 v = tf.compat.v1.get_variable("v", [1]) 2147 return v 2148 2149 v1 = foo() # Creates v. 2150 v2 = foo() # Gets the same, existing v. 2151 assert v1 == v2 2152 ``` 2153 2154 Basic example of sharing a variable with reuse=True: 2155 2156 ```python 2157 with tf.compat.v1.variable_scope("foo"): 2158 v = tf.compat.v1.get_variable("v", [1]) 2159 with tf.compat.v1.variable_scope("foo", reuse=True): 2160 v1 = tf.compat.v1.get_variable("v", [1]) 2161 assert v1 == v 2162 ``` 2163 2164 Sharing a variable by capturing a scope and setting reuse: 2165 2166 ```python 2167 with tf.compat.v1.variable_scope("foo") as scope: 2168 v = tf.compat.v1.get_variable("v", [1]) 2169 scope.reuse_variables() 2170 v1 = tf.compat.v1.get_variable("v", [1]) 2171 assert v1 == v 2172 ``` 2173 2174 To prevent accidental sharing of variables, we raise an exception when getting 2175 an existing variable in a non-reusing scope. 2176 2177 ```python 2178 with tf.compat.v1.variable_scope("foo"): 2179 v = tf.compat.v1.get_variable("v", [1]) 2180 v1 = tf.compat.v1.get_variable("v", [1]) 2181 # Raises ValueError("... v already exists ..."). 2182 ``` 2183 2184 Similarly, we raise an exception when trying to get a variable that does not 2185 exist in reuse mode. 2186 2187 ```python 2188 with tf.compat.v1.variable_scope("foo", reuse=True): 2189 v = tf.compat.v1.get_variable("v", [1]) 2190 # Raises ValueError("... v does not exists ..."). 2191 ``` 2192 2193 Note that the `reuse` flag is inherited: if we open a reusing scope, then all 2194 its sub-scopes become reusing as well. 2195 2196 A note about name scoping: Setting `reuse` does not impact the naming of other 2197 ops such as mult. See related discussion on 2198 [github#6189](https://github.com/tensorflow/tensorflow/issues/6189) 2199 2200 Note that up to and including version 1.0, it was allowed (though explicitly 2201 discouraged) to pass False to the reuse argument, yielding undocumented 2202 behaviour slightly different from None. Starting at 1.1.0 passing None and 2203 False as reuse has exactly the same effect. 2204 2205 A note about using variable scopes in multi-threaded environment: Variable 2206 scopes are thread local, so one thread will not see another thread's current 2207 scope. Also, when using `default_name`, unique scopes names are also generated 2208 only on a per thread basis. If the same name was used within a different 2209 thread, that doesn't prevent a new thread from creating the same scope. 2210 However, the underlying variable store is shared across threads (within the 2211 same graph). As such, if another thread tries to create a new variable with 2212 the same name as a variable created by a previous thread, it will fail unless 2213 reuse is True. 2214 2215 Further, each thread starts with an empty variable scope. So if you wish to 2216 preserve name prefixes from a scope from the main thread, you should capture 2217 the main thread's scope and re-enter it in each thread. For e.g. 2218 2219 ``` 2220 main_thread_scope = variable_scope.get_variable_scope() 2221 2222 # Thread's target function: 2223 def thread_target_fn(captured_scope): 2224 with variable_scope.variable_scope(captured_scope): 2225 # .... regular code for this thread 2226 2227 2228 thread = threading.Thread(target=thread_target_fn, args=(main_thread_scope,)) 2229 ``` 2230 """ 2231 2232 def __init__(self, 2233 name_or_scope, 2234 default_name=None, 2235 values=None, 2236 initializer=None, 2237 regularizer=None, 2238 caching_device=None, 2239 partitioner=None, 2240 custom_getter=None, 2241 reuse=None, 2242 dtype=None, 2243 use_resource=None, 2244 constraint=None, 2245 auxiliary_name_scope=True): 2246 """Initialize the context manager. 2247 2248 Args: 2249 name_or_scope: `string` or `VariableScope`: the scope to open. 2250 default_name: The default name to use if the `name_or_scope` argument is 2251 `None`, this name will be uniquified. If name_or_scope is provided it 2252 won't be used and therefore it is not required and can be None. 2253 values: The list of `Tensor` arguments that are passed to the op function. 2254 initializer: default initializer for variables within this scope. 2255 regularizer: default regularizer for variables within this scope. 2256 caching_device: default caching device for variables within this scope. 2257 partitioner: default partitioner for variables within this scope. 2258 custom_getter: default custom getter for variables within this scope. 2259 reuse: `True`, None, or tf.compat.v1.AUTO_REUSE; if `True`, we go into 2260 reuse mode for this scope as well as all sub-scopes; if 2261 tf.compat.v1.AUTO_REUSE, we create variables if they do not exist, and 2262 return them otherwise; if None, we inherit the parent scope's reuse 2263 flag. When eager execution is enabled, new variables are always created 2264 unless an EagerVariableStore or template is currently active. 2265 dtype: type of variables created in this scope (defaults to the type in 2266 the passed scope, or inherited from parent scope). 2267 use_resource: If False, all variables will be regular Variables. If True, 2268 experimental ResourceVariables with well-defined semantics will be used 2269 instead. Defaults to False (will later change to True). When eager 2270 execution is enabled this argument is always forced to be True. 2271 constraint: An optional projection function to be applied to the variable 2272 after being updated by an `Optimizer` (e.g. used to implement norm 2273 constraints or value constraints for layer weights). The function must 2274 take as input the unprojected Tensor representing the value of the 2275 variable and return the Tensor for the projected value (which must have 2276 the same shape). Constraints are not safe to use when doing asynchronous 2277 distributed training. 2278 auxiliary_name_scope: If `True`, we create an auxiliary name scope with 2279 the scope. If `False`, we don't create it. Note that the argument is not 2280 inherited, and it only takes effect for once when creating. You should 2281 only use it for re-entering a premade variable scope. 2282 2283 Returns: 2284 A scope that can be captured and reused. 2285 2286 Raises: 2287 ValueError: when trying to reuse within a create scope, or create within 2288 a reuse scope. 2289 TypeError: when the types of some arguments are not appropriate. 2290 """ 2291 self._name_or_scope = name_or_scope 2292 self._default_name = default_name 2293 self._values = values 2294 self._initializer = initializer 2295 self._regularizer = regularizer 2296 self._caching_device = caching_device 2297 self._partitioner = partitioner 2298 self._custom_getter = custom_getter 2299 self._reuse = reuse 2300 self._dtype = dtype 2301 self._use_resource = use_resource 2302 self._constraint = constraint 2303 if self._default_name is None and self._name_or_scope is None: 2304 raise TypeError("If default_name is None then name_or_scope is required") 2305 if self._reuse is False: 2306 # We don't allow non-inheriting scopes, False = None here. 2307 self._reuse = None 2308 if not (self._reuse is True 2309 or self._reuse is None 2310 or self._reuse is AUTO_REUSE): 2311 raise ValueError("The reuse parameter must be True or False or None.") 2312 if self._values is None: 2313 self._values = [] 2314 self._in_graph_mode = not context.executing_eagerly() 2315 if self._in_graph_mode: 2316 self._graph = ops._get_graph_from_inputs(self._values) # pylint: disable=protected-access 2317 self._cached_pure_variable_scope = None 2318 self._current_name_scope = None 2319 if not isinstance(auxiliary_name_scope, bool): 2320 raise TypeError("The auxiliary_name_scope must be `True` or `False`, " 2321 "while get {}".format(auxiliary_name_scope)) 2322 self._auxiliary_name_scope = auxiliary_name_scope 2323 2324 def __enter__(self): 2325 # If the default graph is building a function, then we should not replace it 2326 # with the cached graph. 2327 if ops.get_default_graph().building_function: 2328 self._building_function = True 2329 else: 2330 self._building_function = False 2331 if self._in_graph_mode and not self._building_function: 2332 self._graph_context_manager = self._graph.as_default() 2333 self._graph_context_manager.__enter__() 2334 if self._cached_pure_variable_scope is not None: 2335 # Fast path for re-entering variable_scopes. We've held on to the pure 2336 # variable scope from a previous successful __enter__, so we avoid some 2337 # overhead by re-using that object. 2338 if self._current_name_scope is not None: 2339 self._current_name_scope.__enter__() 2340 return self._cached_pure_variable_scope.__enter__() 2341 2342 try: 2343 return self._enter_scope_uncached() 2344 except: 2345 if (self._in_graph_mode and not self._building_function and 2346 self._graph_context_manager is not None): 2347 self._graph_context_manager.__exit__(*sys.exc_info()) 2348 raise 2349 2350 def _enter_scope_uncached(self): 2351 """Enters the context manager when there is no cached scope yet. 2352 2353 Returns: 2354 The entered variable scope. 2355 2356 Raises: 2357 TypeError: A wrong type is passed as `scope` at __init__(). 2358 ValueError: `reuse` is incorrectly set at __init__(). 2359 """ 2360 if self._auxiliary_name_scope: 2361 # Create a new name scope later 2362 current_name_scope = None 2363 else: 2364 # Reenter the current name scope 2365 name_scope = ops.get_name_scope() 2366 if name_scope: 2367 # Hack to reenter 2368 name_scope += "/" 2369 current_name_scope = ops.name_scope(name_scope, skip_on_eager=False) 2370 else: 2371 # Root scope 2372 current_name_scope = ops.name_scope(name_scope, skip_on_eager=False) 2373 2374 # IMPORTANT: Only assign to self._cached_pure_variable_scope and 2375 # self._current_name_scope after successful __enter__() calls. 2376 if self._name_or_scope is not None: 2377 if not isinstance(self._name_or_scope, 2378 (VariableScope,) + six.string_types): 2379 raise TypeError("VariableScope: name_or_scope must be a string or " 2380 "VariableScope.") 2381 if isinstance(self._name_or_scope, six.string_types): 2382 name_scope = self._name_or_scope 2383 else: 2384 name_scope = self._name_or_scope.name.split("/")[-1] 2385 if name_scope or current_name_scope: 2386 current_name_scope = current_name_scope or ops.name_scope( 2387 name_scope, skip_on_eager=False) 2388 try: 2389 current_name_scope_name = current_name_scope.__enter__() 2390 except: 2391 current_name_scope.__exit__(*sys.exc_info()) 2392 raise 2393 self._current_name_scope = current_name_scope 2394 if isinstance(self._name_or_scope, six.string_types): 2395 old_name_scope = current_name_scope_name 2396 else: 2397 old_name_scope = self._name_or_scope.original_name_scope 2398 pure_variable_scope = _pure_variable_scope( 2399 self._name_or_scope, 2400 reuse=self._reuse, 2401 initializer=self._initializer, 2402 regularizer=self._regularizer, 2403 caching_device=self._caching_device, 2404 partitioner=self._partitioner, 2405 custom_getter=self._custom_getter, 2406 old_name_scope=old_name_scope, 2407 dtype=self._dtype, 2408 use_resource=self._use_resource, 2409 constraint=self._constraint) 2410 try: 2411 entered_pure_variable_scope = pure_variable_scope.__enter__() 2412 except: 2413 pure_variable_scope.__exit__(*sys.exc_info()) 2414 raise 2415 self._cached_pure_variable_scope = pure_variable_scope 2416 return entered_pure_variable_scope 2417 else: 2418 self._current_name_scope = None 2419 # This can only happen if someone is entering the root variable scope. 2420 pure_variable_scope = _pure_variable_scope( 2421 self._name_or_scope, 2422 reuse=self._reuse, 2423 initializer=self._initializer, 2424 regularizer=self._regularizer, 2425 caching_device=self._caching_device, 2426 partitioner=self._partitioner, 2427 custom_getter=self._custom_getter, 2428 dtype=self._dtype, 2429 use_resource=self._use_resource, 2430 constraint=self._constraint) 2431 try: 2432 entered_pure_variable_scope = pure_variable_scope.__enter__() 2433 except: 2434 pure_variable_scope.__exit__(*sys.exc_info()) 2435 raise 2436 self._cached_pure_variable_scope = pure_variable_scope 2437 return entered_pure_variable_scope 2438 2439 else: # Here name_or_scope is None. Using default name, but made unique. 2440 if self._reuse: 2441 raise ValueError("reuse=True cannot be used without a name_or_scope") 2442 current_name_scope = current_name_scope or ops.name_scope( 2443 self._default_name, skip_on_eager=False) 2444 try: 2445 current_name_scope_name = current_name_scope.__enter__() 2446 except: 2447 current_name_scope.__exit__(*sys.exc_info()) 2448 raise 2449 self._current_name_scope = current_name_scope 2450 unique_default_name = _get_unique_variable_scope(self._default_name) 2451 pure_variable_scope = _pure_variable_scope( 2452 unique_default_name, 2453 initializer=self._initializer, 2454 regularizer=self._regularizer, 2455 caching_device=self._caching_device, 2456 partitioner=self._partitioner, 2457 custom_getter=self._custom_getter, 2458 old_name_scope=current_name_scope_name, 2459 dtype=self._dtype, 2460 use_resource=self._use_resource, 2461 constraint=self._constraint) 2462 try: 2463 entered_pure_variable_scope = pure_variable_scope.__enter__() 2464 except: 2465 pure_variable_scope.__exit__(*sys.exc_info()) 2466 raise 2467 self._cached_pure_variable_scope = pure_variable_scope 2468 return entered_pure_variable_scope 2469 2470 def __exit__(self, type_arg, value_arg, traceback_arg): 2471 try: 2472 self._cached_pure_variable_scope.__exit__(type_arg, value_arg, 2473 traceback_arg) 2474 finally: 2475 try: 2476 if self._current_name_scope: 2477 self._current_name_scope.__exit__(type_arg, value_arg, 2478 traceback_arg) 2479 finally: 2480 if self._in_graph_mode and not self._building_function: 2481 self._graph_context_manager.__exit__(type_arg, value_arg, 2482 traceback_arg) 2483 2484 2485# pylint: disable=g-doc-return-or-yield 2486@tf_export(v1=["variable_op_scope"]) 2487@tf_contextlib.contextmanager 2488def variable_op_scope(values, 2489 name_or_scope, 2490 default_name=None, 2491 initializer=None, 2492 regularizer=None, 2493 caching_device=None, 2494 partitioner=None, 2495 custom_getter=None, 2496 reuse=None, 2497 dtype=None, 2498 use_resource=None, 2499 constraint=None): 2500 """Deprecated: context manager for defining an op that creates variables.""" 2501 logging.warn("tf.variable_op_scope(values, name, default_name) is deprecated," 2502 " use tf.variable_scope(name, default_name, values)") 2503 with variable_scope( 2504 name_or_scope, 2505 default_name=default_name, 2506 values=values, 2507 initializer=initializer, 2508 regularizer=regularizer, 2509 caching_device=caching_device, 2510 partitioner=partitioner, 2511 custom_getter=custom_getter, 2512 reuse=reuse, 2513 dtype=dtype, 2514 use_resource=use_resource, 2515 constraint=constraint) as scope: 2516 yield scope 2517 2518 2519def _call_partitioner(partitioner, shape, dtype): 2520 """Call partitioner validating its inputs/output. 2521 2522 Args: 2523 partitioner: a function mapping `Tensor` shape and dtype to a list of 2524 partitions. 2525 shape: shape of the `Tensor` to partition, must have at least two 2526 dimensions. 2527 dtype: dtype of the elements in the `Tensor`. 2528 2529 Returns: 2530 A list with elements >=1 and exactly one >1. The index of that 2531 element corresponds to the partitioning axis. 2532 """ 2533 if not shape.is_fully_defined(): 2534 raise ValueError("Shape of a new partitioned variable must be " 2535 "fully defined, but instead was %s." % (shape,)) 2536 if shape.ndims < 1: 2537 raise ValueError("A partitioned Variable must have rank at least 1, " 2538 "shape: %s" % shape) 2539 2540 slicing = partitioner(shape=shape, dtype=dtype) 2541 if not isinstance(slicing, collections_abc.Sequence): 2542 raise ValueError("Partitioner must return a sequence, but saw: %s" % 2543 slicing) 2544 if len(slicing) != shape.ndims: 2545 raise ValueError( 2546 "Partitioner returned a partition list that does not match the " 2547 "Variable's rank: %s vs. %s" % (slicing, shape)) 2548 if any(p < 1 for p in slicing): 2549 raise ValueError("Partitioner returned zero partitions for some axes: %s" % 2550 slicing) 2551 if sum(p > 1 for p in slicing) > 1: 2552 raise ValueError("Can only slice a variable along one dimension: " 2553 "shape: %s, partitioning: %s" % (shape, slicing)) 2554 return slicing 2555 2556 2557# TODO(slebedev): could be inlined, but 2558# `_VariableStore._get_partitioned_variable` is too complex even 2559# without this logic. 2560def _get_slice_dim_and_num_slices(slicing): 2561 """Get slicing dimension and number of slices from the partitioner output.""" 2562 for slice_dim, num_slices in enumerate(slicing): 2563 if num_slices > 1: 2564 break 2565 else: 2566 # Degenerate case: no partitioning applied. 2567 slice_dim = 0 2568 num_slices = 1 2569 return slice_dim, num_slices 2570 2571 2572def _iter_slices(full_shape, num_slices, slice_dim): 2573 """Slices a given a shape along the specified dimension.""" 2574 num_slices_with_excess = full_shape[slice_dim] % num_slices 2575 offset = [0] * len(full_shape) 2576 min_slice_len = full_shape[slice_dim] // num_slices 2577 for i in xrange(num_slices): 2578 shape = full_shape[:] 2579 shape[slice_dim] = min_slice_len + bool(i < num_slices_with_excess) 2580 yield offset[:], shape 2581 offset[slice_dim] += shape[slice_dim] 2582 2583 2584def default_variable_creator(next_creator=None, **kwargs): 2585 """Default variable creator.""" 2586 assert next_creator is None 2587 initial_value = kwargs.get("initial_value", None) 2588 trainable = kwargs.get("trainable", None) 2589 collections = kwargs.get("collections", None) 2590 validate_shape = kwargs.get("validate_shape", True) 2591 caching_device = kwargs.get("caching_device", None) 2592 name = kwargs.get("name", None) 2593 variable_def = kwargs.get("variable_def", None) 2594 dtype = kwargs.get("dtype", None) 2595 expected_shape = kwargs.get("expected_shape", None) 2596 import_scope = kwargs.get("import_scope", None) 2597 constraint = kwargs.get("constraint", None) 2598 use_resource = kwargs.get("use_resource", None) 2599 synchronization = kwargs.get("synchronization", None) 2600 aggregation = kwargs.get("aggregation", None) 2601 shape = kwargs.get("shape", None) 2602 2603 if use_resource is None: 2604 use_resource = get_variable_scope().use_resource 2605 if use_resource is None: 2606 use_resource = _DEFAULT_USE_RESOURCE 2607 use_resource = use_resource or context.executing_eagerly() 2608 if use_resource: 2609 distribute_strategy = kwargs.get("distribute_strategy", None) 2610 return resource_variable_ops.ResourceVariable( 2611 initial_value=initial_value, 2612 trainable=trainable, 2613 collections=collections, 2614 validate_shape=validate_shape, 2615 caching_device=caching_device, 2616 name=name, 2617 dtype=dtype, 2618 constraint=constraint, 2619 variable_def=variable_def, 2620 import_scope=import_scope, 2621 distribute_strategy=distribute_strategy, 2622 synchronization=synchronization, 2623 aggregation=aggregation, 2624 shape=shape) 2625 else: 2626 return variables.RefVariable( 2627 initial_value=initial_value, 2628 trainable=trainable, 2629 collections=collections, 2630 validate_shape=validate_shape, 2631 caching_device=caching_device, 2632 name=name, 2633 dtype=dtype, 2634 constraint=constraint, 2635 variable_def=variable_def, 2636 expected_shape=expected_shape, 2637 import_scope=import_scope, 2638 synchronization=synchronization, 2639 aggregation=aggregation, 2640 shape=shape) 2641 2642 2643def default_variable_creator_v2(next_creator=None, **kwargs): 2644 """Default variable creator.""" 2645 assert next_creator is None 2646 initial_value = kwargs.get("initial_value", None) 2647 trainable = kwargs.get("trainable", None) 2648 validate_shape = kwargs.get("validate_shape", True) 2649 caching_device = kwargs.get("caching_device", None) 2650 name = kwargs.get("name", None) 2651 variable_def = kwargs.get("variable_def", None) 2652 dtype = kwargs.get("dtype", None) 2653 import_scope = kwargs.get("import_scope", None) 2654 constraint = kwargs.get("constraint", None) 2655 distribute_strategy = kwargs.get("distribute_strategy", None) 2656 synchronization = kwargs.get("synchronization", None) 2657 aggregation = kwargs.get("aggregation", None) 2658 shape = kwargs.get("shape", None) 2659 2660 return resource_variable_ops.ResourceVariable( 2661 initial_value=initial_value, 2662 trainable=trainable, 2663 validate_shape=validate_shape, 2664 caching_device=caching_device, 2665 name=name, 2666 dtype=dtype, 2667 constraint=constraint, 2668 variable_def=variable_def, 2669 import_scope=import_scope, 2670 distribute_strategy=distribute_strategy, 2671 synchronization=synchronization, 2672 aggregation=aggregation, 2673 shape=shape) 2674 2675 2676variables.default_variable_creator = default_variable_creator 2677variables.default_variable_creator_v2 = default_variable_creator_v2 2678 2679 2680def _make_getter(captured_getter, captured_previous): 2681 """Gets around capturing loop variables in python being broken.""" 2682 return lambda **kwargs: captured_getter(captured_previous, **kwargs) 2683 2684 2685# TODO(apassos) remove forwarding symbol 2686variable = variables.VariableV1 2687 2688 2689@tf_export(v1=["variable_creator_scope"]) 2690@tf_contextlib.contextmanager 2691def variable_creator_scope_v1(variable_creator): 2692 """Scope which defines a variable creation function to be used by variable(). 2693 2694 variable_creator is expected to be a function with the following signature: 2695 2696 ``` 2697 def variable_creator(next_creator, **kwargs) 2698 ``` 2699 2700 The creator is supposed to eventually call the next_creator to create a 2701 variable if it does want to create a variable and not call Variable or 2702 ResourceVariable directly. This helps make creators composable. A creator may 2703 choose to create multiple variables, return already existing variables, or 2704 simply register that a variable was created and defer to the next creators in 2705 line. Creators can also modify the keyword arguments seen by the next 2706 creators. 2707 2708 Custom getters in the variable scope will eventually resolve down to these 2709 custom creators when they do create variables. 2710 2711 The valid keyword arguments in kwds are: 2712 2713 * initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 2714 which is the initial value for the Variable. The initial value must have 2715 a shape specified unless `validate_shape` is set to False. Can also be a 2716 callable with no argument that returns the initial value when called. In 2717 that case, `dtype` must be specified. (Note that initializer functions 2718 from init_ops.py must first be bound to a shape before being used here.) 2719 * trainable: If `True`, the default, also adds the variable to the graph 2720 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 2721 the default list of variables to use by the `Optimizer` classes. 2722 `trainable` defaults to `True`, unless `synchronization` is 2723 set to `ON_READ`, in which case it defaults to `False`. 2724 * collections: List of graph collections keys. The new variable is added to 2725 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 2726 * validate_shape: If `False`, allows the variable to be initialized with a 2727 value of unknown shape. If `True`, the default, the shape of 2728 `initial_value` must be known. 2729 * caching_device: Optional device string describing where the Variable 2730 should be cached for reading. Defaults to the Variable's device. 2731 If not `None`, caches on another device. Typical use is to cache 2732 on the device where the Ops using the Variable reside, to deduplicate 2733 copying through `Switch` and other conditional statements. 2734 * name: Optional name for the variable. Defaults to `'Variable'` and gets 2735 uniquified automatically. 2736 * dtype: If set, initial_value will be converted to the given type. 2737 If `None`, either the datatype will be kept (if `initial_value` is 2738 a Tensor), or `convert_to_tensor` will decide. 2739 * constraint: A constraint function to be applied to the variable after 2740 updates by some algorithms. 2741 * use_resource: if True, a ResourceVariable is always created. 2742 * synchronization: Indicates when a distributed a variable will be 2743 aggregated. Accepted values are constants defined in the class 2744 `tf.VariableSynchronization`. By default the synchronization is set to 2745 `AUTO` and the current `DistributionStrategy` chooses 2746 when to synchronize. 2747 * aggregation: Indicates how a distributed variable will be aggregated. 2748 Accepted values are constants defined in the class 2749 `tf.VariableAggregation`. 2750 2751 This set may grow over time, so it's important the signature of creators is as 2752 mentioned above. 2753 2754 Args: 2755 variable_creator: the passed creator 2756 2757 Yields: 2758 A scope in which the creator is active 2759 """ 2760 with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access 2761 yield 2762 2763 2764# Note: only the docstrings differ between this and v1. 2765@tf_export("variable_creator_scope", v1=[]) 2766@tf_contextlib.contextmanager 2767def variable_creator_scope(variable_creator): 2768 """Scope which defines a variable creation function to be used by variable(). 2769 2770 variable_creator is expected to be a function with the following signature: 2771 2772 ``` 2773 def variable_creator(next_creator, **kwargs) 2774 ``` 2775 2776 The creator is supposed to eventually call the next_creator to create a 2777 variable if it does want to create a variable and not call Variable or 2778 ResourceVariable directly. This helps make creators composable. A creator may 2779 choose to create multiple variables, return already existing variables, or 2780 simply register that a variable was created and defer to the next creators in 2781 line. Creators can also modify the keyword arguments seen by the next 2782 creators. 2783 2784 Custom getters in the variable scope will eventually resolve down to these 2785 custom creators when they do create variables. 2786 2787 The valid keyword arguments in kwds are: 2788 2789 * initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 2790 which is the initial value for the Variable. The initial value must have 2791 a shape specified unless `validate_shape` is set to False. Can also be a 2792 callable with no argument that returns the initial value when called. In 2793 that case, `dtype` must be specified. (Note that initializer functions 2794 from init_ops.py must first be bound to a shape before being used here.) 2795 * trainable: If `True`, the default, GradientTapes automatically watch 2796 uses of this Variable. 2797 * validate_shape: If `False`, allows the variable to be initialized with a 2798 value of unknown shape. If `True`, the default, the shape of 2799 `initial_value` must be known. 2800 * caching_device: Optional device string describing where the Variable 2801 should be cached for reading. Defaults to the Variable's device. 2802 If not `None`, caches on another device. Typical use is to cache 2803 on the device where the Ops using the Variable reside, to deduplicate 2804 copying through `Switch` and other conditional statements. 2805 * name: Optional name for the variable. Defaults to `'Variable'` and gets 2806 uniquified automatically. 2807 dtype: If set, initial_value will be converted to the given type. 2808 If `None`, either the datatype will be kept (if `initial_value` is 2809 a Tensor), or `convert_to_tensor` will decide. 2810 * constraint: A constraint function to be applied to the variable after 2811 updates by some algorithms. 2812 * synchronization: Indicates when a distributed a variable will be 2813 aggregated. Accepted values are constants defined in the class 2814 `tf.VariableSynchronization`. By default the synchronization is set to 2815 `AUTO` and the current `DistributionStrategy` chooses 2816 when to synchronize. 2817 * aggregation: Indicates how a distributed variable will be aggregated. 2818 Accepted values are constants defined in the class 2819 `tf.VariableAggregation`. 2820 2821 This set may grow over time, so it's important the signature of creators is as 2822 mentioned above. 2823 2824 Args: 2825 variable_creator: the passed creator 2826 2827 Yields: 2828 A scope in which the creator is active 2829 """ 2830 with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access 2831 yield 2832