1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A class to store named variables and a scope operator to manage sharing.""" 16 17import copy 18import enum 19import functools 20import sys 21import threading 22import traceback 23 24from tensorflow.python import tf2 25from tensorflow.python.client import session 26from tensorflow.python.eager import context 27from tensorflow.python.eager import monitoring 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import init_ops 33from tensorflow.python.ops import resource_variable_ops 34from tensorflow.python.ops import variables 35from tensorflow.python.platform import tf_logging as logging 36from tensorflow.python.types import core 37from tensorflow.python.util import deprecation 38from tensorflow.python.util import function_utils 39from tensorflow.python.util import tf_contextlib 40from tensorflow.python.util import tf_inspect 41from tensorflow.python.util.compat import collections_abc 42from tensorflow.python.util.tf_export import tf_export 43 44__all__ = [ 45 "AUTO_REUSE", "VariableScope", "get_variable_scope", "get_variable", 46 "get_local_variable", "variable_scope", "variable_op_scope", 47 "no_regularizer", "VariableSynchronization", "VariableAggregation" 48] 49 50_api_usage_gauge = monitoring.BoolGauge( 51 "/tensorflow/api/resource_variables", 52 "Whether variable_scope.enable_resource_variables() is called.") 53 54 55class _PartitionInfo: 56 """Holds partition info used by initializer functions.""" 57 58 __slots__ = ["_full_shape", "_var_offset"] 59 60 def __init__(self, full_shape, var_offset): 61 """Constructor. 62 63 Args: 64 full_shape: Tuple or list of `int` indicating the full combined shape of 65 the partitioned variables. 66 var_offset: Tuple or list of `int` specifying offset of this partition 67 with respect to the full variable for each dimension. 68 69 Raises: 70 TypeError: If `full_shape` or `var_offset` is not a sequence. 71 ValueError: If `full_shape` or `var_offset` differ in length. If 72 `var_offset` exceeds `full_shape` in any dimension. 73 """ 74 if not isinstance(full_shape, (list, tuple)): 75 raise TypeError( 76 "`full_shape` must be a sequence (like tuple or list) instead of " + 77 type(full_shape).__name__) 78 79 if not isinstance(var_offset, (list, tuple)): 80 raise TypeError( 81 "`var_offset` must be a sequence (like tuple or list) instead of " + 82 type(var_offset).__name__) 83 84 if len(var_offset) != len(full_shape): 85 raise ValueError( 86 "Expected equal length, but `var_offset` is of length {} while " 87 "full_shape is of length {}.".format( 88 len(var_offset), len(full_shape))) 89 90 for offset, shape in zip(var_offset, full_shape): 91 if offset < 0 or offset >= shape: 92 raise ValueError( 93 "Expected 0 <= offset < shape but found offset={}, shape={} for " 94 "var_offset={}, full_shape={}".format(offset, shape, var_offset, 95 full_shape)) 96 97 self._full_shape = full_shape 98 self._var_offset = var_offset 99 100 @property 101 def full_shape(self): 102 return self._full_shape 103 104 @property 105 def var_offset(self): 106 return self._var_offset 107 108 def single_offset(self, shape): 109 """Returns the offset when the variable is partitioned in at most one dim. 110 111 Args: 112 shape: Tuple or list of `int` indicating the shape of one specific 113 variable partition. 114 115 Returns: 116 `int` representing the offset in the dimension along which the variable is 117 partitioned. Returns 0 if the variable is not being partitioned. 118 119 Raises: 120 ValueError: Depending on self.single_slice_dim(). 121 """ 122 123 single_slice_dim = self.single_slice_dim(shape) 124 # If this variable is not being partitioned at all, single_slice_dim() could 125 # return None. 126 if single_slice_dim is None: 127 return 0 128 return self.var_offset[single_slice_dim] 129 130 def single_slice_dim(self, shape): 131 """Returns the slice dim when the variable is partitioned only in one dim. 132 133 Args: 134 shape: Tuple or list of `int` indicating the shape of one specific 135 variable partition. 136 137 Returns: 138 `int` representing the dimension that the variable is partitioned in, or 139 `None` if the variable doesn't seem to be partitioned at all. 140 141 Raises: 142 TypeError: If `shape` is not a sequence. 143 ValueError: If `shape` is not the same length as `self.full_shape`. If 144 the variable is partitioned in more than one dimension. 145 """ 146 if not isinstance(shape, (tuple, list)): 147 raise TypeError( 148 "`shape` must be a sequence (like tuple or list) instead of " + 149 type(shape).__name__) 150 151 if len(shape) != len(self.full_shape): 152 raise ValueError( 153 "Expected equal length, but received shape={} of length {} while " 154 "self.full_shape={} is of length {}.".format(shape, len(shape), 155 self.full_shape, 156 len(self.full_shape))) 157 158 for i in range(len(shape)): 159 if self.var_offset[i] + shape[i] > self.full_shape[i]: 160 raise ValueError( 161 "With self.var_offset={}, a partition of shape={} would exceed " 162 "self.full_shape={} in dimension {}.".format( 163 self.var_offset, shape, self.full_shape, i)) 164 165 slice_dim = None 166 for i in range(len(shape)): 167 if shape[i] == self.full_shape[i]: 168 continue 169 if slice_dim is not None: 170 raise ValueError( 171 "Cannot use single_slice_dim() with shape={} and " 172 "self.full_shape={} since slice dim could be either dimension {} " 173 "or {}.".format(shape, self.full_shape, i, slice_dim)) 174 slice_dim = i 175 176 return slice_dim 177 178 179class _ReuseMode(enum.Enum): 180 """Mode for variable access within a variable scope.""" 181 182 # Indicates that variables are to be fetched if they already exist or 183 # otherwise created. 184 AUTO_REUSE = 1 185 186 # TODO(alive): For TensorFlow 2.0, Deprecate True/False/None API in favor of 187 # enum values. 188 # REUSE_FALSE = 2 189 # REUSE_TRUE = 3 190 191 192# TODO(apassos) remove these forwarding symbols. 193VariableSynchronization = variables.VariableSynchronization # pylint: disable=invalid-name 194VariableAggregation = variables.VariableAggregation # pylint: disable=invalid-name 195 196AUTO_REUSE = _ReuseMode.AUTO_REUSE 197tf_export(v1=["AUTO_REUSE"]).export_constant(__name__, "AUTO_REUSE") 198AUTO_REUSE.__doc__ = """ 199@compatibility(TF2) 200`tf.compat.v1.AUTO_REUSE` is a legacy API that is a no-op when TF2 behaviors 201are enabled. 202 203If you rely on `get_variable` and auto-reuse, see the 204[model mapping guide](https://www.tensorflow.org/guide/migrate/model_mapping) 205for more info on how to migrate your code. 206 207Note: when you use the `tf.compat.v1.keras.utils.track_tf1_style_variables` 208API as described in the above guide, `get_variable` will always behave as if 209`v1.AUTO_REUSE` is set. Without the decorator, reuse will be ignored and new 210variables will always be created, regardless of if they have already been 211created. 212@end_compatibility 213 214When passed in as the value for the `reuse` flag, `AUTO_REUSE` indicates that 215get_variable() should create the requested variable if it doesn't exist or, if 216it does exist, simply return it. 217""" 218 219_DEFAULT_USE_RESOURCE = tf2.enabled() 220 221 222@tf_export(v1=["enable_resource_variables"]) 223def enable_resource_variables(): 224 """Creates resource variables by default. 225 226 Resource variables are improved versions of TensorFlow variables with a 227 well-defined memory model. Accessing a resource variable reads its value, and 228 all ops which access a specific read value of the variable are guaranteed to 229 see the same value for that tensor. Writes which happen after a read (by 230 having a control or data dependency on the read) are guaranteed not to affect 231 the value of the read tensor, and similarly writes which happen before a read 232 are guaranteed to affect the value. No guarantees are made about unordered 233 read/write pairs. 234 235 Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0 236 feature. 237 """ 238 global _DEFAULT_USE_RESOURCE 239 _DEFAULT_USE_RESOURCE = True 240 logging.vlog(1, "Enabling resource variables") 241 _api_usage_gauge.get_cell().set(True) 242 243 244@tf_export(v1=["resource_variables_enabled"]) 245def resource_variables_enabled(): 246 """Returns `True` if resource variables are enabled. 247 248 Resource variables are improved versions of TensorFlow variables with a 249 well-defined memory model. Accessing a resource variable reads its value, and 250 all ops which access a specific read value of the variable are guaranteed to 251 see the same value for that tensor. Writes which happen after a read (by 252 having a control or data dependency on the read) are guaranteed not to affect 253 the value of the read tensor, and similarly writes which happen before a read 254 are guaranteed to affect the value. No guarantees are made about unordered 255 read/write pairs. 256 257 Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0 258 feature. 259 """ 260 global _DEFAULT_USE_RESOURCE 261 return _DEFAULT_USE_RESOURCE 262 263 264@deprecation.deprecated( 265 None, "non-resource variables are not supported in the long term") 266@tf_export(v1=["disable_resource_variables"]) 267def disable_resource_variables(): 268 """Opts out of resource variables. 269 270 If your code needs tf.disable_resource_variables() to be called to work 271 properly please file a bug. 272 """ 273 global _DEFAULT_USE_RESOURCE 274 _DEFAULT_USE_RESOURCE = False 275 logging.vlog(1, "Disabling resource variables") 276 _api_usage_gauge.get_cell().set(False) 277 278 279def _needs_no_arguments(python_callable): 280 """Returns true if the callable needs no arguments to call.""" 281 # TODO(bfontain): Switch to inspect.signature when we are python 3 only. 282 # signature = inspect.signature(python_callable) 283 # return not [1 for param in signature.parameters.values() 284 # if param.default == param.empty] 285 num_arguments = len(tf_inspect.getargspec(python_callable).args) 286 if not tf_inspect.isfunction(python_callable) and not isinstance( 287 python_callable, functools.partial): 288 # getargspec includes self for function objects (which aren't 289 # functools.partial). This has no default so we need to remove it. 290 # It is not even an argument so its odd that getargspec returns this. 291 # Note that this is fixed with inspect.signature in Python 3. 292 num_arguments -= 1 293 return num_arguments == len( 294 tf_inspect.getargspec(python_callable).defaults or []) 295 296 297class _VariableStore: 298 """Variable store that carries a number of named Variables. 299 300 New variable names and new variables can be created; all stored 301 variables are initialized with the initializer passed to __init__. 302 303 Attributes: 304 vars: a dictionary with string names (same as passed in GetVar) as keys and 305 the corresponding TensorFlow Variables as values. 306 """ 307 308 __slots__ = ["_vars", "_partitioned_vars", "_store_eager_variables"] 309 310 def __init__(self): 311 """Create a variable store.""" 312 self._vars = {} # A dictionary of the stored TensorFlow variables. 313 self._partitioned_vars = {} # A dict of the stored PartitionedVariables. 314 self._store_eager_variables = False 315 316 def get_variable(self, 317 name, 318 shape=None, 319 dtype=dtypes.float32, 320 initializer=None, 321 regularizer=None, 322 reuse=None, 323 trainable=None, 324 collections=None, 325 caching_device=None, 326 partitioner=None, 327 validate_shape=True, 328 use_resource=None, 329 custom_getter=None, 330 constraint=None, 331 synchronization=VariableSynchronization.AUTO, 332 aggregation=VariableAggregation.NONE): 333 """Gets an existing variable with these parameters or create a new one. 334 335 If a variable with the given name is already stored, we return the stored 336 variable. Otherwise, we create a new one. 337 338 Set `reuse` to `True` when you only want to reuse existing Variables. 339 Set `reuse` to `False` when you only want to create new Variables. 340 Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want 341 variables to be created if they don't exist or returned if they do. 342 343 If initializer is `None` (the default), the default initializer passed in 344 the constructor is used. If that one is `None` too, we use a new 345 `glorot_uniform_initializer`. If initializer is a Tensor, we use 346 it as a value and derive the shape from the initializer. 347 348 If a partitioner is provided, a `PartitionedVariable` is returned. 349 Accessing this object as a `Tensor` returns the shards concatenated along 350 the partition axis. 351 352 Some useful partitioners are available. See, e.g., 353 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 354 355 Args: 356 name: The name of the new or existing variable. 357 shape: Shape of the new or existing variable. 358 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 359 initializer: Initializer for the variable. 360 regularizer: A (Tensor -> Tensor or None) function; the result of applying 361 it on a newly created variable will be added to the collection 362 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 363 reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of 364 variables. When eager execution is enabled this argument is always 365 forced to be False. 366 trainable: If `True` also add the variable to the graph collection 367 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). `trainable` 368 defaults to `True`, unless `synchronization` is set to `ON_READ`, in 369 which case it defaults to `False`. 370 collections: List of graph collections keys to add the `Variable` to. 371 Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 372 caching_device: Optional device string or function describing where the 373 Variable should be cached for reading. Defaults to the Variable's 374 device. If not `None`, caches on another device. Typical use is to 375 cache on the device where the Ops using the `Variable` reside, to 376 deduplicate copying through `Switch` and other conditional statements. 377 partitioner: Optional callable that accepts a fully defined `TensorShape` 378 and dtype of the `Variable` to be created, and returns a list of 379 partitions for each axis (currently only one axis can be partitioned). 380 validate_shape: If False, allows the variable to be initialized with a 381 value of unknown shape. If True, the default, the shape of initial_value 382 must be known. 383 use_resource: If False, creates a regular Variable. If True, creates 384 instead an experimental ResourceVariable which has well-defined 385 semantics. Defaults to False (will later change to True). When eager 386 execution is enabled this argument is always forced to be true. 387 custom_getter: Callable that takes as a first argument the true getter, 388 and allows overwriting the internal get_variable method. The signature 389 of `custom_getter` should match that of this method, 390 but the most future-proof version will allow for changes: `def 391 custom_getter(getter, *args, **kwargs)`. Direct access to 392 all `get_variable` parameters is also allowed: `def 393 custom_getter(getter, name, *args, **kwargs)`. A simple identity 394 custom getter that simply creates variables with modified names is: 395 ```python 396 def custom_getter(getter, name, *args, **kwargs): return getter(name + 397 '_suffix', *args, **kwargs) ``` 398 constraint: An optional projection function to be applied to the variable 399 after being updated by an `Optimizer` (e.g. used to implement norm 400 constraints or value constraints for layer weights). The function must 401 take as input the unprojected Tensor representing the value of the 402 variable and return the Tensor for the projected value (which must have 403 the same shape). Constraints are not safe to use when doing asynchronous 404 distributed training. 405 synchronization: Indicates when a distributed a variable will be 406 aggregated. Accepted values are constants defined in the class 407 `tf.VariableSynchronization`. By default the synchronization is set to 408 `AUTO` and the current `DistributionStrategy` chooses when to 409 synchronize. 410 aggregation: Indicates how a distributed variable will be aggregated. 411 Accepted values are constants defined in the class 412 `tf.VariableAggregation`. 413 414 Returns: 415 The created or existing `Variable` (or `PartitionedVariable`, if a 416 partitioner was used). 417 418 Raises: 419 ValueError: when creating a new variable and shape is not declared, 420 when reusing a variable and specifying a conflicting shape, 421 or when violating reuse during variable creation. 422 RuntimeError: when eager execution is enabled and not called from an 423 EagerVariableStore. 424 """ 425 if custom_getter is not None and not callable(custom_getter): 426 raise ValueError("Passed a custom_getter which is not callable: %s" % 427 custom_getter) 428 429 with ops.init_scope(): 430 if context.executing_eagerly(): 431 # Variable creation and initialization takes place in `init_scope`s; 432 # as such, if an `init_scope` lifts us into the eager context, then we 433 # need to use `ResourceVariable`s. 434 use_resource = True 435 436 # Note that it's fine to reuse eager variables whose initialization was 437 # lifted from a function-building graph into the eager context (that's why 438 # the following clause is not wrapped in an `init_scope`); lifted variables 439 # are tracked by the graph's `VariableStore`. 440 if context.executing_eagerly(): 441 if not self._store_eager_variables and reuse: 442 raise RuntimeError( 443 "When eager execution is enabled variable reuse is only supported" 444 " when an EagerVariableStore is active. See the documentation on" 445 " EagerVariableStore for example usage.") 446 if self._store_eager_variables: 447 reuse = AUTO_REUSE 448 449 # If a *_ref type is passed in an error would be triggered further down the 450 # stack. We prevent this using base_dtype to get a non-ref version of the 451 # type, before doing anything else. When _ref types are removed in favor of 452 # resources, this line can be removed. 453 try: 454 dtype = dtype.base_dtype 455 except AttributeError: 456 # .base_dtype not existing means that we will try and use the raw dtype 457 # which was passed in - this might be a NumPy type which is valid. 458 pass 459 460 # This is the main logic of get_variable. However, custom_getter 461 # may override this logic. So we save it as a callable and pass 462 # it to custom_getter. 463 # Note: the parameters of _true_getter, and their documentation, match 464 # *exactly* item-for-item with the docstring of this method. 465 def _true_getter( # pylint: disable=missing-docstring 466 name, 467 shape=None, 468 dtype=dtypes.float32, 469 initializer=None, 470 regularizer=None, 471 reuse=None, 472 trainable=None, 473 collections=None, 474 caching_device=None, 475 partitioner=None, 476 validate_shape=True, 477 use_resource=None, 478 constraint=None, 479 synchronization=VariableSynchronization.AUTO, 480 aggregation=VariableAggregation.NONE): 481 is_scalar = ( 482 shape is not None and isinstance(shape, collections_abc.Sequence) and 483 not shape) 484 # Partitioned variable case 485 if partitioner is not None and not is_scalar: 486 if not callable(partitioner): 487 raise ValueError("Partitioner must be callable, but received: %s" % 488 partitioner) 489 with ops.name_scope(None): 490 return self._get_partitioned_variable( 491 name=name, 492 shape=shape, 493 dtype=dtype, 494 initializer=initializer, 495 regularizer=regularizer, 496 reuse=reuse, 497 trainable=trainable, 498 collections=collections, 499 caching_device=caching_device, 500 partitioner=partitioner, 501 validate_shape=validate_shape, 502 use_resource=use_resource, 503 constraint=constraint, 504 synchronization=synchronization, 505 aggregation=aggregation) 506 507 # Special case for partitioned variable to allow reuse without having to 508 # specify partitioner. 509 if (reuse is True and partitioner is None 510 and name in self._partitioned_vars): 511 return self._get_partitioned_variable( 512 name=name, 513 shape=shape, 514 dtype=dtype, 515 initializer=initializer, 516 regularizer=regularizer, 517 reuse=reuse, 518 trainable=trainable, 519 collections=collections, 520 caching_device=caching_device, 521 partitioner=None, 522 validate_shape=validate_shape, 523 use_resource=use_resource, 524 constraint=constraint, 525 synchronization=synchronization, 526 aggregation=aggregation) 527 528 # Single variable case 529 if "%s/part_0" % name in self._vars: 530 raise ValueError( 531 "No partitioner was provided, but a partitioned version of the " 532 "variable was found: %s/part_0. Perhaps a variable of the same " 533 "name was already created with partitioning?" % name) 534 535 return self._get_single_variable( 536 name=name, 537 shape=shape, 538 dtype=dtype, 539 initializer=initializer, 540 regularizer=regularizer, 541 reuse=reuse, 542 trainable=trainable, 543 collections=collections, 544 caching_device=caching_device, 545 validate_shape=validate_shape, 546 use_resource=use_resource, 547 constraint=constraint, 548 synchronization=synchronization, 549 aggregation=aggregation) 550 551 synchronization, aggregation, trainable = ( 552 variables.validate_synchronization_aggregation_trainable( 553 synchronization, aggregation, trainable, name)) 554 555 if custom_getter is not None: 556 # Handle backwards compatibility with getter arguments that were added 557 # to the API after users started writing custom getters. 558 custom_getter_kwargs = { 559 "getter": _true_getter, 560 "name": name, 561 "shape": shape, 562 "dtype": dtype, 563 "initializer": initializer, 564 "regularizer": regularizer, 565 "reuse": reuse, 566 "trainable": trainable, 567 "collections": collections, 568 "caching_device": caching_device, 569 "partitioner": partitioner, 570 "validate_shape": validate_shape, 571 "use_resource": use_resource, 572 "synchronization": synchronization, 573 "aggregation": aggregation, 574 } 575 # `fn_args` and `has_kwargs` can handle functions, `functools.partial`, 576 # `lambda`. 577 if ("constraint" in function_utils.fn_args(custom_getter) or 578 function_utils.has_kwargs(custom_getter)): 579 custom_getter_kwargs["constraint"] = constraint 580 return custom_getter(**custom_getter_kwargs) 581 else: 582 return _true_getter( 583 name, 584 shape=shape, 585 dtype=dtype, 586 initializer=initializer, 587 regularizer=regularizer, 588 reuse=reuse, 589 trainable=trainable, 590 collections=collections, 591 caching_device=caching_device, 592 partitioner=partitioner, 593 validate_shape=validate_shape, 594 use_resource=use_resource, 595 constraint=constraint, 596 synchronization=synchronization, 597 aggregation=aggregation) 598 599 def _get_partitioned_variable(self, 600 name, 601 partitioner, 602 shape=None, 603 dtype=dtypes.float32, 604 initializer=None, 605 regularizer=None, 606 reuse=None, 607 trainable=None, 608 collections=None, 609 caching_device=None, 610 validate_shape=True, 611 use_resource=None, 612 constraint=None, 613 synchronization=VariableSynchronization.AUTO, 614 aggregation=VariableAggregation.NONE): 615 """Gets or creates a sharded variable list with these parameters. 616 617 The `partitioner` must be a callable that accepts a fully defined 618 `TensorShape` and returns a sequence of integers (the `partitions`). 619 These integers describe how to partition the given sharded `Variable` 620 along the given dimension. That is, `partitions[1] = 3` means split 621 the `Variable` into 3 shards along dimension 1. Currently, sharding along 622 only one axis is supported. 623 624 If the list of variables with the given name (prefix) is already stored, 625 we return the stored variables. Otherwise, we create a new one. 626 627 Set `reuse` to `True` when you only want to reuse existing Variables. 628 Set `reuse` to `False` when you only want to create new Variables. 629 Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want 630 variables to be created if they don't exist or returned if they do. 631 632 If initializer is `None` (the default), the default initializer passed in 633 the constructor is used. If that one is `None` too, we use a new 634 `glorot_uniform_initializer`. If initializer is a Tensor, we use 635 it as a value and derive the shape from the initializer. 636 637 If the initializer is a callable, then it will be called for each 638 shard. Otherwise the initializer should match the shape of the entire 639 sharded Variable, and it will be sliced accordingly for each shard. 640 641 Some useful partitioners are available. See, e.g., 642 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 643 644 Args: 645 name: the name of the new or existing sharded variable. 646 partitioner: Optional callable that accepts a fully defined `TensorShape` 647 and `dtype` of the Variable to be created, and returns a list of 648 partitions for each axis (currently only one axis can be partitioned). 649 shape: shape of the new or existing sharded variable. 650 dtype: type of the new or existing sharded variable (defaults to 651 `DT_FLOAT`). 652 initializer: initializer for the sharded variable. 653 regularizer: a (Tensor -> Tensor or None) function; the result of applying 654 it on a newly created variable will be added to the collection 655 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 656 reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of 657 variables. 658 trainable: If `True` also add the variable to the graph collection 659 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 660 collections: List of graph collections keys to add the Variable to. 661 Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 662 caching_device: Optional device string or function describing where the 663 Variable should be cached for reading. Defaults to the Variable's 664 device. If not `None`, caches on another device. Typical use is to 665 cache on the device where the Ops using the Variable reside, to 666 deduplicate copying through `Switch` and other conditional statements. 667 validate_shape: If False, allows the variable to be initialized with a 668 value of unknown shape. If True, the default, the shape of initial_value 669 must be known. 670 use_resource: If False, creates a regular Variable. If True, creates an 671 experimental ResourceVariable which has well-defined semantics. Defaults 672 to False (will later change to True). 673 constraint: An optional projection function to be applied to the variable 674 after being updated by an `Optimizer` (e.g. used to implement norm 675 constraints or value constraints for layer weights). The function must 676 take as input the unprojected Tensor representing the value of the 677 variable and return the Tensor for the projected value (which must have 678 the same shape). Constraints are not safe to use when doing asynchronous 679 distributed training. 680 synchronization: Indicates when a distributed a variable will be 681 aggregated. Accepted values are constants defined in the class 682 `tf.VariableSynchronization`. By default the synchronization is set to 683 `AUTO` and the current `DistributionStrategy` chooses when to 684 synchronize. 685 aggregation: Indicates how a distributed variable will be aggregated. 686 Accepted values are constants defined in the class 687 `tf.VariableAggregation`. 688 689 Returns: 690 A `PartitionedVariable` object. 691 692 Raises: 693 ValueError: when creating a new variable and shape is not declared, 694 when reusing a variable and specifying a conflicting shape, 695 when violating reuse during variable creation, or if an existing 696 sharded variable exists for the given name but with different sharding. 697 """ 698 initializing_from_value = initializer is not None and isinstance( 699 initializer, ops.Tensor) 700 if name in self._vars: 701 raise ValueError( 702 "A partitioner was provided, but an unpartitioned version of the " 703 "variable was found: %s. Perhaps a variable of the same name was " 704 "already created without partitioning?" % name) 705 706 shape = tensor_shape.as_shape(shape) 707 if initializing_from_value: 708 shape = shape.merge_with(initializer.get_shape()) 709 710 partitions = None 711 if not reuse or partitioner: 712 partitions = _call_partitioner(partitioner, shape, dtype) 713 714 if name in self._partitioned_vars: 715 if reuse is False: 716 raise ValueError( 717 "Partitioned variable with name %s already exists. Did you mean to " 718 "set reuse=True or reuse=tf.AUTO_REUSE in VarScope?" % name) 719 720 existing_var = self._partitioned_vars[name] 721 if not shape.is_compatible_with(existing_var.get_shape()): 722 raise ValueError( 723 "Trying to reuse partitioned variable %s, but specified shape %s " 724 "and found shape %s." % (name, shape, existing_var.get_shape())) 725 if not dtype.is_compatible_with(existing_var.dtype): 726 raise ValueError( 727 "Trying to reuse partitioned variable %s, but specified dtype %s " 728 "and found dtype %s." % (name, dtype.name, existing_var.dtype.name)) 729 730 # pylint: disable=protected-access 731 if (partitions is not None and 732 existing_var._get_partitions() != partitions): 733 raise ValueError( 734 "Trying to reuse partitioned variable %s, but specified partitions " 735 "%s and found partitions %s." % 736 (name, partitions, existing_var._get_partitions())) 737 # pylint: enable=protected-access 738 739 return existing_var 740 741 if reuse is True: 742 raise ValueError("PartitionedVariable %s does not exist, or was not " 743 "created with tf.get_variable(). Did you mean to set " 744 "reuse=False or reuse=tf.AUTO_REUSE in VarScope?" % name) 745 746 slice_dim, num_slices = _get_slice_dim_and_num_slices(partitions) 747 748 if "%s/part_0" % name in self._vars: 749 if "%s/part_%d" % (name, num_slices - 1) not in self._vars: 750 raise ValueError( 751 "Partitioner returned a different partitioning than what was " 752 "already found. Partitioner returned %d shards, and shard " 753 "%s/part_0 was found, but %s/part_%d was not." % 754 (num_slices, name, name, num_slices - 1)) 755 if "%s/part_%d" % (name, num_slices) in self._vars: 756 raise ValueError( 757 "Partitioner returned a different partitioning than what was " 758 "already found. Partitioner returned %d shards, and shard " 759 "%s/part_0 was found, but so was the extra shard %s/part_%d." % 760 (num_slices, name, name, num_slices)) 761 762 vs = [] 763 for i, (var_offset, var_shape) in enumerate( 764 _iter_slices(shape.as_list(), num_slices, slice_dim)): 765 partition_info = _PartitionInfo( 766 full_shape=shape.as_list(), var_offset=var_offset) 767 var_full_name = "%s/part_%d" % (name, i) 768 with ops.name_scope( 769 var_full_name + "/PartitionedInitializer", skip_on_eager=False): 770 # Create the tensor to initialize the variable with default value. 771 if initializer is None: 772 init, initializing_from_value = self._get_default_initializer( 773 name=name, shape=shape, dtype=dtype) 774 if initializing_from_value: 775 init_shape = None 776 else: 777 init_shape = var_shape 778 elif callable(initializer): 779 init = initializer 780 init_shape = var_shape 781 elif isinstance(initializer, ops.Tensor): 782 init = array_ops.slice(initializer, var_offset, var_shape) 783 # Use the dtype of the given tensor. 784 dtype = init.dtype.base_dtype 785 init_shape = None 786 else: 787 init = ops.convert_to_tensor(initializer, dtype=dtype) 788 init = array_ops.slice(init, var_offset, var_shape) 789 init_shape = None 790 791 with ops.name_scope(None): 792 var = self._get_single_variable( 793 name=var_full_name, 794 shape=init_shape, 795 dtype=dtype, 796 initializer=init, 797 partition_info=partition_info, 798 regularizer=regularizer, 799 reuse=reuse, 800 trainable=trainable, 801 collections=collections, 802 caching_device=caching_device, 803 validate_shape=validate_shape, 804 use_resource=use_resource, 805 constraint=constraint, 806 synchronization=synchronization, 807 aggregation=aggregation) 808 809 # pylint: disable=protected-access 810 var._set_save_slice_info( 811 variables.Variable.SaveSliceInfo(name, shape.as_list(), var_offset, 812 var_shape)) 813 vs.append(var) 814 # pylint: enable=protected-access 815 816 partitioned_var = variables.PartitionedVariable( 817 name=name, 818 shape=shape, 819 dtype=dtype, 820 variable_list=vs, 821 partitions=partitions) 822 if not context.executing_eagerly() or self._store_eager_variables: 823 self._partitioned_vars[name] = partitioned_var 824 return partitioned_var 825 826 def _get_single_variable(self, 827 name, 828 shape=None, 829 dtype=dtypes.float32, 830 initializer=None, 831 regularizer=None, 832 partition_info=None, 833 reuse=None, 834 trainable=None, 835 collections=None, 836 caching_device=None, 837 validate_shape=True, 838 use_resource=None, 839 constraint=None, 840 synchronization=VariableSynchronization.AUTO, 841 aggregation=VariableAggregation.NONE): 842 """Get or create a single Variable (e.g. 843 844 a shard or entire variable). 845 846 See the documentation of get_variable above (ignore partitioning components) 847 for details. 848 849 Args: 850 name: see get_variable. 851 shape: see get_variable. 852 dtype: see get_variable. 853 initializer: see get_variable. 854 regularizer: see get_variable. 855 partition_info: _PartitionInfo object. 856 reuse: see get_variable. 857 trainable: see get_variable. 858 collections: see get_variable. 859 caching_device: see get_variable. 860 validate_shape: see get_variable. 861 use_resource: see get_variable. 862 constraint: see get_variable. 863 synchronization: see get_variable. 864 aggregation: see get_variable. 865 866 Returns: 867 A Variable. See documentation of get_variable above. 868 869 Raises: 870 ValueError: See documentation of get_variable above. 871 """ 872 # Set to true if initializer is a constant. 873 initializing_from_value = False 874 if initializer is not None and not callable(initializer): 875 initializing_from_value = True 876 if shape is not None and initializing_from_value: 877 raise ValueError("If initializer is a constant, do not specify shape.") 878 879 dtype = dtypes.as_dtype(dtype) 880 shape = tensor_shape.as_shape(shape) 881 882 if name in self._vars: 883 # Here we handle the case when returning an existing variable. 884 if reuse is False: 885 var = self._vars[name] 886 err_msg = ("Variable %s already exists, disallowed." 887 " Did you mean to set reuse=True or " 888 "reuse=tf.AUTO_REUSE in VarScope?" % name) 889 # ResourceVariables don't have an op associated with so no traceback 890 if isinstance(var, resource_variable_ops.ResourceVariable): 891 raise ValueError(err_msg) 892 tb = var.op.traceback[::-1] 893 # Throw away internal tf entries and only take a few lines. In some 894 # cases the traceback can be longer (e.g. if someone uses factory 895 # functions to create variables) so we take more than needed in the 896 # default case. 897 tb = [x for x in tb if "tensorflow/python" not in x[0]][:5] 898 raise ValueError("%s Originally defined at:\n\n%s" % 899 (err_msg, "".join(traceback.format_list(tb)))) 900 found_var = self._vars[name] 901 if not shape.is_compatible_with(found_var.get_shape()): 902 raise ValueError("Trying to share variable %s, but specified shape %s" 903 " and found shape %s." % 904 (name, shape, found_var.get_shape())) 905 if not dtype.is_compatible_with(found_var.dtype): 906 dtype_str = dtype.name 907 found_type_str = found_var.dtype.name 908 raise ValueError("Trying to share variable %s, but specified dtype %s" 909 " and found dtype %s." % 910 (name, dtype_str, found_type_str)) 911 return found_var 912 913 # The code below handles only the case of creating a new variable. 914 if reuse is True: 915 raise ValueError("Variable %s does not exist, or was not created with " 916 "tf.get_variable(). Did you mean to set " 917 "reuse=tf.AUTO_REUSE in VarScope?" % name) 918 919 # Create the tensor to initialize the variable with default value. 920 if initializer is None: 921 initializer, initializing_from_value = self._get_default_initializer( 922 name=name, shape=shape, dtype=dtype) 923 # Enter an init scope when creating the initializer. 924 with ops.init_scope(): 925 if initializing_from_value: 926 init_val = initializer 927 variable_dtype = None 928 else: 929 # Instantiate initializer if provided initializer is a type object. 930 if tf_inspect.isclass(initializer): 931 initializer = initializer() 932 if shape.is_fully_defined(): 933 if "partition_info" in tf_inspect.getargspec(initializer).args: 934 init_val = functools.partial(initializer, 935 shape.as_list(), 936 dtype=dtype, 937 partition_info=partition_info) 938 else: 939 init_val = functools.partial(initializer, 940 shape.as_list(), dtype=dtype) 941 variable_dtype = dtype.base_dtype 942 elif _needs_no_arguments(initializer): 943 init_val = initializer 944 variable_dtype = None 945 else: 946 raise ValueError("The initializer passed is not valid. It should " 947 "be a callable with no arguments and the " 948 "shape should not be provided or an instance of " 949 "`tf.keras.initializers.*' and `shape` should be " 950 "fully defined.") 951 952 # Create the variable. 953 if use_resource is None: 954 # Set the default value if unspecified. 955 use_resource = _DEFAULT_USE_RESOURCE 956 v = variables.VariableV1( 957 initial_value=init_val, 958 name=name, 959 trainable=trainable, 960 collections=collections, 961 caching_device=caching_device, 962 dtype=variable_dtype, 963 validate_shape=validate_shape, 964 constraint=constraint, 965 use_resource=use_resource, 966 synchronization=synchronization, 967 aggregation=aggregation) 968 if context.executing_eagerly() and self._store_eager_variables: 969 if collections: 970 ops.add_to_collections(collections, v) 971 else: 972 ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, v) 973 if trainable: 974 ops.add_to_collection(ops.GraphKeys.TRAINABLE_VARIABLES, v) 975 976 if not context.executing_eagerly() or self._store_eager_variables: 977 # In eager mode we do not want to keep default references to Variable 978 # objects as this will prevent their memory from being released. 979 self._vars[name] = v 980 logging.vlog(1, "Created variable %s with shape %s and init %s", v.name, 981 format(shape), initializer) 982 983 # Run the regularizer if requested and save the resulting loss. 984 if regularizer: 985 def make_regularizer_op(): 986 with ops.colocate_with(v): 987 with ops.name_scope(name + "/Regularizer/"): 988 return regularizer(v) 989 990 if regularizer(v) is not None: 991 lazy_eval_tensor = _LazyEvalTensor(make_regularizer_op) 992 ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, 993 lazy_eval_tensor) 994 995 return v 996 997 # Initialize variable when no initializer provided 998 def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32): 999 """Provide a default initializer and a corresponding value. 1000 1001 Args: 1002 name: see get_variable. 1003 shape: see get_variable. 1004 dtype: see get_variable. 1005 1006 Returns: 1007 initializer and initializing_from_value. See get_variable above. 1008 1009 Raises: 1010 ValueError: When giving unsupported dtype. 1011 """ 1012 del shape 1013 # If dtype is DT_FLOAT, provide a uniform unit scaling initializer 1014 if dtype.is_floating: 1015 initializer = init_ops.glorot_uniform_initializer() 1016 initializing_from_value = False 1017 # If dtype is DT_INT/DT_UINT, provide a default value `zero` 1018 # If dtype is DT_BOOL, provide a default value `FALSE` 1019 elif (dtype.is_integer or dtype.is_unsigned or dtype.is_bool or 1020 dtype == dtypes.string): 1021 initializer = init_ops.zeros_initializer() 1022 initializing_from_value = False 1023 # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here? 1024 else: 1025 raise ValueError("An initializer for variable %s of %s is required" % 1026 (name, dtype.base_dtype)) 1027 1028 return initializer, initializing_from_value 1029 1030 1031class _LazyEvalTensor(core.Tensor): 1032 """A Tensor-like object that only evaluates its thunk when used.""" 1033 1034 def __init__(self, thunk): 1035 """Initializes a _LazyEvalTensor object. 1036 1037 Args: 1038 thunk: A callable. A thunk which computes the value of the tensor. 1039 """ 1040 self._thunk = thunk 1041 self._master_tensor = thunk() 1042 1043 def _as_tensor(self, dtype=None, name=None, as_ref=False): 1044 del name 1045 assert not as_ref 1046 assert dtype in [None, self.dtype] 1047 1048 return self._thunk() 1049 1050 1051def _make_master_property(name): 1052 @property 1053 def prop(self): 1054 return getattr(self._master_tensor, name) # pylint: disable=protected-access 1055 return prop 1056 1057_master_property_list = ("device", "dtype", "graph", "name", "op", "shape", 1058 "value_index") 1059for _name in _master_property_list: 1060 setattr(_LazyEvalTensor, _name, _make_master_property(_name)) 1061 1062 1063def _make_master_method(name): 1064 def method(self, *args, **kwargs): 1065 return getattr(self._master_tensor, name)(*args, **kwargs) # pylint: disable=protected-access 1066 return method 1067 1068_master_method_list = ("get_shape", "__str__", "shape_as_list") 1069for _name in _master_method_list: 1070 setattr(_LazyEvalTensor, _name, _make_master_method(_name)) 1071 1072 1073def _make_op_method(name): 1074 def method(self, *args, **kwargs): 1075 return getattr(self._as_tensor(), name)(*args, **kwargs) # pylint: disable=protected-access 1076 return method 1077 1078_op_list = ("__abs__", "__add__", "__and__", "__bool__", "__div__", "__eq__", 1079 "__floordiv__", "__ge__", "__getitem__", "__gt__", "__invert__", 1080 "__iter__", "__le__", "__len__", "__lt__", "__matmul__", "__mod__", 1081 "__mul__", "__ne__", "__neg__", "__nonzero__", "__or__", "__pow__", 1082 "__radd__", "__rand__", "__rdiv__", "__rfloordiv__", "__rmatmul__", 1083 "__rmod__", "__rmul__", "__ror__", "__rpow__", "__rsub__", 1084 "__rtruediv__", "__rxor__", "__sub__", "__truediv__", "__xor__", 1085 "eval", "numpy") 1086for _name in _op_list: 1087 setattr(_LazyEvalTensor, _name, _make_op_method(_name)) 1088 1089 1090ops.register_tensor_conversion_function( 1091 _LazyEvalTensor, 1092 lambda val, dtype, name, as_ref: val._as_tensor(dtype, name, as_ref) # pylint: disable=protected-access 1093 ) 1094 1095session.register_session_run_conversion_functions( 1096 _LazyEvalTensor, 1097 lambda fetch: ([fetch._master_tensor], lambda fetched_vals: fetched_vals[0]) # pylint: disable=protected-access 1098 ) 1099 1100 1101# To stop regularization, use this regularizer 1102@tf_export(v1=["no_regularizer"]) 1103def no_regularizer(_): 1104 """Use this function to prevent regularization of variables.""" 1105 return None 1106 1107 1108# TODO(alive): support caching devices and partitioned variables in Eager mode. 1109@tf_export(v1=["VariableScope"]) 1110class VariableScope: 1111 """Variable scope object to carry defaults to provide to `get_variable`. 1112 1113 Many of the arguments we need for `get_variable` in a variable store are most 1114 easily handled with a context. This object is used for the defaults. 1115 1116 Attributes: 1117 name: name of the current scope, used as prefix in get_variable. 1118 initializer: default initializer passed to get_variable. 1119 regularizer: default regularizer passed to get_variable. 1120 reuse: Boolean, None, or tf.compat.v1.AUTO_REUSE, setting the reuse in 1121 get_variable. When eager execution is enabled this argument is always 1122 forced to be False. 1123 caching_device: string, callable, or None: the caching device passed to 1124 get_variable. 1125 partitioner: callable or `None`: the partitioner passed to `get_variable`. 1126 custom_getter: default custom getter passed to get_variable. 1127 name_scope: The name passed to `tf.name_scope`. 1128 dtype: default type passed to get_variable (defaults to DT_FLOAT). 1129 use_resource: if False, create a normal Variable; if True create an 1130 experimental ResourceVariable with well-defined semantics. Defaults to 1131 False (will later change to True). When eager execution is enabled this 1132 argument is always forced to be True. 1133 constraint: An optional projection function to be applied to the variable 1134 after being updated by an `Optimizer` (e.g. used to implement norm 1135 constraints or value constraints for layer weights). The function must 1136 take as input the unprojected Tensor representing the value of the 1137 variable and return the Tensor for the projected value (which must have 1138 the same shape). Constraints are not safe to use when doing asynchronous 1139 distributed training. 1140 """ 1141 1142 def __init__(self, 1143 reuse, 1144 name="", 1145 initializer=None, 1146 regularizer=None, 1147 caching_device=None, 1148 partitioner=None, 1149 custom_getter=None, 1150 name_scope="", 1151 dtype=dtypes.float32, 1152 use_resource=None, 1153 constraint=None): 1154 """Creates a new VariableScope with the given properties.""" 1155 self._name = name 1156 self._initializer = initializer 1157 self._regularizer = regularizer 1158 self._reuse = reuse 1159 self._caching_device = caching_device 1160 self._partitioner = partitioner 1161 self._custom_getter = custom_getter 1162 self._name_scope = name_scope 1163 self._dtype = dtype 1164 self._use_resource = use_resource 1165 self._constraint = constraint 1166 if context.executing_eagerly(): 1167 if self._caching_device is not None: 1168 raise NotImplementedError("Caching devices is not yet supported " 1169 "when eager execution is enabled.") 1170 self._reuse = AUTO_REUSE 1171 self._use_resource = True 1172 1173 @property 1174 def name(self): 1175 return self._name 1176 1177 @property 1178 def original_name_scope(self): 1179 return self._name_scope 1180 1181 @property 1182 def reuse(self): 1183 return self._reuse 1184 1185 @property 1186 def initializer(self): 1187 return self._initializer 1188 1189 @property 1190 def dtype(self): 1191 return self._dtype 1192 1193 @property 1194 def use_resource(self): 1195 return self._use_resource 1196 1197 @property 1198 def regularizer(self): 1199 return self._regularizer 1200 1201 @property 1202 def caching_device(self): 1203 return self._caching_device 1204 1205 @property 1206 def partitioner(self): 1207 return self._partitioner 1208 1209 @property 1210 def custom_getter(self): 1211 return self._custom_getter 1212 1213 @property 1214 def constraint(self): 1215 return self._constraint 1216 1217 def reuse_variables(self): 1218 """Reuse variables in this scope.""" 1219 self._reuse = True 1220 1221 def set_initializer(self, initializer): 1222 """Set initializer for this scope.""" 1223 self._initializer = initializer 1224 1225 def set_dtype(self, dtype): 1226 """Set data type for this scope.""" 1227 self._dtype = dtype 1228 1229 def set_use_resource(self, use_resource): 1230 """Sets whether to use ResourceVariables for this scope.""" 1231 if context.executing_eagerly() and not use_resource: 1232 raise ValueError("When eager execution is enabled, " 1233 "use_resource cannot be set to false.") 1234 self._use_resource = use_resource 1235 1236 def set_regularizer(self, regularizer): 1237 """Set regularizer for this scope.""" 1238 self._regularizer = regularizer 1239 1240 def set_caching_device(self, caching_device): 1241 """Set caching_device for this scope.""" 1242 if context.executing_eagerly(): 1243 raise NotImplementedError("Caching devices are not yet supported " 1244 "when eager execution is enabled.") 1245 self._caching_device = caching_device 1246 1247 def set_partitioner(self, partitioner): 1248 """Set partitioner for this scope.""" 1249 self._partitioner = partitioner 1250 1251 def set_custom_getter(self, custom_getter): 1252 """Set custom getter for this scope.""" 1253 self._custom_getter = custom_getter 1254 1255 def get_collection(self, name): 1256 """Get this scope's variables.""" 1257 scope = self._name + "/" if self._name else "" 1258 return ops.get_collection(name, scope) 1259 1260 def trainable_variables(self): 1261 """Get this scope's trainable variables.""" 1262 return self.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) 1263 1264 def global_variables(self): 1265 """Get this scope's global variables.""" 1266 return self.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 1267 1268 def local_variables(self): 1269 """Get this scope's local variables.""" 1270 return self.get_collection(ops.GraphKeys.LOCAL_VARIABLES) 1271 1272 def get_variable(self, 1273 var_store, 1274 name, 1275 shape=None, 1276 dtype=None, 1277 initializer=None, 1278 regularizer=None, 1279 reuse=None, 1280 trainable=None, 1281 collections=None, 1282 caching_device=None, 1283 partitioner=None, 1284 validate_shape=True, 1285 use_resource=None, 1286 custom_getter=None, 1287 constraint=None, 1288 synchronization=VariableSynchronization.AUTO, 1289 aggregation=VariableAggregation.NONE): 1290 """Gets an existing variable with this name or create a new one.""" 1291 if regularizer is None: 1292 regularizer = self._regularizer 1293 if caching_device is None: 1294 caching_device = self._caching_device 1295 if partitioner is None: 1296 partitioner = self._partitioner 1297 if custom_getter is None: 1298 custom_getter = self._custom_getter 1299 if context.executing_eagerly(): 1300 reuse = False 1301 use_resource = True 1302 else: 1303 if reuse is None: 1304 reuse = self._reuse 1305 if use_resource is None: 1306 use_resource = self._use_resource 1307 1308 full_name = self.name + "/" + name if self.name else name 1309 # Variable names only depend on variable_scope (full_name here), 1310 # not name_scope, so we reset it below for the time of variable creation. 1311 with ops.name_scope(None, skip_on_eager=False): 1312 # Check that `initializer` dtype and `dtype` are consistent before 1313 # replacing them with defaults. 1314 if (dtype is not None and initializer is not None and 1315 not callable(initializer)): 1316 init_dtype = ops.convert_to_tensor(initializer).dtype.base_dtype 1317 if init_dtype != dtype: 1318 raise ValueError("Initializer type '%s' and explicit dtype '%s' " 1319 "don't match." % (init_dtype, dtype)) 1320 if initializer is None: 1321 initializer = self._initializer 1322 if constraint is None: 1323 constraint = self._constraint 1324 if dtype is None: 1325 dtype = self._dtype 1326 return var_store.get_variable( 1327 full_name, 1328 shape=shape, 1329 dtype=dtype, 1330 initializer=initializer, 1331 regularizer=regularizer, 1332 reuse=reuse, 1333 trainable=trainable, 1334 collections=collections, 1335 caching_device=caching_device, 1336 partitioner=partitioner, 1337 validate_shape=validate_shape, 1338 use_resource=use_resource, 1339 custom_getter=custom_getter, 1340 constraint=constraint, 1341 synchronization=synchronization, 1342 aggregation=aggregation) 1343 1344 def _get_partitioned_variable(self, 1345 var_store, 1346 name, 1347 shape=None, 1348 dtype=None, 1349 initializer=None, 1350 regularizer=None, 1351 trainable=None, 1352 collections=None, 1353 caching_device=None, 1354 partitioner=None, 1355 validate_shape=True, 1356 use_resource=None, 1357 constraint=None, 1358 synchronization=VariableSynchronization.AUTO, 1359 aggregation=VariableAggregation.NONE): 1360 """Gets an existing variable with this name or create a new one.""" 1361 if initializer is None: 1362 initializer = self._initializer 1363 if regularizer is None: 1364 regularizer = self._regularizer 1365 if constraint is None: 1366 constraint = self._constraint 1367 if caching_device is None: 1368 caching_device = self._caching_device 1369 if partitioner is None: 1370 partitioner = self._partitioner 1371 if dtype is None: 1372 dtype = self._dtype 1373 if use_resource is None: 1374 use_resource = self._use_resource 1375 1376 if self._custom_getter is not None: 1377 raise ValueError( 1378 "Private access to _get_partitioned_variable is not allowed when " 1379 "a custom getter is set. Current custom getter: %s. " 1380 "It is likely that you're using create_partitioned_variables. " 1381 "If so, consider instead using get_variable with a non-empty " 1382 "partitioner parameter instead." % self._custom_getter) 1383 1384 if partitioner is None: 1385 raise ValueError("No partitioner was specified") 1386 1387 # This allows the variable scope name to be used as the variable name if 1388 # this function is invoked with an empty name arg, for backward 1389 # compatibility with create_partitioned_variables(). 1390 full_name_list = [] 1391 if self.name: 1392 full_name_list.append(self.name) 1393 if name: 1394 full_name_list.append(name) 1395 full_name = "/".join(full_name_list) 1396 1397 # Variable names only depend on variable_scope (full_name here), 1398 # not name_scope, so we reset it below for the time of variable creation. 1399 with ops.name_scope(None, skip_on_eager=False): 1400 # pylint: disable=protected-access 1401 return var_store._get_partitioned_variable( 1402 full_name, 1403 shape=shape, 1404 dtype=dtype, 1405 initializer=initializer, 1406 regularizer=regularizer, 1407 reuse=self.reuse, 1408 trainable=trainable, 1409 collections=collections, 1410 caching_device=caching_device, 1411 partitioner=partitioner, 1412 validate_shape=validate_shape, 1413 use_resource=use_resource, 1414 constraint=constraint, 1415 synchronization=synchronization, 1416 aggregation=aggregation) 1417 # pylint: enable=protected-access 1418 1419 1420_VARSTORE_KEY = ("__variable_store",) 1421_VARSCOPESTORE_KEY = ("__varscope",) 1422 1423 1424class _VariableScopeStore(threading.local): 1425 """A thread local store for the current variable scope and scope counts.""" 1426 1427 def __init__(self): 1428 super(_VariableScopeStore, self).__init__() 1429 self.current_scope = VariableScope(False) 1430 self.variable_scopes_count = {} 1431 1432 def open_variable_scope(self, scope_name): 1433 if scope_name in self.variable_scopes_count: 1434 self.variable_scopes_count[scope_name] += 1 1435 else: 1436 self.variable_scopes_count[scope_name] = 1 1437 1438 def close_variable_subscopes(self, scope_name): 1439 if scope_name is None: 1440 for k in self.variable_scopes_count: 1441 self.variable_scopes_count[k] = 0 1442 else: 1443 startswith_check = scope_name + "/" 1444 startswith_len = len(startswith_check) 1445 for k in self.variable_scopes_count: 1446 if k[:startswith_len] == startswith_check: 1447 self.variable_scopes_count[k] = 0 1448 1449 def variable_scope_count(self, scope_name): 1450 return self.variable_scopes_count.get(scope_name, 0) 1451 1452 1453def get_variable_scope_store(): 1454 """Returns the variable scope store for current thread.""" 1455 scope_store = ops.get_collection(_VARSCOPESTORE_KEY) 1456 1457 if not scope_store: 1458 scope_store = _VariableScopeStore() 1459 ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store) 1460 else: 1461 scope_store = scope_store[0] 1462 1463 return scope_store 1464 1465 1466@tf_export(v1=["get_variable_scope"]) 1467def get_variable_scope(): 1468 """Returns the current variable scope. 1469 1470 @compatibility(TF2) 1471 Although it is a legacy `compat.v1` api, 1472 `tf.compat.v1.get_variable` is compatible with eager 1473 execution and `tf.function` 1474 1475 However, to maintain variable-scope based variable reuse 1476 you will need to combine it with 1477 `tf.compat.v1.keras.utils.track_tf1_style_variables`. (Though 1478 it will behave as if reuse is always set to `tf.compat.v1.AUTO_REUSE`.) 1479 1480 See the 1481 [migration guide](https://www.tensorflow.org/guide/migrate/model_mapping) 1482 for more info. 1483 1484 The TF2 equivalent, if you are just trying to track 1485 variable name prefixes and not control `get_variable`-based variable reuse, 1486 would be to use `tf.name_scope` and capture the output of opening the 1487 scope (which represents the current name prefix). 1488 1489 For example: 1490 ```python 1491 x = tf.name_scope('foo') as current_scope: 1492 ... 1493 ``` 1494 @end_compatibility 1495 """ 1496 return get_variable_scope_store().current_scope 1497 1498 1499def _get_default_variable_store(): 1500 store = ops.get_collection(_VARSTORE_KEY) 1501 if store: 1502 return store[0] 1503 store = _VariableStore() 1504 ops.add_to_collection(_VARSTORE_KEY, store) 1505 return store 1506 1507 1508@tf_contextlib.contextmanager 1509def with_variable_store(store): 1510 store_collection = ops.get_collection_ref(_VARSTORE_KEY) 1511 old = list(store_collection) 1512 store_collection[:] = [store] 1513 try: 1514 yield 1515 finally: 1516 store_collection[:] = old 1517 1518 1519class EagerVariableStore: 1520 """Wrapper allowing functional layers to be used with eager execution. 1521 1522 When eager execution is enabled Variables get deleted when they go out of 1523 scope, and are not stored in global collections by default. A lot of code 1524 (mostly the functional layers in tf.layers) assumes that variables are kept in 1525 a global list. 1526 1527 EagerVariableStore can be used in conjunction with this code to make it 1528 eager-friendly. For example, to create a dense layer, use: 1529 1530 ``` 1531 container = tfe.EagerVariableStore() 1532 for input in dataset_iterator: 1533 with container.as_default(): 1534 x = tf.compat.v1.layers.dense(input, name="l1") 1535 print(container.variables) # Should print the variables used in the layer. 1536 ``` 1537 """ 1538 1539 def __init__(self, store=None): 1540 if store is not None: 1541 if not store._store_eager_variables: # pylint: disable=protected-access 1542 raise ValueError("Cannot construct EagerVariableStore from a " 1543 "VariableStore object that does not hold eager " 1544 "variables.") 1545 self._store = store 1546 else: 1547 self._store = _VariableStore() 1548 self._store._store_eager_variables = True # pylint: disable=protected-access 1549 1550 def as_default(self): 1551 return with_variable_store(self._store) 1552 1553 def variables(self): 1554 return sorted(self._store._vars.values(), key=lambda x: x.name) # pylint: disable=protected-access 1555 1556 def trainable_variables(self): 1557 # pylint: disable=protected-access 1558 return sorted([x for x in self._store._vars.values() if x.trainable], 1559 key=lambda x: x.name) 1560 # pylint: enable=protected-access 1561 1562 def non_trainable_variables(self): 1563 # pylint: disable=protected-access 1564 return sorted([x for x in self._store._vars.values() if not x.trainable], 1565 key=lambda x: x.name) 1566 # pylint: enable=protected-access 1567 1568 def copy(self): 1569 """Copy this variable store and all of its contents. 1570 1571 Variables contained in this store will be copied over to the new variable 1572 store, meaning that they can be modified without affecting the variables in 1573 this store. 1574 1575 Returns: 1576 A new EagerVariableStore instance containing copied variables. 1577 """ 1578 # pylint: disable=protected-access 1579 new_store = EagerVariableStore() 1580 for key, var in self._store._vars.items(): 1581 # Strip device out of variable name. 1582 try: 1583 index = var.name.index(":") 1584 except ValueError: 1585 stripped_var_name = var.name 1586 else: 1587 stripped_var_name = var.name[:index] 1588 1589 # Create new variable with same value, name, and "trainable" flag. 1590 new_var = resource_variable_ops.ResourceVariable( 1591 var.read_value(), name=stripped_var_name, trainable=var.trainable) 1592 new_store._store._vars[key] = new_var 1593 return new_store 1594 # pylint: enable=protected-access 1595 1596 1597# The argument list for get_variable must match arguments to get_local_variable. 1598# So, if you are updating the arguments, also update arguments to 1599# get_local_variable below. 1600@tf_export(v1=["get_variable"]) 1601def get_variable(name, 1602 shape=None, 1603 dtype=None, 1604 initializer=None, 1605 regularizer=None, 1606 trainable=None, 1607 collections=None, 1608 caching_device=None, 1609 partitioner=None, 1610 validate_shape=True, 1611 use_resource=None, 1612 custom_getter=None, 1613 constraint=None, 1614 synchronization=VariableSynchronization.AUTO, 1615 aggregation=VariableAggregation.NONE): 1616 return get_variable_scope().get_variable( 1617 _get_default_variable_store(), 1618 name, 1619 shape=shape, 1620 dtype=dtype, 1621 initializer=initializer, 1622 regularizer=regularizer, 1623 trainable=trainable, 1624 collections=collections, 1625 caching_device=caching_device, 1626 partitioner=partitioner, 1627 validate_shape=validate_shape, 1628 use_resource=use_resource, 1629 custom_getter=custom_getter, 1630 constraint=constraint, 1631 synchronization=synchronization, 1632 aggregation=aggregation) 1633 1634 1635get_variable_or_local_docstring = ("""%s 1636 1637@compatibility(TF2) 1638Although it is a legacy `compat.v1` api, 1639`tf.compat.v1.get_variable` is mostly compatible with eager 1640execution and `tf.function` but only if you combine it with the 1641`tf.compat.v1.keras.utils.track_tf1_style_variables` decorator. (Though 1642it will behave as if reuse is always set to `AUTO_REUSE`.) 1643 1644See the 1645[model migration guide](https://www.tensorflow.org/guide/migrate/model_mapping) 1646for more info. 1647 1648If you do not combine it with 1649`tf.compat.v1.keras.utils.track_tf1_style_variables`, `get_variable` will create 1650a brand new variable every single time it is called and will never reuse 1651variables, regardless of variable names or `reuse` arguments. 1652 1653The TF2 equivalent of this symbol would be `tf.Variable`, but note 1654that when using `tf.Variable` you must make sure you track your variables 1655(and regularizer arguments) either manually or via `tf.Module` or 1656`tf.keras.layers.Layer` mechanisms. 1657 1658A section of the 1659[migration guide](https://www.tensorflow.org/guide/migrate/model_mapping#incremental_migration_to_native_tf2) 1660provides more details on incrementally migrating these usages to `tf.Variable` 1661as well. 1662 1663Note: The `partitioner` arg is not compatible with TF2 behaviors even when 1664using `tf.compat.v1.keras.utils.track_tf1_style_variables`. It can be replaced 1665by using `ParameterServerStrategy` and its partitioners. See the 1666[multi-gpu migration guide](https://www.tensorflow.org/guide/migrate/multi_worker_cpu_gpu_training) 1667and the ParameterServerStrategy guides it references for more info. 1668@end_compatibility 1669 1670%sThis function prefixes the name with the current variable scope 1671and performs reuse checks. See the 1672[Variable Scope How To](https://tensorflow.org/guide/variables) 1673for an extensive description of how reusing works. Here is a basic example: 1674 1675```python 1676def foo(): 1677 with tf.variable_scope("foo", reuse=tf.AUTO_REUSE): 1678 v = tf.get_variable("v", [1]) 1679 return v 1680 1681v1 = foo() # Creates v. 1682v2 = foo() # Gets the same, existing v. 1683assert v1 == v2 1684``` 1685 1686If initializer is `None` (the default), the default initializer passed in 1687the variable scope will be used. If that one is `None` too, a 1688`glorot_uniform_initializer` will be used. The initializer can also be 1689a Tensor, in which case the variable is initialized to this value and shape. 1690 1691Similarly, if the regularizer is `None` (the default), the default regularizer 1692passed in the variable scope will be used (if that is `None` too, 1693then by default no regularization is performed). 1694 1695If a partitioner is provided, a `PartitionedVariable` is returned. 1696Accessing this object as a `Tensor` returns the shards concatenated along 1697the partition axis. 1698 1699Some useful partitioners are available. See, e.g., 1700`variable_axis_size_partitioner` and `min_max_variable_partitioner`. 1701 1702Args: 1703 name: The name of the new or existing variable. 1704 shape: Shape of the new or existing variable. 1705 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 1706 initializer: Initializer for the variable if one is created. Can either be 1707 an initializer object or a Tensor. If it's a Tensor, its shape must be known 1708 unless validate_shape is False. 1709 regularizer: A (Tensor -> Tensor or None) function; the result of 1710 applying it on a newly created variable will be added to the collection 1711 `tf.GraphKeys.REGULARIZATION_LOSSES` and can be used for regularization. 1712 %scollections: List of graph collections keys to add the Variable to. 1713 Defaults to `[%s]` (see `tf.Variable`). 1714 caching_device: Optional device string or function describing where the 1715 Variable should be cached for reading. Defaults to the Variable's 1716 device. If not `None`, caches on another device. Typical use is to 1717 cache on the device where the Ops using the Variable reside, to 1718 deduplicate copying through `Switch` and other conditional statements. 1719 partitioner: Optional callable that accepts a fully defined `TensorShape` 1720 and `dtype` of the Variable to be created, and returns a list of 1721 partitions for each axis (currently only one axis can be partitioned). 1722 validate_shape: If False, allows the variable to be initialized with a 1723 value of unknown shape. If True, the default, the shape of initial_value 1724 must be known. For this to be used the initializer must be a Tensor and 1725 not an initializer object. 1726 use_resource: If False, creates a regular Variable. If true, creates an 1727 experimental ResourceVariable instead with well-defined semantics. 1728 Defaults to False (will later change to True). When eager execution is 1729 enabled this argument is always forced to be True. 1730 custom_getter: Callable that takes as a first argument the true getter, and 1731 allows overwriting the internal get_variable method. 1732 The signature of `custom_getter` should match that of this method, 1733 but the most future-proof version will allow for changes: 1734 `def custom_getter(getter, *args, **kwargs)`. Direct access to 1735 all `get_variable` parameters is also allowed: 1736 `def custom_getter(getter, name, *args, **kwargs)`. A simple identity 1737 custom getter that simply creates variables with modified names is: 1738 ```python 1739 def custom_getter(getter, name, *args, **kwargs): 1740 return getter(name + '_suffix', *args, **kwargs) 1741 ``` 1742 constraint: An optional projection function to be applied to the variable 1743 after being updated by an `Optimizer` (e.g. used to implement norm 1744 constraints or value constraints for layer weights). The function must 1745 take as input the unprojected Tensor representing the value of the 1746 variable and return the Tensor for the projected value 1747 (which must have the same shape). Constraints are not safe to 1748 use when doing asynchronous distributed training. 1749 synchronization: Indicates when a distributed a variable will be 1750 aggregated. Accepted values are constants defined in the class 1751 `tf.VariableSynchronization`. By default the synchronization is set to 1752 `AUTO` and the current `DistributionStrategy` chooses 1753 when to synchronize. 1754 aggregation: Indicates how a distributed variable will be aggregated. 1755 Accepted values are constants defined in the class 1756 `tf.VariableAggregation`. 1757 1758Returns: 1759 The created or existing `Variable` (or `PartitionedVariable`, if a 1760 partitioner was used). 1761 1762Raises: 1763 ValueError: when creating a new variable and shape is not declared, 1764 when violating reuse during variable creation, or when `initializer` dtype 1765 and `dtype` don't match. Reuse is set inside `variable_scope`. 1766""") 1767get_variable.__doc__ = get_variable_or_local_docstring % ( 1768 "Gets an existing variable with these parameters or create a new one.", "", 1769 "trainable: If `True` also add the variable to the graph collection\n" 1770 " `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).\n ", 1771 "GraphKeys.GLOBAL_VARIABLES") 1772 1773 1774# The argument list for get_local_variable must match arguments to get_variable. 1775# So, if you are updating the arguments, also update arguments to get_variable. 1776@tf_export(v1=["get_local_variable"]) 1777def get_local_variable( # pylint: disable=missing-docstring 1778 name, 1779 shape=None, 1780 dtype=None, 1781 initializer=None, 1782 regularizer=None, 1783 trainable=False, # pylint: disable=unused-argument 1784 collections=None, 1785 caching_device=None, 1786 partitioner=None, 1787 validate_shape=True, 1788 use_resource=None, 1789 custom_getter=None, 1790 constraint=None, 1791 synchronization=VariableSynchronization.AUTO, 1792 aggregation=VariableAggregation.NONE): 1793 if collections: 1794 collections += [ops.GraphKeys.LOCAL_VARIABLES] 1795 else: 1796 collections = [ops.GraphKeys.LOCAL_VARIABLES] 1797 return get_variable( 1798 name, 1799 shape=shape, 1800 dtype=dtype, 1801 initializer=initializer, 1802 regularizer=regularizer, 1803 trainable=False, 1804 collections=collections, 1805 caching_device=caching_device, 1806 partitioner=partitioner, 1807 validate_shape=validate_shape, 1808 use_resource=use_resource, 1809 synchronization=synchronization, 1810 aggregation=aggregation, 1811 custom_getter=custom_getter, 1812 constraint=constraint) 1813 1814 1815get_local_variable.__doc__ = get_variable_or_local_docstring % ( 1816 "Gets an existing *local* variable or creates a new one.", 1817 "Behavior is the same as in `get_variable`, except that variables are\n" 1818 "added to the `LOCAL_VARIABLES` collection and `trainable` is set to\n" 1819 "`False`.\n", "", "GraphKeys.LOCAL_VARIABLES") 1820 1821 1822def _get_partitioned_variable(name, 1823 shape=None, 1824 dtype=None, 1825 initializer=None, 1826 regularizer=None, 1827 trainable=True, 1828 collections=None, 1829 caching_device=None, 1830 partitioner=None, 1831 validate_shape=True, 1832 use_resource=None, 1833 constraint=None, 1834 synchronization=VariableSynchronization.AUTO, 1835 aggregation=VariableAggregation.NONE): 1836 """Gets or creates a sharded variable list with these parameters. 1837 1838 The `partitioner` must be a callable that accepts a fully defined 1839 `TensorShape` and returns a sequence of integers (the `partitions`). 1840 These integers describe how to partition the given sharded `Variable` 1841 along the given dimension. That is, `partitions[1] = 3` means split 1842 the `Variable` into 3 shards along dimension 1. Currently, sharding along 1843 only one axis is supported. 1844 1845 If the list of variables with the given name (prefix) is already stored, 1846 we return the stored variables. Otherwise, we create a new one. 1847 1848 If initializer is `None` (the default), the default initializer passed in 1849 the constructor is used. If that one is `None` too, we use a new 1850 `glorot_uniform_initializer`. If initializer is a Tensor, we use 1851 it as a value and derive the shape from the initializer. 1852 1853 If the initializer is a callable, then it will be called for each 1854 shard. Otherwise the initializer should match the shape of the entire 1855 sharded Variable, and it will be sliced accordingly for each shard. 1856 1857 Some useful partitioners are available. See, e.g., 1858 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 1859 1860 Args: 1861 name: The name of the new or existing variable. 1862 shape: Shape of the new or existing variable. 1863 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 1864 initializer: Initializer for the variable if one is created. 1865 regularizer: A (Tensor -> Tensor or None) function; the result of applying 1866 it on a newly created variable will be added to the collection 1867 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 1868 trainable: If `True` also add the variable to the graph collection 1869 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 1870 collections: List of graph collections keys to add the Variable to. Defaults 1871 to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 1872 caching_device: Optional device string or function describing where the 1873 Variable should be cached for reading. Defaults to the Variable's device. 1874 If not `None`, caches on another device. Typical use is to cache on the 1875 device where the Ops using the Variable reside, to deduplicate copying 1876 through `Switch` and other conditional statements. 1877 partitioner: Optional callable that accepts a fully defined `TensorShape` 1878 and `dtype` of the Variable to be created, and returns a list of 1879 partitions for each axis (currently only one axis can be partitioned). 1880 validate_shape: If False, allows the variable to be initialized with a value 1881 of unknown shape. If True, the default, the shape of initial_value must be 1882 known. 1883 use_resource: If False, creates a regular Variable. If True, creates an 1884 experimental ResourceVariable instead which has well-defined semantics. 1885 Defaults to False (will later change to True). 1886 constraint: An optional projection function to be applied to the variable 1887 after being updated by an `Optimizer` (e.g. used to implement norm 1888 constraints or value constraints for layer weights). The function must 1889 take as input the unprojected Tensor representing the value of the 1890 variable and return the Tensor for the projected value (which must have 1891 the same shape). Constraints are not safe to use when doing asynchronous 1892 distributed training. 1893 synchronization: Indicates when a distributed a variable will be aggregated. 1894 Accepted values are constants defined in the class 1895 `tf.VariableSynchronization`. By default the synchronization is set to 1896 `AUTO` and the current `DistributionStrategy` chooses when to synchronize. 1897 aggregation: Indicates how a distributed variable will be aggregated. 1898 Accepted values are constants defined in the class 1899 `tf.VariableAggregation`. 1900 1901 Returns: 1902 A tuple `(shards, partitions)` where `shards` is the list of `Variable` 1903 shards and `partitions` is the output of the partitioner on the input 1904 shape. 1905 1906 Raises: 1907 ValueError: when creating a new variable and shape is not declared, 1908 or when violating reuse during variable creation. Reuse is set inside 1909 `variable_scope`. 1910 """ 1911 # pylint: disable=protected-access 1912 scope = get_variable_scope() 1913 if scope.custom_getter is not None: 1914 raise ValueError( 1915 "Private access to _get_partitioned_variable is not allowed when " 1916 "a custom getter is set. Current custom getter: %s. " 1917 "It is likely that you're using create_partitioned_variables. " 1918 "If so, consider instead using get_variable with a non-empty " 1919 "partitioner parameter instead." % scope.custom_getter) 1920 return scope._get_partitioned_variable( 1921 _get_default_variable_store(), 1922 name, 1923 shape=shape, 1924 dtype=dtype, 1925 initializer=initializer, 1926 regularizer=regularizer, 1927 trainable=trainable, 1928 collections=collections, 1929 caching_device=caching_device, 1930 partitioner=partitioner, 1931 validate_shape=validate_shape, 1932 use_resource=use_resource, 1933 constraint=constraint, 1934 synchronization=synchronization, 1935 aggregation=aggregation) 1936 # pylint: enable=protected-access 1937 1938 1939# Named like a function for compatibility with the previous 1940# @tf_contextlib.contextmanager definition. 1941class _pure_variable_scope: # pylint: disable=invalid-name 1942 """A context for the variable_scope, see `variable_scope` for docs.""" 1943 1944 def __init__(self, 1945 name_or_scope, 1946 reuse=None, 1947 initializer=None, 1948 regularizer=None, 1949 caching_device=None, 1950 partitioner=None, 1951 custom_getter=None, 1952 old_name_scope=None, 1953 dtype=dtypes.float32, 1954 use_resource=None, 1955 constraint=None): 1956 """Creates a context for the variable_scope, see `variable_scope` for docs. 1957 1958 Note: this does not create a name scope. 1959 1960 Args: 1961 name_or_scope: `string` or `VariableScope`: the scope to open. 1962 reuse: `True` or None, or tf.compat.v1.AUTO_REUSE; if `None`, we inherit 1963 the parent scope's reuse flag. 1964 initializer: default initializer for variables within this scope. 1965 regularizer: default regularizer for variables within this scope. 1966 caching_device: default caching device for variables within this scope. 1967 partitioner: default partitioner for variables within this scope. 1968 custom_getter: default custom getter for variables within this scope. 1969 old_name_scope: the original name scope when re-entering a variable scope. 1970 dtype: type of the variables within this scope (defaults to `DT_FLOAT`). 1971 use_resource: If False, variables in this scope will be regular Variables. 1972 If True, experimental ResourceVariables will be creates instead, with 1973 well-defined semantics. Defaults to False (will later change to True). 1974 constraint: An optional projection function to be applied to the variable 1975 after being updated by an `Optimizer` (e.g. used to implement norm 1976 constraints or value constraints for layer weights). The function must 1977 take as input the unprojected Tensor representing the value of the 1978 variable and return the Tensor for the projected value (which must have 1979 the same shape). Constraints are not safe to use when doing asynchronous 1980 distributed training. 1981 """ 1982 self._name_or_scope = name_or_scope 1983 self._reuse = reuse 1984 self._initializer = initializer 1985 self._regularizer = regularizer 1986 self._caching_device = caching_device 1987 self._partitioner = partitioner 1988 self._custom_getter = custom_getter 1989 self._old_name_scope = old_name_scope 1990 self._dtype = dtype 1991 self._use_resource = use_resource 1992 self._constraint = constraint 1993 self._var_store = _get_default_variable_store() 1994 self._var_scope_store = get_variable_scope_store() 1995 self._last_variable_scope_object = None 1996 if isinstance(self._name_or_scope, VariableScope): 1997 self._new_name = self._name_or_scope.name 1998 name_scope = self._name_or_scope._name_scope # pylint: disable=protected-access 1999 # Handler for the case when we jump to a shared scope. We create a new 2000 # VariableScope (self._var_scope_object) that contains a copy of the 2001 # provided shared scope, possibly with changed reuse and initializer, if 2002 # the user requested this. 2003 variable_scope_object = VariableScope( 2004 self._name_or_scope.reuse if not self._reuse else self._reuse, 2005 name=self._new_name, 2006 initializer=self._name_or_scope.initializer, 2007 regularizer=self._name_or_scope.regularizer, 2008 caching_device=self._name_or_scope.caching_device, 2009 partitioner=self._name_or_scope.partitioner, 2010 dtype=self._name_or_scope.dtype, 2011 custom_getter=self._name_or_scope.custom_getter, 2012 name_scope=name_scope, 2013 use_resource=self._name_or_scope.use_resource, 2014 constraint=self._constraint) 2015 if self._initializer is not None: 2016 variable_scope_object.set_initializer(self._initializer) 2017 if self._regularizer is not None: 2018 variable_scope_object.set_regularizer(self._regularizer) 2019 if self._caching_device is not None: 2020 variable_scope_object.set_caching_device(self._caching_device) 2021 if self._partitioner is not None: 2022 variable_scope_object.set_partitioner(self._partitioner) 2023 if self._custom_getter is not None: 2024 variable_scope_object.set_custom_getter( 2025 _maybe_wrap_custom_getter(self._custom_getter, 2026 self._name_or_scope.custom_getter)) 2027 if self._dtype is not None: 2028 variable_scope_object.set_dtype(self._dtype) 2029 if self._use_resource is not None: 2030 variable_scope_object.set_use_resource(self._use_resource) 2031 self._cached_variable_scope_object = variable_scope_object 2032 2033 def __enter__(self): 2034 """Begins the scope block. 2035 2036 Returns: 2037 A VariableScope. 2038 Raises: 2039 ValueError: when trying to reuse within a create scope, or create within 2040 a reuse scope, or if reuse is not `None` or `True`. 2041 TypeError: when the types of some arguments are not appropriate. 2042 """ 2043 self._old = self._var_scope_store.current_scope 2044 if isinstance(self._name_or_scope, VariableScope): 2045 self._var_scope_store.open_variable_scope(self._new_name) 2046 self._old_subscopes = copy.copy( 2047 self._var_scope_store.variable_scopes_count) 2048 variable_scope_object = self._cached_variable_scope_object 2049 else: 2050 # Handler for the case when we just prolong current variable scope. 2051 # VariableScope with name extended by the provided one, and inherited 2052 # reuse and initializer (except if the user provided values to set). 2053 self._new_name = ( 2054 self._old.name + "/" + 2055 self._name_or_scope if self._old.name else self._name_or_scope) 2056 self._reuse = (self._reuse or 2057 self._old.reuse) # Re-using is inherited by sub-scopes. 2058 if self._old_name_scope is None: 2059 name_scope = self._name_or_scope 2060 else: 2061 name_scope = self._old_name_scope 2062 variable_scope_object = VariableScope( 2063 self._reuse, 2064 name=self._new_name, 2065 initializer=self._old.initializer, 2066 regularizer=self._old.regularizer, 2067 caching_device=self._old.caching_device, 2068 partitioner=self._old.partitioner, 2069 dtype=self._old.dtype, 2070 use_resource=self._old.use_resource, 2071 custom_getter=self._old.custom_getter, 2072 name_scope=name_scope, 2073 constraint=self._constraint) 2074 if self._initializer is not None: 2075 variable_scope_object.set_initializer(self._initializer) 2076 if self._regularizer is not None: 2077 variable_scope_object.set_regularizer(self._regularizer) 2078 if self._caching_device is not None: 2079 variable_scope_object.set_caching_device(self._caching_device) 2080 if self._partitioner is not None: 2081 variable_scope_object.set_partitioner(self._partitioner) 2082 if self._custom_getter is not None: 2083 variable_scope_object.set_custom_getter( 2084 _maybe_wrap_custom_getter(self._custom_getter, 2085 self._old.custom_getter)) 2086 if self._dtype is not None: 2087 variable_scope_object.set_dtype(self._dtype) 2088 if self._use_resource is not None: 2089 variable_scope_object.set_use_resource(self._use_resource) 2090 self._var_scope_store.open_variable_scope(self._new_name) 2091 self._var_scope_store.current_scope = variable_scope_object 2092 self._last_variable_scope_object = variable_scope_object 2093 return variable_scope_object 2094 2095 def __exit__(self, type_arg, value_arg, traceback_arg): 2096 if (self._var_scope_store.current_scope is 2097 not self._last_variable_scope_object): 2098 raise RuntimeError("Improper nesting of variable_scope.") 2099 # If jumping out from a non-prolonged scope, restore counts. 2100 if isinstance(self._name_or_scope, VariableScope): 2101 self._var_scope_store.variable_scopes_count = self._old_subscopes 2102 else: 2103 self._var_scope_store.close_variable_subscopes(self._new_name) 2104 self._var_scope_store.current_scope = self._old 2105 2106 2107def _maybe_wrap_custom_getter(custom_getter, old_getter): 2108 """Wrap a call to a custom_getter to use the old_getter internally.""" 2109 if old_getter is None: 2110 return custom_getter 2111 2112 # The new custom_getter should call the old one 2113 def wrapped_custom_getter(getter, *args, **kwargs): 2114 # Call: 2115 # custom_getter( 2116 # lambda: old_getter(true_getter, ...), *args, **kwargs) 2117 # which means custom_getter will call old_getter, which 2118 # will call the true_getter, perform any intermediate 2119 # processing, and return the results to the current 2120 # getter, which will also perform additional processing. 2121 return custom_getter(functools.partial(old_getter, getter), *args, **kwargs) 2122 2123 return wrapped_custom_getter 2124 2125 2126def _get_unique_variable_scope(prefix): 2127 """Get a name with the given prefix unique in the current variable scope.""" 2128 var_scope_store = get_variable_scope_store() 2129 current_scope = get_variable_scope() 2130 name = current_scope.name + "/" + prefix if current_scope.name else prefix 2131 if var_scope_store.variable_scope_count(name) == 0: 2132 return prefix 2133 idx = 1 2134 while var_scope_store.variable_scope_count(name + ("_%d" % idx)) > 0: 2135 idx += 1 2136 return prefix + ("_%d" % idx) 2137 2138 2139# Named like a function for backwards compatibility with the 2140# @tf_contextlib.contextmanager version, which was switched to a class to avoid 2141# some object creation overhead. 2142@tf_export(v1=["variable_scope"]) # pylint: disable=invalid-name 2143class variable_scope: 2144 """A context manager for defining ops that creates variables (layers). 2145 2146 @compatibility(TF2) 2147 Although it is a legacy `compat.v1` api, 2148 `tf.compat.v1.variable_scope` is mostly compatible with eager 2149 execution and `tf.function` as long as you combine it with the 2150 `tf.compat.v1.keras.utils.track_tf1_style_variables` decorator (though 2151 it will behave as if reuse is always set to `AUTO_REUSE`.) 2152 2153 See the 2154 [model migration guide]( 2155 https://www.tensorflow.org/guide/migrate/model_mapping) 2156 for more info on 2157 migrating code that relies on `variable_scope`-based variable reuse. 2158 2159 When you use it with eager execution enabled but without 2160 `tf.compat.v1.keras.utils.track_tf1_style_variables`, 2161 `tf.compat.v1.variable_scope` will still be able to prefix the names 2162 of variables created within the scope but it will not enable variable reuse 2163 or error-raising checks around variable reuse (`get_variable` calls within 2164 it would always create new variables). 2165 2166 Once you have switched away from `get_variable`-based variable reuse 2167 mechanisms, to switch to TF2 APIs you can just use 2168 `tf.name_scope` to prefix variable names. 2169 @end_compatibility 2170 2171 This context manager validates that the (optional) `values` are from the same 2172 graph, ensures that graph is the default graph, and pushes a name scope and a 2173 variable scope. 2174 2175 If `name_or_scope` is not None, it is used as is. If `name_or_scope` is None, 2176 then `default_name` is used. In that case, if the same name has been 2177 previously used in the same scope, it will be made unique by appending `_N` 2178 to it. 2179 2180 Variable scope allows you to create new variables and to share already created 2181 ones while providing checks to not create or share by accident. For details, 2182 see the [Variable Scope How To](https://tensorflow.org/guide/variables), here 2183 we present only a few basic examples. 2184 2185 The Variable Scope works as expected when the Eager Execution is Disabled. 2186 2187 ```python 2188 tf.compat.v1.disable_eager_execution() 2189 ``` 2190 2191 Simple example of how to create a new variable: 2192 2193 ```python 2194 with tf.compat.v1.variable_scope("foo"): 2195 with tf.compat.v1.variable_scope("bar"): 2196 v = tf.compat.v1.get_variable("v", [1]) 2197 assert v.name == "foo/bar/v:0" 2198 ``` 2199 2200 Simple example of how to reenter a premade variable scope safely: 2201 2202 ```python 2203 with tf.compat.v1.variable_scope("foo") as vs: 2204 pass 2205 2206 # Re-enter the variable scope. 2207 with tf.compat.v1.variable_scope(vs, 2208 auxiliary_name_scope=False) as vs1: 2209 # Restore the original name_scope. 2210 with tf.name_scope(vs1.original_name_scope): 2211 v = tf.compat.v1.get_variable("v", [1]) 2212 assert v.name == "foo/v:0" 2213 c = tf.constant([1], name="c") 2214 assert c.name == "foo/c:0" 2215 ``` 2216 2217 Keep in mind that the counters for `default_name` are discarded once the 2218 parent scope is exited. Therefore when the code re-enters the scope (for 2219 instance by saving it), all nested default_name counters will be restarted. 2220 2221 For instance: 2222 2223 ```python 2224 with tf.compat.v1.variable_scope("foo") as vs: 2225 with tf.compat.v1.variable_scope(None, default_name="bar"): 2226 v = tf.compat.v1.get_variable("a", [1]) 2227 assert v.name == "foo/bar/a:0", v.name 2228 with tf.compat.v1.variable_scope(None, default_name="bar"): 2229 v = tf.compat.v1.get_variable("b", [1]) 2230 assert v.name == "foo/bar_1/b:0" 2231 2232 with tf.compat.v1.variable_scope(vs): 2233 with tf.compat.v1.variable_scope(None, default_name="bar"): 2234 v = tf.compat.v1.get_variable("c", [1]) 2235 assert v.name == "foo/bar/c:0" # Uses bar instead of bar_2! 2236 ``` 2237 2238 Basic example of sharing a variable AUTO_REUSE: 2239 2240 ```python 2241 def foo(): 2242 with tf.compat.v1.variable_scope("foo", reuse=tf.compat.v1.AUTO_REUSE): 2243 v = tf.compat.v1.get_variable("v", [1]) 2244 return v 2245 2246 v1 = foo() # Creates v. 2247 v2 = foo() # Gets the same, existing v. 2248 assert v1 == v2 2249 ``` 2250 2251 Basic example of sharing a variable with reuse=True: 2252 2253 ```python 2254 with tf.compat.v1.variable_scope("foo"): 2255 v = tf.compat.v1.get_variable("v", [1]) 2256 with tf.compat.v1.variable_scope("foo", reuse=True): 2257 v1 = tf.compat.v1.get_variable("v", [1]) 2258 assert v1 == v 2259 ``` 2260 2261 Sharing a variable by capturing a scope and setting reuse: 2262 2263 ```python 2264 with tf.compat.v1.variable_scope("foo") as scope: 2265 v = tf.compat.v1.get_variable("v", [1]) 2266 scope.reuse_variables() 2267 v1 = tf.compat.v1.get_variable("v", [1]) 2268 assert v1 == v 2269 ``` 2270 2271 To prevent accidental sharing of variables, we raise an exception when getting 2272 an existing variable in a non-reusing scope. 2273 2274 ```python 2275 with tf.compat.v1.variable_scope("foo"): 2276 v = tf.compat.v1.get_variable("v", [1]) 2277 v1 = tf.compat.v1.get_variable("v", [1]) 2278 # Raises ValueError("... v already exists ..."). 2279 ``` 2280 2281 Similarly, we raise an exception when trying to get a variable that does not 2282 exist in reuse mode. 2283 2284 ```python 2285 with tf.compat.v1.variable_scope("foo", reuse=True): 2286 v = tf.compat.v1.get_variable("v", [1]) 2287 # Raises ValueError("... v does not exists ..."). 2288 ``` 2289 2290 Note that the `reuse` flag is inherited: if we open a reusing scope, then all 2291 its sub-scopes become reusing as well. 2292 2293 A note about name scoping: Setting `reuse` does not impact the naming of other 2294 ops such as mult. See related discussion on 2295 [github#6189](https://github.com/tensorflow/tensorflow/issues/6189) 2296 2297 Note that up to and including version 1.0, it was allowed (though explicitly 2298 discouraged) to pass False to the reuse argument, yielding undocumented 2299 behaviour slightly different from None. Starting at 1.1.0 passing None and 2300 False as reuse has exactly the same effect. 2301 2302 A note about using variable scopes in multi-threaded environment: Variable 2303 scopes are thread local, so one thread will not see another thread's current 2304 scope. Also, when using `default_name`, unique scopes names are also generated 2305 only on a per thread basis. If the same name was used within a different 2306 thread, that doesn't prevent a new thread from creating the same scope. 2307 However, the underlying variable store is shared across threads (within the 2308 same graph). As such, if another thread tries to create a new variable with 2309 the same name as a variable created by a previous thread, it will fail unless 2310 reuse is True. 2311 2312 Further, each thread starts with an empty variable scope. So if you wish to 2313 preserve name prefixes from a scope from the main thread, you should capture 2314 the main thread's scope and re-enter it in each thread. For e.g. 2315 2316 ``` 2317 main_thread_scope = variable_scope.get_variable_scope() 2318 2319 # Thread's target function: 2320 def thread_target_fn(captured_scope): 2321 with variable_scope.variable_scope(captured_scope): 2322 # .... regular code for this thread 2323 2324 2325 thread = threading.Thread(target=thread_target_fn, args=(main_thread_scope,)) 2326 ``` 2327 """ 2328 2329 def __init__(self, 2330 name_or_scope, 2331 default_name=None, 2332 values=None, 2333 initializer=None, 2334 regularizer=None, 2335 caching_device=None, 2336 partitioner=None, 2337 custom_getter=None, 2338 reuse=None, 2339 dtype=None, 2340 use_resource=None, 2341 constraint=None, 2342 auxiliary_name_scope=True): 2343 """Initialize the context manager. 2344 2345 Args: 2346 name_or_scope: `string` or `VariableScope`: the scope to open. 2347 default_name: The default name to use if the `name_or_scope` argument is 2348 `None`, this name will be uniquified. If name_or_scope is provided it 2349 won't be used and therefore it is not required and can be None. 2350 values: The list of `Tensor` arguments that are passed to the op function. 2351 initializer: default initializer for variables within this scope. 2352 regularizer: default regularizer for variables within this scope. 2353 caching_device: default caching device for variables within this scope. 2354 partitioner: default partitioner for variables within this scope. 2355 custom_getter: default custom getter for variables within this scope. 2356 reuse: `True`, None, or tf.compat.v1.AUTO_REUSE; if `True`, we go into 2357 reuse mode for this scope as well as all sub-scopes; if 2358 tf.compat.v1.AUTO_REUSE, we create variables if they do not exist, and 2359 return them otherwise; if None, we inherit the parent scope's reuse 2360 flag. When eager execution is enabled, new variables are always created 2361 unless an EagerVariableStore or template is currently active. 2362 dtype: type of variables created in this scope (defaults to the type in 2363 the passed scope, or inherited from parent scope). 2364 use_resource: If False, all variables will be regular Variables. If True, 2365 experimental ResourceVariables with well-defined semantics will be used 2366 instead. Defaults to False (will later change to True). When eager 2367 execution is enabled this argument is always forced to be True. 2368 constraint: An optional projection function to be applied to the variable 2369 after being updated by an `Optimizer` (e.g. used to implement norm 2370 constraints or value constraints for layer weights). The function must 2371 take as input the unprojected Tensor representing the value of the 2372 variable and return the Tensor for the projected value (which must have 2373 the same shape). Constraints are not safe to use when doing asynchronous 2374 distributed training. 2375 auxiliary_name_scope: If `True`, we create an auxiliary name scope with 2376 the scope. If `False`, we don't create it. Note that the argument is not 2377 inherited, and it only takes effect for once when creating. You should 2378 only use it for re-entering a premade variable scope. 2379 2380 Returns: 2381 A scope that can be captured and reused. 2382 2383 Raises: 2384 ValueError: when trying to reuse within a create scope, or create within 2385 a reuse scope. 2386 TypeError: when the types of some arguments are not appropriate. 2387 """ 2388 self._name_or_scope = name_or_scope 2389 self._default_name = default_name 2390 self._values = values 2391 self._initializer = initializer 2392 self._regularizer = regularizer 2393 self._caching_device = caching_device 2394 self._partitioner = partitioner 2395 self._custom_getter = custom_getter 2396 self._reuse = reuse 2397 self._dtype = dtype 2398 self._use_resource = use_resource 2399 self._constraint = constraint 2400 if self._default_name is None and self._name_or_scope is None: 2401 raise TypeError("If default_name is None then name_or_scope is required") 2402 if self._reuse is False: 2403 # We don't allow non-inheriting scopes, False = None here. 2404 self._reuse = None 2405 if not (self._reuse is True 2406 or self._reuse is None 2407 or self._reuse is AUTO_REUSE): 2408 raise ValueError("The reuse parameter must be True or False or None.") 2409 if self._values is None: 2410 self._values = [] 2411 self._in_graph_mode = not context.executing_eagerly() 2412 if self._in_graph_mode: 2413 self._graph = ops._get_graph_from_inputs(self._values) # pylint: disable=protected-access 2414 self._cached_pure_variable_scope = None 2415 self._current_name_scope = None 2416 if not isinstance(auxiliary_name_scope, bool): 2417 raise TypeError("The auxiliary_name_scope must be `True` or `False`, " 2418 "while get {}".format(auxiliary_name_scope)) 2419 self._auxiliary_name_scope = auxiliary_name_scope 2420 2421 def __enter__(self): 2422 # If the default graph is building a function, then we should not replace it 2423 # with the cached graph. 2424 if ops.get_default_graph().building_function: 2425 self._building_function = True 2426 else: 2427 self._building_function = False 2428 if self._in_graph_mode and not self._building_function: 2429 self._graph_context_manager = self._graph.as_default() 2430 self._graph_context_manager.__enter__() 2431 if self._cached_pure_variable_scope is not None: 2432 # Fast path for re-entering variable_scopes. We've held on to the pure 2433 # variable scope from a previous successful __enter__, so we avoid some 2434 # overhead by re-using that object. 2435 if self._current_name_scope is not None: 2436 self._current_name_scope.__enter__() 2437 return self._cached_pure_variable_scope.__enter__() 2438 2439 try: 2440 return self._enter_scope_uncached() 2441 except: 2442 if (self._in_graph_mode and not self._building_function and 2443 self._graph_context_manager is not None): 2444 self._graph_context_manager.__exit__(*sys.exc_info()) 2445 raise 2446 2447 def _enter_scope_uncached(self): 2448 """Enters the context manager when there is no cached scope yet. 2449 2450 Returns: 2451 The entered variable scope. 2452 2453 Raises: 2454 TypeError: A wrong type is passed as `scope` at __init__(). 2455 ValueError: `reuse` is incorrectly set at __init__(). 2456 """ 2457 if self._auxiliary_name_scope: 2458 # Create a new name scope later 2459 current_name_scope = None 2460 else: 2461 # Reenter the current name scope 2462 name_scope = ops.get_name_scope() 2463 if name_scope: 2464 # Hack to reenter 2465 name_scope += "/" 2466 current_name_scope = ops.name_scope(name_scope, skip_on_eager=False) 2467 else: 2468 # Root scope 2469 current_name_scope = ops.name_scope(name_scope, skip_on_eager=False) 2470 2471 # IMPORTANT: Only assign to self._cached_pure_variable_scope and 2472 # self._current_name_scope after successful __enter__() calls. 2473 if self._name_or_scope is not None: 2474 if not isinstance(self._name_or_scope, (VariableScope, str)): 2475 raise TypeError("VariableScope: name_or_scope must be a string or " 2476 "VariableScope.") 2477 if isinstance(self._name_or_scope, str): 2478 name_scope = self._name_or_scope 2479 else: 2480 name_scope = self._name_or_scope.name.split("/")[-1] 2481 if name_scope or current_name_scope: 2482 current_name_scope = current_name_scope or ops.name_scope( 2483 name_scope, skip_on_eager=False) 2484 try: 2485 current_name_scope_name = current_name_scope.__enter__() 2486 except: 2487 current_name_scope.__exit__(*sys.exc_info()) 2488 raise 2489 self._current_name_scope = current_name_scope 2490 if isinstance(self._name_or_scope, str): 2491 old_name_scope = current_name_scope_name 2492 else: 2493 old_name_scope = self._name_or_scope.original_name_scope 2494 pure_variable_scope = _pure_variable_scope( 2495 self._name_or_scope, 2496 reuse=self._reuse, 2497 initializer=self._initializer, 2498 regularizer=self._regularizer, 2499 caching_device=self._caching_device, 2500 partitioner=self._partitioner, 2501 custom_getter=self._custom_getter, 2502 old_name_scope=old_name_scope, 2503 dtype=self._dtype, 2504 use_resource=self._use_resource, 2505 constraint=self._constraint) 2506 try: 2507 entered_pure_variable_scope = pure_variable_scope.__enter__() 2508 except: 2509 pure_variable_scope.__exit__(*sys.exc_info()) 2510 raise 2511 self._cached_pure_variable_scope = pure_variable_scope 2512 return entered_pure_variable_scope 2513 else: 2514 self._current_name_scope = None 2515 # This can only happen if someone is entering the root variable scope. 2516 pure_variable_scope = _pure_variable_scope( 2517 self._name_or_scope, 2518 reuse=self._reuse, 2519 initializer=self._initializer, 2520 regularizer=self._regularizer, 2521 caching_device=self._caching_device, 2522 partitioner=self._partitioner, 2523 custom_getter=self._custom_getter, 2524 dtype=self._dtype, 2525 use_resource=self._use_resource, 2526 constraint=self._constraint) 2527 try: 2528 entered_pure_variable_scope = pure_variable_scope.__enter__() 2529 except: 2530 pure_variable_scope.__exit__(*sys.exc_info()) 2531 raise 2532 self._cached_pure_variable_scope = pure_variable_scope 2533 return entered_pure_variable_scope 2534 2535 else: # Here name_or_scope is None. Using default name, but made unique. 2536 if self._reuse: 2537 raise ValueError("reuse=True cannot be used without a name_or_scope") 2538 current_name_scope = current_name_scope or ops.name_scope( 2539 self._default_name, skip_on_eager=False) 2540 try: 2541 current_name_scope_name = current_name_scope.__enter__() 2542 except: 2543 current_name_scope.__exit__(*sys.exc_info()) 2544 raise 2545 self._current_name_scope = current_name_scope 2546 unique_default_name = _get_unique_variable_scope(self._default_name) 2547 pure_variable_scope = _pure_variable_scope( 2548 unique_default_name, 2549 initializer=self._initializer, 2550 regularizer=self._regularizer, 2551 caching_device=self._caching_device, 2552 partitioner=self._partitioner, 2553 custom_getter=self._custom_getter, 2554 old_name_scope=current_name_scope_name, 2555 dtype=self._dtype, 2556 use_resource=self._use_resource, 2557 constraint=self._constraint) 2558 try: 2559 entered_pure_variable_scope = pure_variable_scope.__enter__() 2560 except: 2561 pure_variable_scope.__exit__(*sys.exc_info()) 2562 raise 2563 self._cached_pure_variable_scope = pure_variable_scope 2564 return entered_pure_variable_scope 2565 2566 def __exit__(self, type_arg, value_arg, traceback_arg): 2567 try: 2568 self._cached_pure_variable_scope.__exit__(type_arg, value_arg, 2569 traceback_arg) 2570 finally: 2571 try: 2572 if self._current_name_scope: 2573 self._current_name_scope.__exit__(type_arg, value_arg, 2574 traceback_arg) 2575 finally: 2576 if self._in_graph_mode and not self._building_function: 2577 self._graph_context_manager.__exit__(type_arg, value_arg, 2578 traceback_arg) 2579 2580 2581# pylint: disable=g-doc-return-or-yield 2582@tf_export(v1=["variable_op_scope"]) 2583@tf_contextlib.contextmanager 2584def variable_op_scope(values, 2585 name_or_scope, 2586 default_name=None, 2587 initializer=None, 2588 regularizer=None, 2589 caching_device=None, 2590 partitioner=None, 2591 custom_getter=None, 2592 reuse=None, 2593 dtype=None, 2594 use_resource=None, 2595 constraint=None): 2596 """Deprecated: context manager for defining an op that creates variables.""" 2597 logging.warn("tf.variable_op_scope(values, name, default_name) is deprecated," 2598 " use tf.variable_scope(name, default_name, values)") 2599 with variable_scope( 2600 name_or_scope, 2601 default_name=default_name, 2602 values=values, 2603 initializer=initializer, 2604 regularizer=regularizer, 2605 caching_device=caching_device, 2606 partitioner=partitioner, 2607 custom_getter=custom_getter, 2608 reuse=reuse, 2609 dtype=dtype, 2610 use_resource=use_resource, 2611 constraint=constraint) as scope: 2612 yield scope 2613 2614 2615def _call_partitioner(partitioner, shape, dtype): 2616 """Call partitioner validating its inputs/output. 2617 2618 Args: 2619 partitioner: a function mapping `Tensor` shape and dtype to a list of 2620 partitions. 2621 shape: shape of the `Tensor` to partition, must have at least two 2622 dimensions. 2623 dtype: dtype of the elements in the `Tensor`. 2624 2625 Returns: 2626 A list with elements >=1 and exactly one >1. The index of that 2627 element corresponds to the partitioning axis. 2628 """ 2629 if not shape.is_fully_defined(): 2630 raise ValueError("Shape of a new partitioned variable must be " 2631 "fully defined, but instead was %s." % (shape,)) 2632 if shape.ndims < 1: 2633 raise ValueError("A partitioned Variable must have rank at least 1, " 2634 "shape: %s" % shape) 2635 2636 slicing = partitioner(shape=shape, dtype=dtype) 2637 if not isinstance(slicing, collections_abc.Sequence): 2638 raise ValueError("Partitioner must return a sequence, but saw: %s" % 2639 slicing) 2640 if len(slicing) != shape.ndims: 2641 raise ValueError( 2642 "Partitioner returned a partition list that does not match the " 2643 "Variable's rank: %s vs. %s" % (slicing, shape)) 2644 if any(p < 1 for p in slicing): 2645 raise ValueError("Partitioner returned zero partitions for some axes: %s" % 2646 slicing) 2647 if sum(p > 1 for p in slicing) > 1: 2648 raise ValueError("Can only slice a variable along one dimension: " 2649 "shape: %s, partitioning: %s" % (shape, slicing)) 2650 return slicing 2651 2652 2653# TODO(slebedev): could be inlined, but 2654# `_VariableStore._get_partitioned_variable` is too complex even 2655# without this logic. 2656def _get_slice_dim_and_num_slices(slicing): 2657 """Get slicing dimension and number of slices from the partitioner output.""" 2658 for slice_dim, num_slices in enumerate(slicing): 2659 if num_slices > 1: 2660 break 2661 else: 2662 # Degenerate case: no partitioning applied. 2663 slice_dim = 0 2664 num_slices = 1 2665 return slice_dim, num_slices 2666 2667 2668def _iter_slices(full_shape, num_slices, slice_dim): 2669 """Slices a given a shape along the specified dimension.""" 2670 num_slices_with_excess = full_shape[slice_dim] % num_slices 2671 offset = [0] * len(full_shape) 2672 min_slice_len = full_shape[slice_dim] // num_slices 2673 for i in range(num_slices): 2674 shape = full_shape[:] 2675 shape[slice_dim] = min_slice_len + bool(i < num_slices_with_excess) 2676 yield offset[:], shape 2677 offset[slice_dim] += shape[slice_dim] 2678 2679 2680def default_variable_creator(next_creator=None, **kwargs): 2681 """Default variable creator.""" 2682 assert next_creator is None 2683 initial_value = kwargs.get("initial_value", None) 2684 trainable = kwargs.get("trainable", None) 2685 collections = kwargs.get("collections", None) 2686 validate_shape = kwargs.get("validate_shape", True) 2687 caching_device = kwargs.get("caching_device", None) 2688 name = kwargs.get("name", None) 2689 variable_def = kwargs.get("variable_def", None) 2690 dtype = kwargs.get("dtype", None) 2691 expected_shape = kwargs.get("expected_shape", None) 2692 import_scope = kwargs.get("import_scope", None) 2693 constraint = kwargs.get("constraint", None) 2694 use_resource = kwargs.get("use_resource", None) 2695 synchronization = kwargs.get("synchronization", None) 2696 aggregation = kwargs.get("aggregation", None) 2697 shape = kwargs.get("shape", None) 2698 2699 if use_resource is None: 2700 use_resource = get_variable_scope().use_resource 2701 if use_resource is None: 2702 use_resource = _DEFAULT_USE_RESOURCE 2703 use_resource = use_resource or context.executing_eagerly() 2704 if use_resource: 2705 distribute_strategy = kwargs.get("distribute_strategy", None) 2706 return resource_variable_ops.ResourceVariable( 2707 initial_value=initial_value, 2708 trainable=trainable, 2709 collections=collections, 2710 validate_shape=validate_shape, 2711 caching_device=caching_device, 2712 name=name, 2713 dtype=dtype, 2714 constraint=constraint, 2715 variable_def=variable_def, 2716 import_scope=import_scope, 2717 distribute_strategy=distribute_strategy, 2718 synchronization=synchronization, 2719 aggregation=aggregation, 2720 shape=shape) 2721 else: 2722 return variables.RefVariable( 2723 initial_value=initial_value, 2724 trainable=trainable, 2725 collections=collections, 2726 validate_shape=validate_shape, 2727 caching_device=caching_device, 2728 name=name, 2729 dtype=dtype, 2730 constraint=constraint, 2731 variable_def=variable_def, 2732 expected_shape=expected_shape, 2733 import_scope=import_scope, 2734 synchronization=synchronization, 2735 aggregation=aggregation, 2736 shape=shape) 2737 2738 2739def default_variable_creator_v2(next_creator=None, **kwargs): 2740 """Default variable creator.""" 2741 assert next_creator is None 2742 initial_value = kwargs.get("initial_value", None) 2743 trainable = kwargs.get("trainable", None) 2744 validate_shape = kwargs.get("validate_shape", True) 2745 caching_device = kwargs.get("caching_device", None) 2746 name = kwargs.get("name", None) 2747 variable_def = kwargs.get("variable_def", None) 2748 dtype = kwargs.get("dtype", None) 2749 import_scope = kwargs.get("import_scope", None) 2750 constraint = kwargs.get("constraint", None) 2751 distribute_strategy = kwargs.get("distribute_strategy", None) 2752 synchronization = kwargs.get("synchronization", None) 2753 aggregation = kwargs.get("aggregation", None) 2754 shape = kwargs.get("shape", None) 2755 2756 return resource_variable_ops.ResourceVariable( 2757 initial_value=initial_value, 2758 trainable=trainable, 2759 validate_shape=validate_shape, 2760 caching_device=caching_device, 2761 name=name, 2762 dtype=dtype, 2763 constraint=constraint, 2764 variable_def=variable_def, 2765 import_scope=import_scope, 2766 distribute_strategy=distribute_strategy, 2767 synchronization=synchronization, 2768 aggregation=aggregation, 2769 shape=shape) 2770 2771 2772variables.default_variable_creator = default_variable_creator 2773variables.default_variable_creator_v2 = default_variable_creator_v2 2774 2775 2776def _make_getter(captured_getter, captured_previous): 2777 """Gets around capturing loop variables in python being broken.""" 2778 return lambda **kwargs: captured_getter(captured_previous, **kwargs) 2779 2780 2781# TODO(apassos) remove forwarding symbol 2782variable = variables.VariableV1 2783 2784 2785@tf_export(v1=["variable_creator_scope"]) 2786@tf_contextlib.contextmanager 2787def variable_creator_scope_v1(variable_creator): 2788 """Scope which defines a variable creation function to be used by variable(). 2789 2790 variable_creator is expected to be a function with the following signature: 2791 2792 ``` 2793 def variable_creator(next_creator, **kwargs) 2794 ``` 2795 2796 The creator is supposed to eventually call the next_creator to create a 2797 variable if it does want to create a variable and not call Variable or 2798 ResourceVariable directly. This helps make creators composable. A creator may 2799 choose to create multiple variables, return already existing variables, or 2800 simply register that a variable was created and defer to the next creators in 2801 line. Creators can also modify the keyword arguments seen by the next 2802 creators. 2803 2804 Custom getters in the variable scope will eventually resolve down to these 2805 custom creators when they do create variables. 2806 2807 The valid keyword arguments in kwds are: 2808 2809 * initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 2810 which is the initial value for the Variable. The initial value must have 2811 a shape specified unless `validate_shape` is set to False. Can also be a 2812 callable with no argument that returns the initial value when called. In 2813 that case, `dtype` must be specified. (Note that initializer functions 2814 from init_ops.py must first be bound to a shape before being used here.) 2815 * trainable: If `True`, the default, also adds the variable to the graph 2816 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 2817 the default list of variables to use by the `Optimizer` classes. 2818 `trainable` defaults to `True`, unless `synchronization` is 2819 set to `ON_READ`, in which case it defaults to `False`. 2820 * collections: List of graph collections keys. The new variable is added to 2821 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 2822 * validate_shape: If `False`, allows the variable to be initialized with a 2823 value of unknown shape. If `True`, the default, the shape of 2824 `initial_value` must be known. 2825 * caching_device: Optional device string describing where the Variable 2826 should be cached for reading. Defaults to the Variable's device. 2827 If not `None`, caches on another device. Typical use is to cache 2828 on the device where the Ops using the Variable reside, to deduplicate 2829 copying through `Switch` and other conditional statements. 2830 * name: Optional name for the variable. Defaults to `'Variable'` and gets 2831 uniquified automatically. 2832 * dtype: If set, initial_value will be converted to the given type. 2833 If `None`, either the datatype will be kept (if `initial_value` is 2834 a Tensor), or `convert_to_tensor` will decide. 2835 * constraint: A constraint function to be applied to the variable after 2836 updates by some algorithms. 2837 * use_resource: if True, a ResourceVariable is always created. 2838 * synchronization: Indicates when a distributed a variable will be 2839 aggregated. Accepted values are constants defined in the class 2840 `tf.VariableSynchronization`. By default the synchronization is set to 2841 `AUTO` and the current `DistributionStrategy` chooses 2842 when to synchronize. 2843 * aggregation: Indicates how a distributed variable will be aggregated. 2844 Accepted values are constants defined in the class 2845 `tf.VariableAggregation`. 2846 2847 This set may grow over time, so it's important the signature of creators is as 2848 mentioned above. 2849 2850 Args: 2851 variable_creator: the passed creator 2852 2853 Yields: 2854 A scope in which the creator is active 2855 """ 2856 with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access 2857 yield 2858 2859 2860# Note: only the docstrings differ between this and v1. 2861@tf_export("variable_creator_scope", v1=[]) 2862@tf_contextlib.contextmanager 2863def variable_creator_scope(variable_creator): 2864 """Scope which defines a variable creation function to be used by variable(). 2865 2866 variable_creator is expected to be a function with the following signature: 2867 2868 ``` 2869 def variable_creator(next_creator, **kwargs) 2870 ``` 2871 2872 The creator is supposed to eventually call the next_creator to create a 2873 variable if it does want to create a variable and not call Variable or 2874 ResourceVariable directly. This helps make creators composable. A creator may 2875 choose to create multiple variables, return already existing variables, or 2876 simply register that a variable was created and defer to the next creators in 2877 line. Creators can also modify the keyword arguments seen by the next 2878 creators. 2879 2880 Custom getters in the variable scope will eventually resolve down to these 2881 custom creators when they do create variables. 2882 2883 The valid keyword arguments in kwds are: 2884 2885 * initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 2886 which is the initial value for the Variable. The initial value must have 2887 a shape specified unless `validate_shape` is set to False. Can also be a 2888 callable with no argument that returns the initial value when called. In 2889 that case, `dtype` must be specified. (Note that initializer functions 2890 from init_ops.py must first be bound to a shape before being used here.) 2891 * trainable: If `True`, the default, GradientTapes automatically watch 2892 uses of this Variable. 2893 * validate_shape: If `False`, allows the variable to be initialized with a 2894 value of unknown shape. If `True`, the default, the shape of 2895 `initial_value` must be known. 2896 * caching_device: Optional device string describing where the Variable 2897 should be cached for reading. Defaults to the Variable's device. 2898 If not `None`, caches on another device. Typical use is to cache 2899 on the device where the Ops using the Variable reside, to deduplicate 2900 copying through `Switch` and other conditional statements. 2901 * name: Optional name for the variable. Defaults to `'Variable'` and gets 2902 uniquified automatically. 2903 dtype: If set, initial_value will be converted to the given type. 2904 If `None`, either the datatype will be kept (if `initial_value` is 2905 a Tensor), or `convert_to_tensor` will decide. 2906 * constraint: A constraint function to be applied to the variable after 2907 updates by some algorithms. 2908 * synchronization: Indicates when a distributed a variable will be 2909 aggregated. Accepted values are constants defined in the class 2910 `tf.VariableSynchronization`. By default the synchronization is set to 2911 `AUTO` and the current `DistributionStrategy` chooses 2912 when to synchronize. 2913 * aggregation: Indicates how a distributed variable will be aggregated. 2914 Accepted values are constants defined in the class 2915 `tf.VariableAggregation`. 2916 2917 This set may grow over time, so it's important the signature of creators is as 2918 mentioned above. 2919 2920 Args: 2921 variable_creator: the passed creator 2922 2923 Yields: 2924 A scope in which the creator is active 2925 """ 2926 with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access 2927 yield 2928