1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""A class to store named variables and a scope operator to manage sharing.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections as collections_lib 23import copy 24import enum # pylint: disable=g-bad-import-order 25import functools 26import sys 27import threading 28import traceback 29 30import six 31from six import iteritems 32from six.moves import xrange # pylint: disable=redefined-builtin 33 34from tensorflow.python import tf2 35from tensorflow.python.eager import context 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_shape 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import init_ops 41from tensorflow.python.ops import resource_variable_ops 42from tensorflow.python.ops import variables 43from tensorflow.python.platform import tf_logging as logging 44from tensorflow.python.util import deprecation 45from tensorflow.python.util import function_utils 46from tensorflow.python.util import tf_contextlib 47from tensorflow.python.util import tf_inspect 48from tensorflow.python.util.tf_export import tf_export 49 50__all__ = [ 51 "AUTO_REUSE", "VariableScope", "get_variable_scope", "get_variable", 52 "get_local_variable", "variable_scope", "variable_op_scope", 53 "no_regularizer", "VariableSynchronization", "VariableAggregation" 54] 55 56 57class _PartitionInfo(object): 58 """Holds partition info used by initializer functions. 59 """ 60 61 def __init__(self, full_shape, var_offset): 62 """Constructor. 63 64 Args: 65 full_shape: Tuple or list of `int` indicating the full combined shape 66 of the partitioned variables. 67 var_offset: Tuple or list of `int` specifying offset of this partition 68 with respect to the full variable for each dimension. 69 70 Raises: 71 TypeError: If `full_shape` or `var_offset` is not a sequence. 72 ValueError: If `full_shape` or `var_offset` differ in length. If 73 `var_offset` exceeds `full_shape` in any dimension. 74 """ 75 if not isinstance(full_shape, collections_lib.Sequence) or isinstance( 76 full_shape, six.string_types): 77 raise TypeError( 78 "`full_shape` must be a sequence (like tuple or list) instead of " + 79 type(full_shape).__name__) 80 81 if not isinstance(var_offset, collections_lib.Sequence) or isinstance( 82 var_offset, six.string_types): 83 raise TypeError( 84 "`var_offset` must be a sequence (like tuple or list) instead of " + 85 type(var_offset).__name__) 86 87 if len(var_offset) != len(full_shape): 88 raise ValueError( 89 "Expected equal length, but `var_offset` is of length {} while " 90 "full_shape is of length {}.".format( 91 len(var_offset), len(full_shape))) 92 93 for i in xrange(len(full_shape)): 94 offset = var_offset[i] 95 shape = full_shape[i] 96 if offset < 0 or offset >= shape: 97 raise ValueError( 98 "Expected 0 <= offset < shape but found offset={}, shape={} for " 99 "var_offset={}, full_shape={}".format(offset, shape, var_offset, 100 full_shape)) 101 102 self._full_shape = full_shape 103 self._var_offset = var_offset 104 105 @property 106 def full_shape(self): 107 return self._full_shape 108 109 @property 110 def var_offset(self): 111 return self._var_offset 112 113 def single_offset(self, shape): 114 """Returns the offset when the variable is partitioned in at most one dim. 115 116 Args: 117 shape: Tuple or list of `int` indicating the shape of one specific 118 variable partition. 119 120 Returns: 121 `int` representing the offset in the dimension along which the variable is 122 partitioned. Returns 0 if the variable is not being partitioned. 123 124 Raises: 125 ValueError: Depending on self.single_slice_dim(). 126 """ 127 128 single_slice_dim = self.single_slice_dim(shape) 129 # If this variable is not being partitioned at all, single_slice_dim() could 130 # return None. 131 if single_slice_dim is None: 132 return 0 133 return self.var_offset[single_slice_dim] 134 135 def single_slice_dim(self, shape): 136 """Returns the slice dim when the variable is partitioned only in one dim. 137 138 Args: 139 shape: Tuple or list of `int` indicating the shape of one specific 140 variable partition. 141 142 Returns: 143 `int` representing the dimension that the variable is partitioned in, or 144 `None` if the variable doesn't seem to be partitioned at all. 145 146 Raises: 147 TypeError: If `shape` is not a sequence. 148 ValueError: If `shape` is not the same length as `self.full_shape`. If 149 the variable is partitioned in more than one dimension. 150 """ 151 if not isinstance(shape, collections_lib.Sequence) or isinstance( 152 shape, six.string_types): 153 raise TypeError( 154 "`shape` must be a sequence (like tuple or list) instead of " + 155 type(shape).__name__) 156 157 if len(shape) != len(self.full_shape): 158 raise ValueError( 159 "Expected equal length, but received shape={} of length {} while " 160 "self.full_shape={} is of length {}.".format(shape, len( 161 shape), self.full_shape, len(self.full_shape))) 162 163 for i in xrange(len(shape)): 164 if self.var_offset[i] + shape[i] > self.full_shape[i]: 165 raise ValueError( 166 "With self.var_offset={}, a partition of shape={} would exceed " 167 "self.full_shape={} in dimension {}.".format( 168 self.var_offset, shape, self.full_shape, i)) 169 170 slice_dim = None 171 for i in xrange(len(shape)): 172 if shape[i] == self.full_shape[i]: 173 continue 174 if slice_dim is not None: 175 raise ValueError( 176 "Cannot use single_slice_dim() with shape={} and " 177 "self.full_shape={} since slice dim could be either dimension {} " 178 "or {}.".format(shape, self.full_shape, i, slice_dim)) 179 slice_dim = i 180 181 return slice_dim 182 183 184class _ReuseMode(enum.Enum): 185 """Mode for variable access within a variable scope.""" 186 187 # Indicates that variables are to be fetched if they already exist or 188 # otherwise created. 189 AUTO_REUSE = 1 190 191 # TODO(alive): For TensorFlow 2.0, Deprecate True/False/None API in favor of 192 # enum values. 193 # REUSE_FALSE = 2 194 # REUSE_TRUE = 3 195 196 197# TODO(apassos) remove these forwarding symbols. 198VariableSynchronization = variables.VariableSynchronization # pylint: disable=invalid-name 199VariableAggregation = variables.VariableAggregation # pylint: disable=invalid-name 200 201AUTO_REUSE = _ReuseMode.AUTO_REUSE 202tf_export(v1=["AUTO_REUSE"]).export_constant(__name__, "AUTO_REUSE") 203AUTO_REUSE.__doc__ = """ 204When passed in as the value for the `reuse` flag, AUTO_REUSE indicates that 205get_variable() should create the requested variable if it doesn't exist or, if 206it does exist, simply return it. 207""" 208 209 210_DEFAULT_USE_RESOURCE = tf2.enabled() 211 212 213@tf_export(v1=["enable_resource_variables"]) 214def enable_resource_variables(): 215 """Creates resource variables by default. 216 217 Resource variables are improved versions of TensorFlow variables with a 218 well-defined memory model. Accessing a resource variable reads its value, and 219 all ops which access a specific read value of the variable are guaranteed to 220 see the same value for that tensor. Writes which happen after a read (by 221 having a control or data dependency on the read) are guaranteed not to affect 222 the value of the read tensor, and similarly writes which happen before a read 223 are guaranteed to affect the value. No guarantees are made about unordered 224 read/write pairs. 225 226 Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0 227 feature. 228 """ 229 global _DEFAULT_USE_RESOURCE 230 _DEFAULT_USE_RESOURCE = True 231 232 233@tf_export(v1=["resource_variables_enabled"]) 234def resource_variables_enabled(): 235 """Returns `True` if resource variables are enabled. 236 237 Resource variables are improved versions of TensorFlow variables with a 238 well-defined memory model. Accessing a resource variable reads its value, and 239 all ops which access a specific read value of the variable are guaranteed to 240 see the same value for that tensor. Writes which happen after a read (by 241 having a control or data dependency on the read) are guaranteed not to affect 242 the value of the read tensor, and similarly writes which happen before a read 243 are guaranteed to affect the value. No guarantees are made about unordered 244 read/write pairs. 245 246 Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0 247 feature. 248 """ 249 global _DEFAULT_USE_RESOURCE 250 return _DEFAULT_USE_RESOURCE 251 252 253@deprecation.deprecated( 254 None, "non-resource variables are not supported in the long term") 255@tf_export(v1=["disable_resource_variables"]) 256def disable_resource_variables(): 257 """Opts out of resource variables. 258 259 If your code needs tf.disable_resource_variables() to be called to work 260 properly please file a bug. 261 """ 262 global _DEFAULT_USE_RESOURCE 263 _DEFAULT_USE_RESOURCE = False 264 265 266class _VariableStore(object): 267 """Variable store that carries a number of named Variables. 268 269 New variable names and new variables can be created; all stored 270 variables are initialized with the initializer passed to __init__. 271 272 Attributes: 273 vars: a dictionary with string names (same as passed in GetVar) as keys 274 and the corresponding TensorFlow Variables as values. 275 """ 276 277 def __init__(self): 278 """Create a variable store.""" 279 self._vars = {} # A dictionary of the stored TensorFlow variables. 280 self._partitioned_vars = {} # A dict of the stored PartitionedVariables. 281 self._store_eager_variables = False 282 283 def get_variable(self, 284 name, 285 shape=None, 286 dtype=dtypes.float32, 287 initializer=None, 288 regularizer=None, 289 reuse=None, 290 trainable=None, 291 collections=None, 292 caching_device=None, 293 partitioner=None, 294 validate_shape=True, 295 use_resource=None, 296 custom_getter=None, 297 constraint=None, 298 synchronization=VariableSynchronization.AUTO, 299 aggregation=VariableAggregation.NONE): 300 """Gets an existing variable with these parameters or create a new one. 301 302 If a variable with the given name is already stored, we return the stored 303 variable. Otherwise, we create a new one. 304 305 Set `reuse` to `True` when you only want to reuse existing Variables. 306 Set `reuse` to `False` when you only want to create new Variables. 307 Set `reuse` to None (the default) or tf.AUTO_REUSE when you want 308 variables to be created if they don't exist or returned if they do. 309 310 If initializer is `None` (the default), the default initializer passed in 311 the constructor is used. If that one is `None` too, we use a new 312 `glorot_uniform_initializer`. If initializer is a Tensor, we use 313 it as a value and derive the shape from the initializer. 314 315 If a partitioner is provided, a `PartitionedVariable` is returned. 316 Accessing this object as a `Tensor` returns the shards concatenated along 317 the partition axis. 318 319 Some useful partitioners are available. See, e.g., 320 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 321 322 Args: 323 name: The name of the new or existing variable. 324 shape: Shape of the new or existing variable. 325 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 326 initializer: Initializer for the variable. 327 regularizer: A (Tensor -> Tensor or None) function; the result of 328 applying it on a newly created variable will be added to the collection 329 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 330 reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation 331 of variables. When eager execution is enabled this argument is always 332 forced to be False. 333 trainable: If `True` also add the variable to the graph collection 334 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 335 `trainable` defaults to `True` unless `synchronization` is 336 set to `ON_READ`. 337 collections: List of graph collections keys to add the `Variable` to. 338 Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 339 caching_device: Optional device string or function describing where the 340 Variable should be cached for reading. Defaults to the Variable's 341 device. If not `None`, caches on another device. Typical use is to 342 cache on the device where the Ops using the `Variable` reside, to 343 deduplicate copying through `Switch` and other conditional statements. 344 partitioner: Optional callable that accepts a fully defined `TensorShape` 345 and dtype of the `Variable` to be created, and returns a list of 346 partitions for each axis (currently only one axis can be partitioned). 347 validate_shape: If False, allows the variable to be initialized with a 348 value of unknown shape. If True, the default, the shape of initial_value 349 must be known. 350 use_resource: If False, creates a regular Variable. If True, creates 351 instead an experimental ResourceVariable which has well-defined 352 semantics. Defaults to False (will later change to True). 353 When eager execution is enabled this argument is always forced to be 354 true. 355 custom_getter: Callable that takes as a first argument the true getter, 356 and allows overwriting the internal get_variable method. 357 The signature of `custom_getter` should match that of this method, 358 but the most future-proof version will allow for changes: 359 `def custom_getter(getter, *args, **kwargs)`. Direct access to 360 all `get_variable` parameters is also allowed: 361 `def custom_getter(getter, name, *args, **kwargs)`. A simple identity 362 custom getter that simply creates variables with modified names is: 363 ```python 364 def custom_getter(getter, name, *args, **kwargs): 365 return getter(name + '_suffix', *args, **kwargs) 366 ``` 367 constraint: An optional projection function to be applied to the variable 368 after being updated by an `Optimizer` (e.g. used to implement norm 369 constraints or value constraints for layer weights). The function must 370 take as input the unprojected Tensor representing the value of the 371 variable and return the Tensor for the projected value 372 (which must have the same shape). Constraints are not safe to 373 use when doing asynchronous distributed training. 374 synchronization: Indicates when a distributed a variable will be 375 aggregated. Accepted values are constants defined in the class 376 `tf.VariableSynchronization`. By default the synchronization is set to 377 `AUTO` and the current `DistributionStrategy` chooses 378 when to synchronize. If `synchronization` is set to `ON_READ`, 379 `trainable` must not be set to `True`. 380 aggregation: Indicates how a distributed variable will be aggregated. 381 Accepted values are constants defined in the class 382 `tf.VariableAggregation`. 383 384 Returns: 385 The created or existing `Variable` (or `PartitionedVariable`, if a 386 partitioner was used). 387 388 Raises: 389 ValueError: when creating a new variable and shape is not declared, 390 when reusing a variable and specifying a conflicting shape, 391 or when violating reuse during variable creation. 392 RuntimeError: when eager execution is enabled and not called from an 393 EagerVariableStore. 394 """ 395 if custom_getter is not None and not callable(custom_getter): 396 raise ValueError( 397 "Passed a custom_getter which is not callable: %s" % custom_getter) 398 399 with ops.init_scope(): 400 if context.executing_eagerly(): 401 # Variable creation and initialization takes place in `init_scope`s; 402 # as such, if an `init_scope` lifts us into the eager context, then we 403 # need to use `ResourceVariable`s. 404 use_resource = True 405 406 # Note that it's fine to reuse eager variables whose initialization was 407 # lifted from a function-building graph into the eager context (that's why 408 # the following clause is not wrapped in an `init_scope`); lifted variables 409 # are tracked by the graph's `VariableStore`. 410 if context.executing_eagerly(): 411 if not self._store_eager_variables and reuse: 412 raise RuntimeError( 413 "When eager execution is enabled variable reuse is only supported" 414 " when an EagerVariableStore is active. See the documentation on" 415 " EagerVariableStore for example usage.") 416 if self._store_eager_variables: 417 reuse = AUTO_REUSE 418 419 # If a *_ref type is passed in an error would be triggered further down the 420 # stack. We prevent this using base_dtype to get a non-ref version of the 421 # type, before doing anything else. When _ref types are removed in favor of 422 # resources, this line can be removed. 423 try: 424 dtype = dtype.base_dtype 425 except AttributeError: 426 # .base_dtype not existing means that we will try and use the raw dtype 427 # which was passed in - this might be a NumPy type which is valid. 428 pass 429 430 # This is the main logic of get_variable. However, custom_getter 431 # may override this logic. So we save it as a callable and pass 432 # it to custom_getter. 433 # Note: the parameters of _true_getter, and their documentation, match 434 # *exactly* item-for-item with the docstring of this method. 435 def _true_getter( # pylint: disable=missing-docstring 436 name, 437 shape=None, 438 dtype=dtypes.float32, 439 initializer=None, 440 regularizer=None, 441 reuse=None, 442 trainable=None, 443 collections=None, 444 caching_device=None, 445 partitioner=None, 446 validate_shape=True, 447 use_resource=None, 448 constraint=None, 449 synchronization=VariableSynchronization.AUTO, 450 aggregation=VariableAggregation.NONE): 451 is_scalar = (shape is not None 452 and isinstance(shape, collections_lib.Sequence) 453 and not shape) 454 # Partitioned variable case 455 if partitioner is not None and not is_scalar: 456 if not callable(partitioner): 457 raise ValueError( 458 "Partitioner must be callable, but received: %s" % partitioner) 459 with ops.name_scope(None): 460 return self._get_partitioned_variable( 461 name=name, 462 shape=shape, 463 dtype=dtype, 464 initializer=initializer, 465 regularizer=regularizer, 466 reuse=reuse, 467 trainable=trainable, 468 collections=collections, 469 caching_device=caching_device, 470 partitioner=partitioner, 471 validate_shape=validate_shape, 472 use_resource=use_resource, 473 constraint=constraint, 474 synchronization=synchronization, 475 aggregation=aggregation) 476 477 # Special case for partitioned variable to allow reuse without having to 478 # specify partitioner. 479 if (reuse is True and partitioner is None 480 and name in self._partitioned_vars): 481 return self._get_partitioned_variable( 482 name=name, 483 shape=shape, 484 dtype=dtype, 485 initializer=initializer, 486 regularizer=regularizer, 487 reuse=reuse, 488 trainable=trainable, 489 collections=collections, 490 caching_device=caching_device, 491 partitioner=None, 492 validate_shape=validate_shape, 493 use_resource=use_resource, 494 constraint=constraint, 495 synchronization=synchronization, 496 aggregation=aggregation) 497 498 # Single variable case 499 if "%s/part_0" % name in self._vars: 500 raise ValueError( 501 "No partitioner was provided, but a partitioned version of the " 502 "variable was found: %s/part_0. Perhaps a variable of the same " 503 "name was already created with partitioning?" % name) 504 505 return self._get_single_variable( 506 name=name, 507 shape=shape, 508 dtype=dtype, 509 initializer=initializer, 510 regularizer=regularizer, 511 reuse=reuse, 512 trainable=trainable, 513 collections=collections, 514 caching_device=caching_device, 515 validate_shape=validate_shape, 516 use_resource=use_resource, 517 constraint=constraint, 518 synchronization=synchronization, 519 aggregation=aggregation) 520 521 # Set trainable value based on synchronization value. 522 trainable = _get_trainable_value( 523 synchronization=synchronization, trainable=trainable) 524 525 if custom_getter is not None: 526 # Handle backwards compatibility with getter arguments that were added 527 # to the API after users started writing custom getters. 528 custom_getter_kwargs = { 529 "getter": _true_getter, 530 "name": name, 531 "shape": shape, 532 "dtype": dtype, 533 "initializer": initializer, 534 "regularizer": regularizer, 535 "reuse": reuse, 536 "trainable": trainable, 537 "collections": collections, 538 "caching_device": caching_device, 539 "partitioner": partitioner, 540 "validate_shape": validate_shape, 541 "use_resource": use_resource, 542 "synchronization": synchronization, 543 "aggregation": aggregation, 544 } 545 # `fn_args` and `has_kwargs` can handle functions, `functools.partial`, 546 # `lambda`. 547 if ("constraint" in function_utils.fn_args(custom_getter) or 548 function_utils.has_kwargs(custom_getter)): 549 custom_getter_kwargs["constraint"] = constraint 550 return custom_getter(**custom_getter_kwargs) 551 else: 552 return _true_getter( 553 name, 554 shape=shape, 555 dtype=dtype, 556 initializer=initializer, 557 regularizer=regularizer, 558 reuse=reuse, 559 trainable=trainable, 560 collections=collections, 561 caching_device=caching_device, 562 partitioner=partitioner, 563 validate_shape=validate_shape, 564 use_resource=use_resource, 565 constraint=constraint, 566 synchronization=synchronization, 567 aggregation=aggregation) 568 569 def _get_partitioned_variable(self, 570 name, 571 partitioner, 572 shape=None, 573 dtype=dtypes.float32, 574 initializer=None, 575 regularizer=None, 576 reuse=None, 577 trainable=None, 578 collections=None, 579 caching_device=None, 580 validate_shape=True, 581 use_resource=None, 582 constraint=None, 583 synchronization=VariableSynchronization.AUTO, 584 aggregation=VariableAggregation.NONE): 585 """Gets or creates a sharded variable list with these parameters. 586 587 The `partitioner` must be a callable that accepts a fully defined 588 `TensorShape` and returns a sequence of integers (the `partitions`). 589 These integers describe how to partition the given sharded `Variable` 590 along the given dimension. That is, `partitions[1] = 3` means split 591 the `Variable` into 3 shards along dimension 1. Currently, sharding along 592 only one axis is supported. 593 594 If the list of variables with the given name (prefix) is already stored, 595 we return the stored variables. Otherwise, we create a new one. 596 597 Set `reuse` to `True` when you only want to reuse existing Variables. 598 Set `reuse` to `False` when you only want to create new Variables. 599 Set `reuse` to None (the default) or tf.AUTO_REUSE when you want 600 variables to be created if they don't exist or returned if they do. 601 602 If initializer is `None` (the default), the default initializer passed in 603 the constructor is used. If that one is `None` too, we use a new 604 `glorot_uniform_initializer`. If initializer is a Tensor, we use 605 it as a value and derive the shape from the initializer. 606 607 If the initializer is a callable, then it will be called for each 608 shard. Otherwise the initializer should match the shape of the entire 609 sharded Variable, and it will be sliced accordingly for each shard. 610 611 Some useful partitioners are available. See, e.g., 612 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 613 614 Args: 615 name: the name of the new or existing sharded variable. 616 partitioner: Optional callable that accepts a fully defined `TensorShape` 617 and `dtype` of the Variable to be created, and returns a list of 618 partitions for each axis (currently only one axis can be partitioned). 619 shape: shape of the new or existing sharded variable. 620 dtype: type of the new or existing sharded variable 621 (defaults to `DT_FLOAT`). 622 initializer: initializer for the sharded variable. 623 regularizer: a (Tensor -> Tensor or None) function; the result of 624 applying it on a newly created variable will be added to the collection 625 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 626 reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation 627 of variables. 628 trainable: If `True` also add the variable to the graph collection 629 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 630 collections: List of graph collections keys to add the Variable to. 631 Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 632 caching_device: Optional device string or function describing where the 633 Variable should be cached for reading. Defaults to the Variable's 634 device. If not `None`, caches on another device. Typical use is to 635 cache on the device where the Ops using the Variable reside, to 636 deduplicate copying through `Switch` and other conditional statements. 637 validate_shape: If False, allows the variable to be initialized with a 638 value of unknown shape. If True, the default, the shape of initial_value 639 must be known. 640 use_resource: If False, creates a regular Variable. If True, creates an 641 experimental ResourceVariable which has well-defined semantics. Defaults 642 to False (will later change to True). 643 constraint: An optional projection function to be applied to the variable 644 after being updated by an `Optimizer` (e.g. used to implement norm 645 constraints or value constraints for layer weights). The function must 646 take as input the unprojected Tensor representing the value of the 647 variable and return the Tensor for the projected value 648 (which must have the same shape). Constraints are not safe to 649 use when doing asynchronous distributed training. 650 synchronization: Indicates when a distributed a variable will be 651 aggregated. Accepted values are constants defined in the class 652 `tf.VariableSynchronization`. By default the synchronization is set to 653 `AUTO` and the current `DistributionStrategy` chooses 654 when to synchronize. If `synchronization` is set to `ON_READ`, 655 `trainable` must not be set to `True`. 656 aggregation: Indicates how a distributed variable will be aggregated. 657 Accepted values are constants defined in the class 658 `tf.VariableAggregation`. 659 660 Returns: 661 A `PartitionedVariable` object. 662 663 Raises: 664 ValueError: when creating a new variable and shape is not declared, 665 when reusing a variable and specifying a conflicting shape, 666 when violating reuse during variable creation, or if an existing 667 sharded variable exists for the given name but with different sharding. 668 """ 669 initializing_from_value = initializer is not None and isinstance( 670 initializer, ops.Tensor) 671 if name in self._vars: 672 raise ValueError( 673 "A partitioner was provided, but an unpartitioned version of the " 674 "variable was found: %s. Perhaps a variable of the same name was " 675 "already created without partitioning?" % name) 676 677 shape = tensor_shape.as_shape(shape) 678 if initializing_from_value: 679 shape = shape.merge_with(initializer.get_shape()) 680 681 partitions = None 682 if not reuse or partitioner: 683 partitions = _call_partitioner(partitioner, shape, dtype) 684 685 if name in self._partitioned_vars: 686 if reuse is False: 687 raise ValueError( 688 "Partitioned variable with name %s already exists. Did you mean to " 689 "set reuse=True or reuse=tf.AUTO_REUSE in VarScope?" 690 % name) 691 692 existing_var = self._partitioned_vars[name] 693 if not shape.is_compatible_with(existing_var.get_shape()): 694 raise ValueError( 695 "Trying to reuse partitioned variable %s, but specified shape %s " 696 "and found shape %s." 697 % (name, shape, existing_var.get_shape())) 698 if not dtype.is_compatible_with(existing_var.dtype): 699 raise ValueError( 700 "Trying to reuse partitioned variable %s, but specified dtype %s " 701 "and found dtype %s." 702 % (name, dtype.name, existing_var.dtype.name)) 703 704 # pylint: disable=protected-access 705 if (partitions is not None and 706 existing_var._get_partitions() != partitions): 707 raise ValueError( 708 "Trying to reuse partitioned variable %s, but specified partitions " 709 "%s and found partitions %s." % 710 (name, partitions, existing_var._get_partitions())) 711 # pylint: enable=protected-access 712 713 return existing_var 714 715 if reuse is True: 716 raise ValueError("PartitionedVariable %s does not exist, or was not " 717 "created with tf.get_variable(). Did you mean to set " 718 "reuse=False or reuse=tf.AUTO_REUSE in VarScope?" % name) 719 720 slice_dim, num_slices = _get_slice_dim_and_num_slices(partitions) 721 722 if "%s/part_0" % name in self._vars: 723 if "%s/part_%d" % (name, num_slices - 1) not in self._vars: 724 raise ValueError( 725 "Partitioner returned a different partitioning than what was " 726 "already found. Partitioner returned %d shards, and shard " 727 "%s/part_0 was found, but %s/part_%d was not." 728 % (num_slices, name, name, num_slices - 1)) 729 if "%s/part_%d" % (name, num_slices) in self._vars: 730 raise ValueError( 731 "Partitioner returned a different partitioning than what was " 732 "already found. Partitioner returned %d shards, and shard " 733 "%s/part_0 was found, but so was the extra shard %s/part_%d." 734 % (num_slices, name, name, num_slices)) 735 736 vs = [] 737 for i, (var_offset, var_shape) in enumerate(_iter_slices( 738 shape.as_list(), 739 num_slices, 740 slice_dim 741 )): 742 partition_info = _PartitionInfo( 743 full_shape=shape.as_list(), var_offset=var_offset) 744 var_full_name = "%s/part_%d" % (name, i) 745 with ops.name_scope(var_full_name + "/PartitionedInitializer"): 746 # Create the tensor to initialize the variable with default value. 747 if initializer is None: 748 init, initializing_from_value = self._get_default_initializer( 749 name=name, shape=shape, dtype=dtype) 750 if initializing_from_value: 751 init_shape = None 752 else: 753 init_shape = var_shape 754 elif callable(initializer): 755 init = initializer 756 init_shape = var_shape 757 elif isinstance(initializer, ops.Tensor): 758 init = array_ops.slice(initializer, var_offset, var_shape) 759 # Use the dtype of the given tensor. 760 dtype = init.dtype.base_dtype 761 init_shape = None 762 else: 763 init = ops.convert_to_tensor(initializer, dtype=dtype) 764 init = array_ops.slice(init, var_offset, var_shape) 765 init_shape = None 766 767 with ops.name_scope(None): 768 var = self._get_single_variable( 769 name=var_full_name, 770 shape=init_shape, 771 dtype=dtype, 772 initializer=init, 773 partition_info=partition_info, 774 regularizer=regularizer, 775 reuse=reuse, 776 trainable=trainable, 777 collections=collections, 778 caching_device=caching_device, 779 validate_shape=validate_shape, 780 use_resource=use_resource, 781 constraint=constraint, 782 synchronization=synchronization, 783 aggregation=aggregation) 784 785 # pylint: disable=protected-access 786 var._set_save_slice_info(variables.Variable.SaveSliceInfo( 787 name, shape.as_list(), var_offset, var_shape)) 788 vs.append(var) 789 # pylint: enable=protected-access 790 791 partitioned_var = variables.PartitionedVariable(name=name, 792 shape=shape, 793 dtype=dtype, 794 variable_list=vs, 795 partitions=partitions) 796 if not context.executing_eagerly() or self._store_eager_variables: 797 self._partitioned_vars[name] = partitioned_var 798 return partitioned_var 799 800 def _get_single_variable(self, 801 name, 802 shape=None, 803 dtype=dtypes.float32, 804 initializer=None, 805 regularizer=None, 806 partition_info=None, 807 reuse=None, 808 trainable=None, 809 collections=None, 810 caching_device=None, 811 validate_shape=True, 812 use_resource=None, 813 constraint=None, 814 synchronization=VariableSynchronization.AUTO, 815 aggregation=VariableAggregation.NONE): 816 """Get or create a single Variable (e.g. a shard or entire variable). 817 818 See the documentation of get_variable above (ignore partitioning components) 819 for details. 820 821 Args: 822 name: see get_variable. 823 shape: see get_variable. 824 dtype: see get_variable. 825 initializer: see get_variable. 826 regularizer: see get_variable. 827 partition_info: _PartitionInfo object. 828 reuse: see get_variable. 829 trainable: see get_variable. 830 collections: see get_variable. 831 caching_device: see get_variable. 832 validate_shape: see get_variable. 833 use_resource: see get_variable. 834 constraint: see get_variable. 835 synchronization: see get_variable. 836 aggregation: see get_variable. 837 838 Returns: 839 A Variable. See documentation of get_variable above. 840 841 Raises: 842 ValueError: See documentation of get_variable above. 843 """ 844 # Set to true if initializer is a constant. 845 initializing_from_value = False 846 if initializer is not None and not callable(initializer): 847 initializing_from_value = True 848 if shape is not None and initializing_from_value: 849 raise ValueError("If initializer is a constant, do not specify shape.") 850 851 dtype = dtypes.as_dtype(dtype) 852 shape = tensor_shape.as_shape(shape) 853 854 if name in self._vars: 855 # Here we handle the case when returning an existing variable. 856 if reuse is False: 857 var = self._vars[name] 858 err_msg = ("Variable %s already exists, disallowed." 859 " Did you mean to set reuse=True or " 860 "reuse=tf.AUTO_REUSE in VarScope?" % name) 861 # ResourceVariables don't have an op associated with so no traceback 862 if isinstance(var, resource_variable_ops.ResourceVariable): 863 raise ValueError(err_msg) 864 tb = var.op.traceback[::-1] 865 # Throw away internal tf entries and only take a few lines. In some 866 # cases the traceback can be longer (e.g. if someone uses factory 867 # functions to create variables) so we take more than needed in the 868 # default case. 869 tb = [x for x in tb if "tensorflow/python" not in x[0]][:5] 870 raise ValueError("%s Originally defined at:\n\n%s" % (err_msg, "".join( 871 traceback.format_list(tb)))) 872 found_var = self._vars[name] 873 if not shape.is_compatible_with(found_var.get_shape()): 874 raise ValueError("Trying to share variable %s, but specified shape %s" 875 " and found shape %s." % (name, shape, 876 found_var.get_shape())) 877 if not dtype.is_compatible_with(found_var.dtype): 878 dtype_str = dtype.name 879 found_type_str = found_var.dtype.name 880 raise ValueError("Trying to share variable %s, but specified dtype %s" 881 " and found dtype %s." % (name, dtype_str, 882 found_type_str)) 883 return found_var 884 885 # The code below handles only the case of creating a new variable. 886 if reuse is True: 887 raise ValueError("Variable %s does not exist, or was not created with " 888 "tf.get_variable(). Did you mean to set " 889 "reuse=tf.AUTO_REUSE in VarScope?" % name) 890 891 # Create the tensor to initialize the variable with default value. 892 if initializer is None: 893 initializer, initializing_from_value = self._get_default_initializer( 894 name=name, shape=shape, dtype=dtype) 895 # Enter an init scope when creating the initializer. 896 with ops.init_scope(): 897 if initializing_from_value: 898 init_val = initializer 899 variable_dtype = None 900 else: 901 # Instantiate initializer if provided initializer is a type object. 902 if tf_inspect.isclass(initializer): 903 initializer = initializer(dtype=dtype) 904 if shape is not None and shape.is_fully_defined(): 905 init_val = lambda: initializer( # pylint: disable=g-long-lambda 906 shape.as_list(), dtype=dtype, partition_info=partition_info) 907 variable_dtype = dtype.base_dtype 908 elif len(tf_inspect.getargspec(initializer).args) == len( 909 tf_inspect.getargspec(initializer).defaults or []): 910 init_val = initializer 911 variable_dtype = None 912 else: 913 raise ValueError("The initializer passed is not valid. It should " 914 "be a callable with no arguments and the " 915 "shape should not be provided or an instance of " 916 "`tf.keras.initializers.*' and `shape` should be " 917 "fully defined.") 918 919 # Create the variable. 920 if use_resource is None: 921 # Set the default value if unspecified. 922 use_resource = _DEFAULT_USE_RESOURCE 923 v = variables.VariableV1( 924 initial_value=init_val, 925 name=name, 926 trainable=trainable, 927 collections=collections, 928 caching_device=caching_device, 929 dtype=variable_dtype, 930 validate_shape=validate_shape, 931 constraint=constraint, 932 use_resource=use_resource, 933 synchronization=synchronization, 934 aggregation=aggregation) 935 if context.executing_eagerly() and self._store_eager_variables: 936 if collections: 937 ops.add_to_collections(collections, v) 938 else: 939 ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, v) 940 if trainable: 941 ops.add_to_collection(ops.GraphKeys.TRAINABLE_VARIABLES, v) 942 943 if not context.executing_eagerly() or self._store_eager_variables: 944 # In eager mode we do not want to keep default references to Variable 945 # objects as this will prevent their memory from being released. 946 self._vars[name] = v 947 logging.vlog(1, "Created variable %s with shape %s and init %s", v.name, 948 format(shape), initializer) 949 950 # Run the regularizer if requested and save the resulting loss. 951 if regularizer: 952 with ops.colocate_with(v): 953 with ops.name_scope(name + "/Regularizer/"): 954 with ops.init_scope(): 955 loss = regularizer(v) 956 if loss is not None: 957 if context.executing_eagerly(): 958 v_name = "v_%s" % type(v) 959 loss_name = "loss_%s" % type(loss) 960 else: 961 v_name = v.name 962 loss_name = loss.name 963 logging.vlog(1, "Applied regularizer to %s and added the result %s " 964 "to REGULARIZATION_LOSSES.", v_name, loss_name) 965 ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, loss) 966 return v 967 968 # Initialize variable when no initializer provided 969 def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32): 970 """Provide a default initializer and a corresponding value. 971 972 Args: 973 name: see get_variable. 974 shape: see get_variable. 975 dtype: see get_variable. 976 977 Returns: 978 initializer and initializing_from_value. See get_variable above. 979 980 Raises: 981 ValueError: When giving unsupported dtype. 982 """ 983 del shape 984 # If dtype is DT_FLOAT, provide a uniform unit scaling initializer 985 if dtype.is_floating: 986 initializer = init_ops.glorot_uniform_initializer() 987 initializing_from_value = False 988 # If dtype is DT_INT/DT_UINT, provide a default value `zero` 989 # If dtype is DT_BOOL, provide a default value `FALSE` 990 elif (dtype.is_integer or dtype.is_unsigned or dtype.is_bool 991 or dtype == dtypes.string): 992 initializer = init_ops.zeros_initializer() 993 initializing_from_value = False 994 # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here? 995 else: 996 raise ValueError("An initializer for variable %s of %s is required" 997 % (name, dtype.base_dtype)) 998 999 return initializer, initializing_from_value 1000 1001 1002# To stop regularization, use this regularizer 1003@tf_export("no_regularizer") 1004def no_regularizer(_): 1005 """Use this function to prevent regularization of variables.""" 1006 return None 1007 1008 1009# TODO(alive): support caching devices and partitioned variables in Eager mode. 1010@tf_export(v1=["VariableScope"]) 1011class VariableScope(object): 1012 """Variable scope object to carry defaults to provide to `get_variable`. 1013 1014 Many of the arguments we need for `get_variable` in a variable store are most 1015 easily handled with a context. This object is used for the defaults. 1016 1017 Attributes: 1018 name: name of the current scope, used as prefix in get_variable. 1019 initializer: default initializer passed to get_variable. 1020 regularizer: default regularizer passed to get_variable. 1021 reuse: Boolean, None, or tf.AUTO_REUSE, setting the reuse in 1022 get_variable. When eager execution is enabled this argument is always 1023 forced to be False. 1024 caching_device: string, callable, or None: the caching device passed to 1025 get_variable. 1026 partitioner: callable or `None`: the partitioner passed to `get_variable`. 1027 custom_getter: default custom getter passed to get_variable. 1028 name_scope: The name passed to `tf.name_scope`. 1029 dtype: default type passed to get_variable (defaults to DT_FLOAT). 1030 use_resource: if False, create a normal Variable; if True create an 1031 experimental ResourceVariable with well-defined semantics. Defaults 1032 to False (will later change to True). When eager execution is enabled 1033 this argument is always forced to be True. 1034 constraint: An optional projection function to be applied to the variable 1035 after being updated by an `Optimizer` (e.g. used to implement norm 1036 constraints or value constraints for layer weights). The function must 1037 take as input the unprojected Tensor representing the value of the 1038 variable and return the Tensor for the projected value 1039 (which must have the same shape). Constraints are not safe to 1040 use when doing asynchronous distributed training. 1041 """ 1042 1043 def __init__(self, 1044 reuse, 1045 name="", 1046 initializer=None, 1047 regularizer=None, 1048 caching_device=None, 1049 partitioner=None, 1050 custom_getter=None, 1051 name_scope="", 1052 dtype=dtypes.float32, 1053 use_resource=None, 1054 constraint=None): 1055 """Creates a new VariableScope with the given properties.""" 1056 self._name = name 1057 self._initializer = initializer 1058 self._regularizer = regularizer 1059 self._reuse = reuse 1060 self._caching_device = caching_device 1061 self._partitioner = partitioner 1062 self._custom_getter = custom_getter 1063 self._name_scope = name_scope 1064 self._dtype = dtype 1065 self._use_resource = use_resource 1066 self._constraint = constraint 1067 if context.executing_eagerly(): 1068 if self._caching_device is not None: 1069 raise NotImplementedError("Caching devices is not yet supported " 1070 "when eager execution is enabled.") 1071 self._reuse = AUTO_REUSE 1072 self._use_resource = True 1073 1074 @property 1075 def name(self): 1076 return self._name 1077 1078 @property 1079 def original_name_scope(self): 1080 return self._name_scope 1081 1082 @property 1083 def reuse(self): 1084 return self._reuse 1085 1086 @property 1087 def initializer(self): 1088 return self._initializer 1089 1090 @property 1091 def dtype(self): 1092 return self._dtype 1093 1094 @property 1095 def use_resource(self): 1096 return self._use_resource 1097 1098 @property 1099 def regularizer(self): 1100 return self._regularizer 1101 1102 @property 1103 def caching_device(self): 1104 return self._caching_device 1105 1106 @property 1107 def partitioner(self): 1108 return self._partitioner 1109 1110 @property 1111 def custom_getter(self): 1112 return self._custom_getter 1113 1114 @property 1115 def constraint(self): 1116 return self._constraint 1117 1118 def reuse_variables(self): 1119 """Reuse variables in this scope.""" 1120 self._reuse = True 1121 1122 def set_initializer(self, initializer): 1123 """Set initializer for this scope.""" 1124 self._initializer = initializer 1125 1126 def set_dtype(self, dtype): 1127 """Set data type for this scope.""" 1128 self._dtype = dtype 1129 1130 def set_use_resource(self, use_resource): 1131 """Sets whether to use ResourceVariables for this scope.""" 1132 if context.executing_eagerly() and not use_resource: 1133 raise ValueError("When eager execution is enabled, " 1134 "use_resource cannot be set to false.") 1135 self._use_resource = use_resource 1136 1137 def set_regularizer(self, regularizer): 1138 """Set regularizer for this scope.""" 1139 self._regularizer = regularizer 1140 1141 def set_caching_device(self, caching_device): 1142 """Set caching_device for this scope.""" 1143 if context.executing_eagerly(): 1144 raise NotImplementedError("Caching devices are not yet supported " 1145 "when eager execution is enabled.") 1146 self._caching_device = caching_device 1147 1148 def set_partitioner(self, partitioner): 1149 """Set partitioner for this scope.""" 1150 self._partitioner = partitioner 1151 1152 def set_custom_getter(self, custom_getter): 1153 """Set custom getter for this scope.""" 1154 self._custom_getter = custom_getter 1155 1156 def get_collection(self, name): 1157 """Get this scope's variables.""" 1158 scope = self._name + "/" if self._name else "" 1159 return ops.get_collection(name, scope) 1160 1161 def trainable_variables(self): 1162 """Get this scope's trainable variables.""" 1163 return self.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) 1164 1165 def global_variables(self): 1166 """Get this scope's global variables.""" 1167 return self.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 1168 1169 def local_variables(self): 1170 """Get this scope's local variables.""" 1171 return self.get_collection(ops.GraphKeys.LOCAL_VARIABLES) 1172 1173 def get_variable(self, 1174 var_store, 1175 name, 1176 shape=None, 1177 dtype=None, 1178 initializer=None, 1179 regularizer=None, 1180 reuse=None, 1181 trainable=None, 1182 collections=None, 1183 caching_device=None, 1184 partitioner=None, 1185 validate_shape=True, 1186 use_resource=None, 1187 custom_getter=None, 1188 constraint=None, 1189 synchronization=VariableSynchronization.AUTO, 1190 aggregation=VariableAggregation.NONE): 1191 """Gets an existing variable with this name or create a new one.""" 1192 if regularizer is None: 1193 regularizer = self._regularizer 1194 if caching_device is None: 1195 caching_device = self._caching_device 1196 if partitioner is None: 1197 partitioner = self._partitioner 1198 if custom_getter is None: 1199 custom_getter = self._custom_getter 1200 if context.executing_eagerly(): 1201 reuse = False 1202 use_resource = True 1203 else: 1204 if reuse is None: 1205 reuse = self._reuse 1206 if use_resource is None: 1207 use_resource = self._use_resource 1208 1209 full_name = self.name + "/" + name if self.name else name 1210 # Variable names only depend on variable_scope (full_name here), 1211 # not name_scope, so we reset it below for the time of variable creation. 1212 with ops.name_scope(None): 1213 # Check that `initializer` dtype and `dtype` are consistent before 1214 # replacing them with defaults. 1215 if (dtype is not None and initializer is not None and 1216 not callable(initializer)): 1217 init_dtype = ops.convert_to_tensor(initializer).dtype.base_dtype 1218 if init_dtype != dtype: 1219 raise ValueError("Initializer type '%s' and explicit dtype '%s' " 1220 "don't match." % (init_dtype, dtype)) 1221 if initializer is None: 1222 initializer = self._initializer 1223 if constraint is None: 1224 constraint = self._constraint 1225 if dtype is None: 1226 dtype = self._dtype 1227 return var_store.get_variable( 1228 full_name, 1229 shape=shape, 1230 dtype=dtype, 1231 initializer=initializer, 1232 regularizer=regularizer, 1233 reuse=reuse, 1234 trainable=trainable, 1235 collections=collections, 1236 caching_device=caching_device, 1237 partitioner=partitioner, 1238 validate_shape=validate_shape, 1239 use_resource=use_resource, 1240 custom_getter=custom_getter, 1241 constraint=constraint, 1242 synchronization=synchronization, 1243 aggregation=aggregation) 1244 1245 def _get_partitioned_variable(self, 1246 var_store, 1247 name, 1248 shape=None, 1249 dtype=None, 1250 initializer=None, 1251 regularizer=None, 1252 trainable=None, 1253 collections=None, 1254 caching_device=None, 1255 partitioner=None, 1256 validate_shape=True, 1257 use_resource=None, 1258 constraint=None, 1259 synchronization=VariableSynchronization.AUTO, 1260 aggregation=VariableAggregation.NONE): 1261 """Gets an existing variable with this name or create a new one.""" 1262 if initializer is None: 1263 initializer = self._initializer 1264 if regularizer is None: 1265 regularizer = self._regularizer 1266 if constraint is None: 1267 constraint = self._constraint 1268 if caching_device is None: 1269 caching_device = self._caching_device 1270 if partitioner is None: 1271 partitioner = self._partitioner 1272 if dtype is None: 1273 dtype = self._dtype 1274 if use_resource is None: 1275 use_resource = self._use_resource 1276 1277 if self._custom_getter is not None: 1278 raise ValueError( 1279 "Private access to _get_partitioned_variable is not allowed when " 1280 "a custom getter is set. Current custom getter: %s. " 1281 "It is likely that you're using create_partitioned_variables. " 1282 "If so, consider instead using get_variable with a non-empty " 1283 "partitioner parameter instead." % self._custom_getter) 1284 1285 if partitioner is None: 1286 raise ValueError("No partitioner was specified") 1287 1288 # This allows the variable scope name to be used as the variable name if 1289 # this function is invoked with an empty name arg, for backward 1290 # compatibility with create_partitioned_variables(). 1291 full_name_list = [] 1292 if self.name: 1293 full_name_list.append(self.name) 1294 if name: 1295 full_name_list.append(name) 1296 full_name = "/".join(full_name_list) 1297 1298 # Variable names only depend on variable_scope (full_name here), 1299 # not name_scope, so we reset it below for the time of variable creation. 1300 with ops.name_scope(None): 1301 # pylint: disable=protected-access 1302 return var_store._get_partitioned_variable( 1303 full_name, 1304 shape=shape, 1305 dtype=dtype, 1306 initializer=initializer, 1307 regularizer=regularizer, 1308 reuse=self.reuse, 1309 trainable=trainable, 1310 collections=collections, 1311 caching_device=caching_device, 1312 partitioner=partitioner, 1313 validate_shape=validate_shape, 1314 use_resource=use_resource, 1315 constraint=constraint, 1316 synchronization=synchronization, 1317 aggregation=aggregation) 1318 # pylint: enable=protected-access 1319 1320 1321_VARSTORE_KEY = ("__variable_store",) 1322_VARSCOPESTORE_KEY = ("__varscope",) 1323 1324 1325class _VariableScopeStore(threading.local): 1326 """A thread local store for the current variable scope and scope counts.""" 1327 1328 def __init__(self): 1329 super(_VariableScopeStore, self).__init__() 1330 self.current_scope = VariableScope(False) 1331 self.variable_scopes_count = {} 1332 1333 def open_variable_scope(self, scope_name): 1334 if scope_name in self.variable_scopes_count: 1335 self.variable_scopes_count[scope_name] += 1 1336 else: 1337 self.variable_scopes_count[scope_name] = 1 1338 1339 def close_variable_subscopes(self, scope_name): 1340 for k in list(self.variable_scopes_count.keys()): 1341 if scope_name is None or k.startswith(scope_name + "/"): 1342 self.variable_scopes_count[k] = 0 1343 1344 def variable_scope_count(self, scope_name): 1345 return self.variable_scopes_count.get(scope_name, 0) 1346 1347 1348def get_variable_scope_store(): 1349 """Returns the variable scope store for current thread.""" 1350 scope_store = ops.get_collection(_VARSCOPESTORE_KEY) 1351 1352 if not scope_store: 1353 scope_store = _VariableScopeStore() 1354 ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store) 1355 else: 1356 scope_store = scope_store[0] 1357 1358 return scope_store 1359 1360 1361@tf_export(v1=["get_variable_scope"]) 1362def get_variable_scope(): 1363 """Returns the current variable scope.""" 1364 return get_variable_scope_store().current_scope 1365 1366 1367def _get_default_variable_store(): 1368 store = ops.get_collection(_VARSTORE_KEY) 1369 if store: 1370 return store[0] 1371 store = _VariableStore() 1372 ops.add_to_collection(_VARSTORE_KEY, store) 1373 return store 1374 1375 1376@tf_contextlib.contextmanager 1377def with_variable_store(store): 1378 store_collection = ops.get_collection_ref(_VARSTORE_KEY) 1379 old = list(store_collection) 1380 store_collection[:] = [store] 1381 try: 1382 yield 1383 finally: 1384 store_collection[:] = old 1385 1386 1387class EagerVariableStore(object): 1388 """Wrapper allowing functional layers to be used with eager execution. 1389 1390 When eager execution is enabled Variables get deleted when they go out of 1391 scope, and are not stored in global collections by default. A lot of code 1392 (mostly the functional layers in tf.layers) assumes that variables are kept in 1393 a global list. 1394 1395 EagerVariableStore can be used in conjunction with this code to make it 1396 eager-friendly. For example, to create a dense layer, use: 1397 1398 ``` 1399 container = tfe.EagerVariableStore() 1400 for input in dataset_iterator: 1401 with container.as_default(): 1402 x = tf.layers.dense(input, name="l1") 1403 print(container.variables) # Should print the variables used in the layer. 1404 ``` 1405 """ 1406 1407 def __init__(self, store=None): 1408 if store is not None: 1409 if not store._store_eager_variables: # pylint: disable=protected-access 1410 raise ValueError("Cannot construct EagerVariableStore from a " 1411 "VariableStore object that does not hold eager " 1412 "variables.") 1413 self._store = store 1414 else: 1415 self._store = _VariableStore() 1416 self._store._store_eager_variables = True # pylint: disable=protected-access 1417 1418 def as_default(self): 1419 return with_variable_store(self._store) 1420 1421 def variables(self): 1422 return sorted(self._store._vars.values(), key=lambda x: x.name) # pylint: disable=protected-access 1423 1424 def trainable_variables(self): 1425 # pylint: disable=protected-access 1426 return sorted([x for x in self._store._vars.values() if x.trainable], 1427 key=lambda x: x.name) 1428 # pylint: enable=protected-access 1429 1430 def non_trainable_variables(self): 1431 # pylint: disable=protected-access 1432 return sorted([x for x in self._store._vars.values() if not x.trainable], 1433 key=lambda x: x.name) 1434 # pylint: enable=protected-access 1435 1436 def copy(self): 1437 """Copy this variable store and all of its contents. 1438 1439 Variables contained in this store will be copied over to the new variable 1440 store, meaning that they can be modified without affecting the variables in 1441 this store. 1442 1443 Returns: 1444 A new EagerVariableStore instance containing copied variables. 1445 """ 1446 # pylint: disable=protected-access 1447 new_store = EagerVariableStore() 1448 for key, var in iteritems(self._store._vars): 1449 # Strip device out of variable name. 1450 try: 1451 index = var.name.index(":") 1452 except ValueError: 1453 stripped_var_name = var.name 1454 else: 1455 stripped_var_name = var.name[:index] 1456 1457 # Create new variable with same value, name, and "trainable" flag. 1458 new_var = resource_variable_ops.ResourceVariable( 1459 var.read_value(), 1460 name=stripped_var_name, 1461 trainable=var.trainable) 1462 new_store._store._vars[key] = new_var 1463 return new_store 1464 # pylint: enable=protected-access 1465 1466 1467# The argument list for get_variable must match arguments to get_local_variable. 1468# So, if you are updating the arguments, also update arguments to 1469# get_local_variable below. 1470@tf_export(v1=["get_variable"]) 1471def get_variable(name, 1472 shape=None, 1473 dtype=None, 1474 initializer=None, 1475 regularizer=None, 1476 trainable=None, 1477 collections=None, 1478 caching_device=None, 1479 partitioner=None, 1480 validate_shape=True, 1481 use_resource=None, 1482 custom_getter=None, 1483 constraint=None, 1484 synchronization=VariableSynchronization.AUTO, 1485 aggregation=VariableAggregation.NONE): 1486 return get_variable_scope().get_variable( 1487 _get_default_variable_store(), 1488 name, 1489 shape=shape, 1490 dtype=dtype, 1491 initializer=initializer, 1492 regularizer=regularizer, 1493 trainable=trainable, 1494 collections=collections, 1495 caching_device=caching_device, 1496 partitioner=partitioner, 1497 validate_shape=validate_shape, 1498 use_resource=use_resource, 1499 custom_getter=custom_getter, 1500 constraint=constraint, 1501 synchronization=synchronization, 1502 aggregation=aggregation) 1503 1504 1505get_variable_or_local_docstring = ("""%s 1506 1507%sThis function prefixes the name with the current variable scope 1508and performs reuse checks. See the 1509[Variable Scope How To](https://tensorflow.org/guide/variables) 1510for an extensive description of how reusing works. Here is a basic example: 1511 1512```python 1513def foo(): 1514 with tf.variable_scope("foo", reuse=tf.AUTO_REUSE): 1515 v = tf.get_variable("v", [1]) 1516 return v 1517 1518v1 = foo() # Creates v. 1519v2 = foo() # Gets the same, existing v. 1520assert v1 == v2 1521``` 1522 1523If initializer is `None` (the default), the default initializer passed in 1524the variable scope will be used. If that one is `None` too, a 1525`glorot_uniform_initializer` will be used. The initializer can also be 1526a Tensor, in which case the variable is initialized to this value and shape. 1527 1528Similarly, if the regularizer is `None` (the default), the default regularizer 1529passed in the variable scope will be used (if that is `None` too, 1530then by default no regularization is performed). 1531 1532If a partitioner is provided, a `PartitionedVariable` is returned. 1533Accessing this object as a `Tensor` returns the shards concatenated along 1534the partition axis. 1535 1536Some useful partitioners are available. See, e.g., 1537`variable_axis_size_partitioner` and `min_max_variable_partitioner`. 1538 1539Args: 1540 name: The name of the new or existing variable. 1541 shape: Shape of the new or existing variable. 1542 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 1543 initializer: Initializer for the variable if one is created. Can either be 1544 an initializer object or a Tensor. If it's a Tensor, its shape must be known 1545 unless validate_shape is False. 1546 regularizer: A (Tensor -> Tensor or None) function; the result of 1547 applying it on a newly created variable will be added to the collection 1548 `tf.GraphKeys.REGULARIZATION_LOSSES` and can be used for regularization. 1549 %scollections: List of graph collections keys to add the Variable to. 1550 Defaults to `[%s]` (see `tf.Variable`). 1551 caching_device: Optional device string or function describing where the 1552 Variable should be cached for reading. Defaults to the Variable's 1553 device. If not `None`, caches on another device. Typical use is to 1554 cache on the device where the Ops using the Variable reside, to 1555 deduplicate copying through `Switch` and other conditional statements. 1556 partitioner: Optional callable that accepts a fully defined `TensorShape` 1557 and `dtype` of the Variable to be created, and returns a list of 1558 partitions for each axis (currently only one axis can be partitioned). 1559 validate_shape: If False, allows the variable to be initialized with a 1560 value of unknown shape. If True, the default, the shape of initial_value 1561 must be known. For this to be used the initializer must be a Tensor and 1562 not an initializer object. 1563 use_resource: If False, creates a regular Variable. If true, creates an 1564 experimental ResourceVariable instead with well-defined semantics. 1565 Defaults to False (will later change to True). When eager execution is 1566 enabled this argument is always forced to be True. 1567 custom_getter: Callable that takes as a first argument the true getter, and 1568 allows overwriting the internal get_variable method. 1569 The signature of `custom_getter` should match that of this method, 1570 but the most future-proof version will allow for changes: 1571 `def custom_getter(getter, *args, **kwargs)`. Direct access to 1572 all `get_variable` parameters is also allowed: 1573 `def custom_getter(getter, name, *args, **kwargs)`. A simple identity 1574 custom getter that simply creates variables with modified names is: 1575 ```python 1576 def custom_getter(getter, name, *args, **kwargs): 1577 return getter(name + '_suffix', *args, **kwargs) 1578 ``` 1579 constraint: An optional projection function to be applied to the variable 1580 after being updated by an `Optimizer` (e.g. used to implement norm 1581 constraints or value constraints for layer weights). The function must 1582 take as input the unprojected Tensor representing the value of the 1583 variable and return the Tensor for the projected value 1584 (which must have the same shape). Constraints are not safe to 1585 use when doing asynchronous distributed training. 1586 synchronization: Indicates when a distributed a variable will be 1587 aggregated. Accepted values are constants defined in the class 1588 `tf.VariableSynchronization`. By default the synchronization is set to 1589 `AUTO` and the current `DistributionStrategy` chooses 1590 when to synchronize. If `synchronization` is set to `ON_READ`, 1591 `trainable` must not be set to `True`. 1592 aggregation: Indicates how a distributed variable will be aggregated. 1593 Accepted values are constants defined in the class 1594 `tf.VariableAggregation`. 1595 1596Returns: 1597 The created or existing `Variable` (or `PartitionedVariable`, if a 1598 partitioner was used). 1599 1600Raises: 1601 ValueError: when creating a new variable and shape is not declared, 1602 when violating reuse during variable creation, or when `initializer` dtype 1603 and `dtype` don't match. Reuse is set inside `variable_scope`. 1604""") 1605get_variable.__doc__ = get_variable_or_local_docstring % ( 1606 "Gets an existing variable with these parameters or create a new one.", 1607 "", 1608 "trainable: If `True` also add the variable to the graph collection\n" 1609 " `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).\n ", 1610 "GraphKeys.GLOBAL_VARIABLES") 1611 1612 1613# The argument list for get_local_variable must match arguments to get_variable. 1614# So, if you are updating the arguments, also update arguments to get_variable. 1615@tf_export(v1=["get_local_variable"]) 1616def get_local_variable( # pylint: disable=missing-docstring 1617 name, 1618 shape=None, 1619 dtype=None, 1620 initializer=None, 1621 regularizer=None, 1622 trainable=False, # pylint: disable=unused-argument 1623 collections=None, 1624 caching_device=None, 1625 partitioner=None, 1626 validate_shape=True, 1627 use_resource=None, 1628 custom_getter=None, 1629 constraint=None, 1630 synchronization=VariableSynchronization.AUTO, 1631 aggregation=VariableAggregation.NONE): 1632 if collections: 1633 collections += [ops.GraphKeys.LOCAL_VARIABLES] 1634 else: 1635 collections = [ops.GraphKeys.LOCAL_VARIABLES] 1636 return get_variable( 1637 name, 1638 shape=shape, 1639 dtype=dtype, 1640 initializer=initializer, 1641 regularizer=regularizer, 1642 trainable=False, 1643 collections=collections, 1644 caching_device=caching_device, 1645 partitioner=partitioner, 1646 validate_shape=validate_shape, 1647 use_resource=use_resource, 1648 synchronization=synchronization, 1649 aggregation=aggregation, 1650 custom_getter=custom_getter, 1651 constraint=constraint) 1652 1653 1654get_local_variable.__doc__ = get_variable_or_local_docstring % ( 1655 "Gets an existing *local* variable or creates a new one.", 1656 "Behavior is the same as in `get_variable`, except that variables are\n" 1657 "added to the `LOCAL_VARIABLES` collection and `trainable` is set to\n" 1658 "`False`.\n", 1659 "", 1660 "GraphKeys.LOCAL_VARIABLES") 1661 1662 1663def _get_partitioned_variable(name, 1664 shape=None, 1665 dtype=None, 1666 initializer=None, 1667 regularizer=None, 1668 trainable=True, 1669 collections=None, 1670 caching_device=None, 1671 partitioner=None, 1672 validate_shape=True, 1673 use_resource=None, 1674 constraint=None, 1675 synchronization=VariableSynchronization.AUTO, 1676 aggregation=VariableAggregation.NONE): 1677 """Gets or creates a sharded variable list with these parameters. 1678 1679 The `partitioner` must be a callable that accepts a fully defined 1680 `TensorShape` and returns a sequence of integers (the `partitions`). 1681 These integers describe how to partition the given sharded `Variable` 1682 along the given dimension. That is, `partitions[1] = 3` means split 1683 the `Variable` into 3 shards along dimension 1. Currently, sharding along 1684 only one axis is supported. 1685 1686 If the list of variables with the given name (prefix) is already stored, 1687 we return the stored variables. Otherwise, we create a new one. 1688 1689 If initializer is `None` (the default), the default initializer passed in 1690 the constructor is used. If that one is `None` too, we use a new 1691 `glorot_uniform_initializer`. If initializer is a Tensor, we use 1692 it as a value and derive the shape from the initializer. 1693 1694 If the initializer is a callable, then it will be called for each 1695 shard. Otherwise the initializer should match the shape of the entire 1696 sharded Variable, and it will be sliced accordingly for each shard. 1697 1698 Some useful partitioners are available. See, e.g., 1699 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 1700 1701 Args: 1702 name: The name of the new or existing variable. 1703 shape: Shape of the new or existing variable. 1704 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 1705 initializer: Initializer for the variable if one is created. 1706 regularizer: A (Tensor -> Tensor or None) function; the result of 1707 applying it on a newly created variable will be added to the collection 1708 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 1709 trainable: If `True` also add the variable to the graph collection 1710 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 1711 collections: List of graph collections keys to add the Variable to. 1712 Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 1713 caching_device: Optional device string or function describing where the 1714 Variable should be cached for reading. Defaults to the Variable's 1715 device. If not `None`, caches on another device. Typical use is to 1716 cache on the device where the Ops using the Variable reside, to 1717 deduplicate copying through `Switch` and other conditional statements. 1718 partitioner: Optional callable that accepts a fully defined `TensorShape` 1719 and `dtype` of the Variable to be created, and returns a list of 1720 partitions for each axis (currently only one axis can be partitioned). 1721 validate_shape: If False, allows the variable to be initialized with a 1722 value of unknown shape. If True, the default, the shape of initial_value 1723 must be known. 1724 use_resource: If False, creates a regular Variable. If True, creates an 1725 experimental ResourceVariable instead which has well-defined semantics. 1726 Defaults to False (will later change to True). 1727 constraint: An optional projection function to be applied to the variable 1728 after being updated by an `Optimizer` (e.g. used to implement norm 1729 constraints or value constraints for layer weights). The function must 1730 take as input the unprojected Tensor representing the value of the 1731 variable and return the Tensor for the projected value 1732 (which must have the same shape). Constraints are not safe to 1733 use when doing asynchronous distributed training. 1734 synchronization: Indicates when a distributed a variable will be 1735 aggregated. Accepted values are constants defined in the class 1736 `tf.VariableSynchronization`. By default the synchronization is set to 1737 `AUTO` and the current `DistributionStrategy` chooses 1738 when to synchronize. If `synchronization` is set to `ON_READ`, 1739 `trainable` must not be set to `True`. 1740 aggregation: Indicates how a distributed variable will be aggregated. 1741 Accepted values are constants defined in the class 1742 `tf.VariableAggregation`. 1743 1744 Returns: 1745 A tuple `(shards, partitions)` where `shards` is the list of `Variable` 1746 shards and `partitions` is the output of the partitioner on the input 1747 shape. 1748 1749 Raises: 1750 ValueError: when creating a new variable and shape is not declared, 1751 or when violating reuse during variable creation. Reuse is set inside 1752 `variable_scope`. 1753 """ 1754 # pylint: disable=protected-access 1755 scope = get_variable_scope() 1756 if scope.custom_getter is not None: 1757 raise ValueError( 1758 "Private access to _get_partitioned_variable is not allowed when " 1759 "a custom getter is set. Current custom getter: %s. " 1760 "It is likely that you're using create_partitioned_variables. " 1761 "If so, consider instead using get_variable with a non-empty " 1762 "partitioner parameter instead." % scope.custom_getter) 1763 return scope._get_partitioned_variable( 1764 _get_default_variable_store(), 1765 name, 1766 shape=shape, 1767 dtype=dtype, 1768 initializer=initializer, 1769 regularizer=regularizer, 1770 trainable=trainable, 1771 collections=collections, 1772 caching_device=caching_device, 1773 partitioner=partitioner, 1774 validate_shape=validate_shape, 1775 use_resource=use_resource, 1776 constraint=constraint, 1777 synchronization=synchronization, 1778 aggregation=aggregation) 1779 # pylint: enable=protected-access 1780 1781 1782# Named like a function for compatibility with the previous 1783# @tf_contextlib.contextmanager definition. 1784class _pure_variable_scope(object): # pylint: disable=invalid-name 1785 """A context for the variable_scope, see `variable_scope` for docs.""" 1786 1787 def __init__(self, 1788 name_or_scope, 1789 reuse=None, 1790 initializer=None, 1791 regularizer=None, 1792 caching_device=None, 1793 partitioner=None, 1794 custom_getter=None, 1795 old_name_scope=None, 1796 dtype=dtypes.float32, 1797 use_resource=None, 1798 constraint=None): 1799 """Creates a context for the variable_scope, see `variable_scope` for docs. 1800 1801 Note: this does not create a name scope. 1802 1803 Args: 1804 name_or_scope: `string` or `VariableScope`: the scope to open. 1805 reuse: `True` or None, or tf.AUTO_REUSE; if `None`, we inherit the parent 1806 scope's reuse flag. 1807 initializer: default initializer for variables within this scope. 1808 regularizer: default regularizer for variables within this scope. 1809 caching_device: default caching device for variables within this scope. 1810 partitioner: default partitioner for variables within this scope. 1811 custom_getter: default custom getter for variables within this scope. 1812 old_name_scope: the original name scope when re-entering a variable scope. 1813 dtype: type of the variables within this scope (defaults to `DT_FLOAT`). 1814 use_resource: If False, variables in this scope will be regular Variables. 1815 If True, experimental ResourceVariables will be creates instead, with 1816 well-defined semantics. Defaults to False (will later change to True). 1817 constraint: An optional projection function to be applied to the variable 1818 after being updated by an `Optimizer` (e.g. used to implement norm 1819 constraints or value constraints for layer weights). The function must 1820 take as input the unprojected Tensor representing the value of the 1821 variable and return the Tensor for the projected value 1822 (which must have the same shape). Constraints are not safe to 1823 use when doing asynchronous distributed training. 1824 """ 1825 self._name_or_scope = name_or_scope 1826 self._reuse = reuse 1827 self._initializer = initializer 1828 self._regularizer = regularizer 1829 self._caching_device = caching_device 1830 self._partitioner = partitioner 1831 self._custom_getter = custom_getter 1832 self._old_name_scope = old_name_scope 1833 self._dtype = dtype 1834 self._use_resource = use_resource 1835 self._constraint = constraint 1836 self._var_store = _get_default_variable_store() 1837 self._var_scope_store = get_variable_scope_store() 1838 if isinstance(self._name_or_scope, VariableScope): 1839 self._new_name = self._name_or_scope.name 1840 name_scope = self._name_or_scope._name_scope # pylint: disable=protected-access 1841 # Handler for the case when we jump to a shared scope. We create a new 1842 # VariableScope (self._var_scope_object) that contains a copy of the 1843 # provided shared scope, possibly with changed reuse and initializer, if 1844 # the user requested this. 1845 variable_scope_object = VariableScope( 1846 self._name_or_scope.reuse if not self._reuse else self._reuse, 1847 name=self._new_name, 1848 initializer=self._name_or_scope.initializer, 1849 regularizer=self._name_or_scope.regularizer, 1850 caching_device=self._name_or_scope.caching_device, 1851 partitioner=self._name_or_scope.partitioner, 1852 dtype=self._name_or_scope.dtype, 1853 custom_getter=self._name_or_scope.custom_getter, 1854 name_scope=name_scope, 1855 use_resource=self._name_or_scope.use_resource, 1856 constraint=self._constraint) 1857 if self._initializer is not None: 1858 variable_scope_object.set_initializer(self._initializer) 1859 if self._regularizer is not None: 1860 variable_scope_object.set_regularizer(self._regularizer) 1861 if self._caching_device is not None: 1862 variable_scope_object.set_caching_device(self._caching_device) 1863 if self._partitioner is not None: 1864 variable_scope_object.set_partitioner(self._partitioner) 1865 if self._custom_getter is not None: 1866 variable_scope_object.set_custom_getter( 1867 _maybe_wrap_custom_getter( 1868 self._custom_getter, self._name_or_scope.custom_getter)) 1869 if self._dtype is not None: 1870 variable_scope_object.set_dtype(self._dtype) 1871 if self._use_resource is not None: 1872 variable_scope_object.set_use_resource(self._use_resource) 1873 self._cached_variable_scope_object = variable_scope_object 1874 1875 def __enter__(self): 1876 """Begins the scope block. 1877 1878 Returns: 1879 A VariableScope. 1880 Raises: 1881 ValueError: when trying to reuse within a create scope, or create within 1882 a reuse scope, or if reuse is not `None` or `True`. 1883 TypeError: when the types of some arguments are not appropriate. 1884 """ 1885 self._old = self._var_scope_store.current_scope 1886 if isinstance(self._name_or_scope, VariableScope): 1887 self._var_scope_store.open_variable_scope(self._new_name) 1888 self._old_subscopes = copy.copy( 1889 self._var_scope_store.variable_scopes_count) 1890 variable_scope_object = self._cached_variable_scope_object 1891 else: 1892 # Handler for the case when we just prolong current variable scope. 1893 # VariableScope with name extended by the provided one, and inherited 1894 # reuse and initializer (except if the user provided values to set). 1895 self._new_name = ( 1896 self._old.name + "/" + self._name_or_scope if self._old.name 1897 else self._name_or_scope) 1898 self._reuse = (self._reuse 1899 or self._old.reuse) # Re-using is inherited by sub-scopes. 1900 if self._old_name_scope is None: 1901 name_scope = self._name_or_scope 1902 else: 1903 name_scope = self._old_name_scope 1904 variable_scope_object = VariableScope( 1905 self._reuse, 1906 name=self._new_name, 1907 initializer=self._old.initializer, 1908 regularizer=self._old.regularizer, 1909 caching_device=self._old.caching_device, 1910 partitioner=self._old.partitioner, 1911 dtype=self._old.dtype, 1912 use_resource=self._old.use_resource, 1913 custom_getter=self._old.custom_getter, 1914 name_scope=name_scope, 1915 constraint=self._constraint) 1916 if self._initializer is not None: 1917 variable_scope_object.set_initializer(self._initializer) 1918 if self._regularizer is not None: 1919 variable_scope_object.set_regularizer(self._regularizer) 1920 if self._caching_device is not None: 1921 variable_scope_object.set_caching_device(self._caching_device) 1922 if self._partitioner is not None: 1923 variable_scope_object.set_partitioner(self._partitioner) 1924 if self._custom_getter is not None: 1925 variable_scope_object.set_custom_getter( 1926 _maybe_wrap_custom_getter(self._custom_getter, 1927 self._old.custom_getter)) 1928 if self._dtype is not None: 1929 variable_scope_object.set_dtype(self._dtype) 1930 if self._use_resource is not None: 1931 variable_scope_object.set_use_resource(self._use_resource) 1932 self._var_scope_store.open_variable_scope(self._new_name) 1933 self._var_scope_store.current_scope = variable_scope_object 1934 return variable_scope_object 1935 1936 def __exit__(self, type_arg, value_arg, traceback_arg): 1937 # If jumping out from a non-prolonged scope, restore counts. 1938 if isinstance(self._name_or_scope, VariableScope): 1939 self._var_scope_store.variable_scopes_count = self._old_subscopes 1940 else: 1941 self._var_scope_store.close_variable_subscopes(self._new_name) 1942 self._var_scope_store.current_scope = self._old 1943 1944 1945def _maybe_wrap_custom_getter(custom_getter, old_getter): 1946 """Wrap a call to a custom_getter to use the old_getter internally.""" 1947 if old_getter is None: 1948 return custom_getter 1949 1950 # The new custom_getter should call the old one 1951 def wrapped_custom_getter(getter, *args, **kwargs): 1952 # Call: 1953 # custom_getter( 1954 # lambda: old_getter(true_getter, ...), *args, **kwargs) 1955 # which means custom_getter will call old_getter, which 1956 # will call the true_getter, perform any intermediate 1957 # processing, and return the results to the current 1958 # getter, which will also perform additional processing. 1959 return custom_getter( 1960 functools.partial(old_getter, getter), 1961 *args, **kwargs) 1962 return wrapped_custom_getter 1963 1964 1965def _get_unique_variable_scope(prefix): 1966 """Get a name with the given prefix unique in the current variable scope.""" 1967 var_scope_store = get_variable_scope_store() 1968 current_scope = get_variable_scope() 1969 name = current_scope.name + "/" + prefix if current_scope.name else prefix 1970 if var_scope_store.variable_scope_count(name) == 0: 1971 return prefix 1972 idx = 1 1973 while var_scope_store.variable_scope_count(name + ("_%d" % idx)) > 0: 1974 idx += 1 1975 return prefix + ("_%d" % idx) 1976 1977 1978# Named like a function for backwards compatibility with the 1979# @tf_contextlib.contextmanager version, which was switched to a class to avoid 1980# some object creation overhead. 1981@tf_export(v1=["variable_scope"]) # pylint: disable=invalid-name 1982class variable_scope(object): 1983 """A context manager for defining ops that creates variables (layers). 1984 1985 This context manager validates that the (optional) `values` are from the same 1986 graph, ensures that graph is the default graph, and pushes a name scope and a 1987 variable scope. 1988 1989 If `name_or_scope` is not None, it is used as is. If `name_or_scope` is None, 1990 then `default_name` is used. In that case, if the same name has been 1991 previously used in the same scope, it will be made unique by appending `_N` 1992 to it. 1993 1994 Variable scope allows you to create new variables and to share already created 1995 ones while providing checks to not create or share by accident. For details, 1996 see the [Variable Scope How To](https://tensorflow.org/guide/variables), here 1997 we present only a few basic examples. 1998 1999 Simple example of how to create a new variable: 2000 2001 ```python 2002 with tf.variable_scope("foo"): 2003 with tf.variable_scope("bar"): 2004 v = tf.get_variable("v", [1]) 2005 assert v.name == "foo/bar/v:0" 2006 ``` 2007 2008 Simple example of how to reenter a premade variable scope safely: 2009 2010 ```python 2011 with tf.variable_scope("foo") as vs: 2012 pass 2013 2014 # Re-enter the variable scope. 2015 with tf.variable_scope(vs, 2016 auxiliary_name_scope=False) as vs1: 2017 # Restore the original name_scope. 2018 with tf.name_scope(vs1.original_name_scope): 2019 v = tf.get_variable("v", [1]) 2020 assert v.name == "foo/v:0" 2021 c = tf.constant([1], name="c") 2022 assert c.name == "foo/c:0" 2023 ``` 2024 2025 Basic example of sharing a variable AUTO_REUSE: 2026 2027 ```python 2028 def foo(): 2029 with tf.variable_scope("foo", reuse=tf.AUTO_REUSE): 2030 v = tf.get_variable("v", [1]) 2031 return v 2032 2033 v1 = foo() # Creates v. 2034 v2 = foo() # Gets the same, existing v. 2035 assert v1 == v2 2036 ``` 2037 2038 Basic example of sharing a variable with reuse=True: 2039 2040 ```python 2041 with tf.variable_scope("foo"): 2042 v = tf.get_variable("v", [1]) 2043 with tf.variable_scope("foo", reuse=True): 2044 v1 = tf.get_variable("v", [1]) 2045 assert v1 == v 2046 ``` 2047 2048 Sharing a variable by capturing a scope and setting reuse: 2049 2050 ```python 2051 with tf.variable_scope("foo") as scope: 2052 v = tf.get_variable("v", [1]) 2053 scope.reuse_variables() 2054 v1 = tf.get_variable("v", [1]) 2055 assert v1 == v 2056 ``` 2057 2058 To prevent accidental sharing of variables, we raise an exception when getting 2059 an existing variable in a non-reusing scope. 2060 2061 ```python 2062 with tf.variable_scope("foo"): 2063 v = tf.get_variable("v", [1]) 2064 v1 = tf.get_variable("v", [1]) 2065 # Raises ValueError("... v already exists ..."). 2066 ``` 2067 2068 Similarly, we raise an exception when trying to get a variable that does not 2069 exist in reuse mode. 2070 2071 ```python 2072 with tf.variable_scope("foo", reuse=True): 2073 v = tf.get_variable("v", [1]) 2074 # Raises ValueError("... v does not exists ..."). 2075 ``` 2076 2077 Note that the `reuse` flag is inherited: if we open a reusing scope, then all 2078 its sub-scopes become reusing as well. 2079 2080 A note about name scoping: Setting `reuse` does not impact the naming of other 2081 ops such as mult. See related discussion on 2082 [github#6189](https://github.com/tensorflow/tensorflow/issues/6189) 2083 2084 Note that up to and including version 1.0, it was allowed (though explicitly 2085 discouraged) to pass False to the reuse argument, yielding undocumented 2086 behaviour slightly different from None. Starting at 1.1.0 passing None and 2087 False as reuse has exactly the same effect. 2088 2089 A note about using variable scopes in multi-threaded environment: Variable 2090 scopes are thread local, so one thread will not see another thread's current 2091 scope. Also, when using `default_name`, unique scopes names are also generated 2092 only on a per thread basis. If the same name was used within a different 2093 thread, that doesn't prevent a new thread from creating the same scope. 2094 However, the underlying variable store is shared across threads (within the 2095 same graph). As such, if another thread tries to create a new variable with 2096 the same name as a variable created by a previous thread, it will fail unless 2097 reuse is True. 2098 2099 Further, each thread starts with an empty variable scope. So if you wish to 2100 preserve name prefixes from a scope from the main thread, you should capture 2101 the main thread's scope and re-enter it in each thread. For e.g. 2102 2103 ``` 2104 main_thread_scope = variable_scope.get_variable_scope() 2105 2106 # Thread's target function: 2107 def thread_target_fn(captured_scope): 2108 with variable_scope.variable_scope(captured_scope): 2109 # .... regular code for this thread 2110 2111 2112 thread = threading.Thread(target=thread_target_fn, args=(main_thread_scope,)) 2113 ``` 2114 """ 2115 2116 def __init__(self, 2117 name_or_scope, 2118 default_name=None, 2119 values=None, 2120 initializer=None, 2121 regularizer=None, 2122 caching_device=None, 2123 partitioner=None, 2124 custom_getter=None, 2125 reuse=None, 2126 dtype=None, 2127 use_resource=None, 2128 constraint=None, 2129 auxiliary_name_scope=True): 2130 """Initialize the context manager. 2131 2132 Args: 2133 name_or_scope: `string` or `VariableScope`: the scope to open. 2134 default_name: The default name to use if the `name_or_scope` argument is 2135 `None`, this name will be uniquified. If name_or_scope is provided it 2136 won't be used and therefore it is not required and can be None. 2137 values: The list of `Tensor` arguments that are passed to the op function. 2138 initializer: default initializer for variables within this scope. 2139 regularizer: default regularizer for variables within this scope. 2140 caching_device: default caching device for variables within this scope. 2141 partitioner: default partitioner for variables within this scope. 2142 custom_getter: default custom getter for variables within this scope. 2143 reuse: `True`, None, or tf.AUTO_REUSE; if `True`, we go into reuse mode 2144 for this scope as well as all sub-scopes; if tf.AUTO_REUSE, we create 2145 variables if they do not exist, and return them otherwise; if None, we 2146 inherit the parent scope's reuse flag. When eager execution is enabled, 2147 new variables are always created unless an EagerVariableStore or 2148 template is currently active. 2149 dtype: type of variables created in this scope (defaults to the type 2150 in the passed scope, or inherited from parent scope). 2151 use_resource: If False, all variables will be regular Variables. If True, 2152 experimental ResourceVariables with well-defined semantics will be used 2153 instead. Defaults to False (will later change to True). When eager 2154 execution is enabled this argument is always forced to be True. 2155 constraint: An optional projection function to be applied to the variable 2156 after being updated by an `Optimizer` (e.g. used to implement norm 2157 constraints or value constraints for layer weights). The function must 2158 take as input the unprojected Tensor representing the value of the 2159 variable and return the Tensor for the projected value 2160 (which must have the same shape). Constraints are not safe to 2161 use when doing asynchronous distributed training. 2162 auxiliary_name_scope: If `True`, we create an auxiliary name scope with 2163 the scope. If `False`, we don't create it. Note that the argument is 2164 not inherited, and it only takes effect for once when creating. You 2165 should only use it for re-entering a premade variable scope. 2166 2167 Returns: 2168 A scope that can be captured and reused. 2169 2170 Raises: 2171 ValueError: when trying to reuse within a create scope, or create within 2172 a reuse scope. 2173 TypeError: when the types of some arguments are not appropriate. 2174 """ 2175 self._name_or_scope = name_or_scope 2176 self._default_name = default_name 2177 self._values = values 2178 self._initializer = initializer 2179 self._regularizer = regularizer 2180 self._caching_device = caching_device 2181 self._partitioner = partitioner 2182 self._custom_getter = custom_getter 2183 self._reuse = reuse 2184 self._dtype = dtype 2185 self._use_resource = use_resource 2186 self._constraint = constraint 2187 if self._default_name is None and self._name_or_scope is None: 2188 raise TypeError("If default_name is None then name_or_scope is required") 2189 if self._reuse is False: 2190 # We don't allow non-inheriting scopes, False = None here. 2191 self._reuse = None 2192 if not (self._reuse is True 2193 or self._reuse is None 2194 or self._reuse is AUTO_REUSE): 2195 raise ValueError("The reuse parameter must be True or False or None.") 2196 if self._values is None: 2197 self._values = [] 2198 self._in_graph_mode = not context.executing_eagerly() 2199 if self._in_graph_mode: 2200 self._graph = ops._get_graph_from_inputs(self._values) # pylint: disable=protected-access 2201 self._cached_pure_variable_scope = None 2202 self._current_name_scope = None 2203 if not isinstance(auxiliary_name_scope, bool): 2204 raise TypeError("The auxiliary_name_scope must be `True` or `False`, " 2205 "while get {}".format(auxiliary_name_scope)) 2206 self._auxiliary_name_scope = auxiliary_name_scope 2207 2208 def __enter__(self): 2209 # If the default graph is building a function, then we should not replace it 2210 # with the cached graph. 2211 if ops.get_default_graph().building_function: 2212 self._building_function = True 2213 else: 2214 self._building_function = False 2215 if self._in_graph_mode and not self._building_function: 2216 self._graph_context_manager = self._graph.as_default() 2217 self._graph_context_manager.__enter__() 2218 if self._cached_pure_variable_scope is not None: 2219 # Fast path for re-entering variable_scopes. We've held on to the pure 2220 # variable scope from a previous successful __enter__, so we avoid some 2221 # overhead by re-using that object. 2222 if self._current_name_scope is not None: 2223 self._current_name_scope.__enter__() 2224 return self._cached_pure_variable_scope.__enter__() 2225 2226 try: 2227 return self._enter_scope_uncached() 2228 except Exception: 2229 if self._in_graph_mode and not self._building_function: 2230 if self._graph_context_manager is not None: 2231 self._graph_context_manager.__exit__(*sys.exc_info()) 2232 raise 2233 2234 def _enter_scope_uncached(self): 2235 """Enters the context manager when there is no cached scope yet. 2236 2237 Returns: 2238 The entered variable scope. 2239 2240 Raises: 2241 TypeError: A wrong type is passed as `scope` at __init__(). 2242 ValueError: `reuse` is incorrectly set at __init__(). 2243 """ 2244 if self._auxiliary_name_scope: 2245 # Create a new name scope later 2246 current_name_scope = None 2247 else: 2248 # Reenter the current name scope 2249 name_scope = ops.get_name_scope() 2250 if name_scope: 2251 # Hack to reenter 2252 name_scope += "/" 2253 current_name_scope = ops.name_scope(name_scope) 2254 else: 2255 # Root scope 2256 current_name_scope = ops.name_scope(name_scope) 2257 2258 # IMPORTANT: Only assign to self._cached_pure_variable_scope and 2259 # self._current_name_scope after successful __enter__() calls. 2260 if self._name_or_scope is not None: 2261 if not isinstance(self._name_or_scope, 2262 (VariableScope,) + six.string_types): 2263 raise TypeError("VariableScope: name_or_scope must be a string or " 2264 "VariableScope.") 2265 if isinstance(self._name_or_scope, six.string_types): 2266 name_scope = self._name_or_scope 2267 else: 2268 name_scope = self._name_or_scope.name.split("/")[-1] 2269 if name_scope or current_name_scope: 2270 current_name_scope = current_name_scope or ops.name_scope(name_scope) 2271 try: 2272 current_name_scope_name = current_name_scope.__enter__() 2273 except: 2274 current_name_scope.__exit__(*sys.exc_info()) 2275 raise 2276 self._current_name_scope = current_name_scope 2277 if isinstance(self._name_or_scope, six.string_types): 2278 old_name_scope = current_name_scope_name 2279 else: 2280 old_name_scope = self._name_or_scope.original_name_scope 2281 pure_variable_scope = _pure_variable_scope( 2282 self._name_or_scope, 2283 reuse=self._reuse, 2284 initializer=self._initializer, 2285 regularizer=self._regularizer, 2286 caching_device=self._caching_device, 2287 partitioner=self._partitioner, 2288 custom_getter=self._custom_getter, 2289 old_name_scope=old_name_scope, 2290 dtype=self._dtype, 2291 use_resource=self._use_resource, 2292 constraint=self._constraint) 2293 try: 2294 entered_pure_variable_scope = pure_variable_scope.__enter__() 2295 except: 2296 pure_variable_scope.__exit__(*sys.exc_info()) 2297 raise 2298 self._cached_pure_variable_scope = pure_variable_scope 2299 return entered_pure_variable_scope 2300 else: 2301 self._current_name_scope = None 2302 # This can only happen if someone is entering the root variable scope. 2303 pure_variable_scope = _pure_variable_scope( 2304 self._name_or_scope, 2305 reuse=self._reuse, 2306 initializer=self._initializer, 2307 regularizer=self._regularizer, 2308 caching_device=self._caching_device, 2309 partitioner=self._partitioner, 2310 custom_getter=self._custom_getter, 2311 dtype=self._dtype, 2312 use_resource=self._use_resource, 2313 constraint=self._constraint) 2314 try: 2315 entered_pure_variable_scope = pure_variable_scope.__enter__() 2316 except: 2317 pure_variable_scope.__exit__(*sys.exc_info()) 2318 raise 2319 self._cached_pure_variable_scope = pure_variable_scope 2320 return entered_pure_variable_scope 2321 2322 else: # Here name_or_scope is None. Using default name, but made unique. 2323 if self._reuse: 2324 raise ValueError("reuse=True cannot be used without a name_or_scope") 2325 current_name_scope = current_name_scope or ops.name_scope( 2326 self._default_name) 2327 try: 2328 current_name_scope_name = current_name_scope.__enter__() 2329 except: 2330 current_name_scope.__exit__(*sys.exc_info()) 2331 raise 2332 self._current_name_scope = current_name_scope 2333 unique_default_name = _get_unique_variable_scope(self._default_name) 2334 pure_variable_scope = _pure_variable_scope( 2335 unique_default_name, 2336 initializer=self._initializer, 2337 regularizer=self._regularizer, 2338 caching_device=self._caching_device, 2339 partitioner=self._partitioner, 2340 custom_getter=self._custom_getter, 2341 old_name_scope=current_name_scope_name, 2342 dtype=self._dtype, 2343 use_resource=self._use_resource, 2344 constraint=self._constraint) 2345 try: 2346 entered_pure_variable_scope = pure_variable_scope.__enter__() 2347 except: 2348 pure_variable_scope.__exit__(*sys.exc_info()) 2349 raise 2350 self._cached_pure_variable_scope = pure_variable_scope 2351 return entered_pure_variable_scope 2352 2353 def __exit__(self, type_arg, value_arg, traceback_arg): 2354 self._cached_pure_variable_scope.__exit__( 2355 type_arg, value_arg, traceback_arg) 2356 if self._current_name_scope: 2357 self._current_name_scope.__exit__(type_arg, value_arg, traceback_arg) 2358 if self._in_graph_mode and not self._building_function: 2359 self._graph_context_manager.__exit__(type_arg, value_arg, traceback_arg) 2360 2361 2362# pylint: disable=g-doc-return-or-yield 2363@tf_export(v1=["variable_op_scope"]) 2364@tf_contextlib.contextmanager 2365def variable_op_scope(values, 2366 name_or_scope, 2367 default_name=None, 2368 initializer=None, 2369 regularizer=None, 2370 caching_device=None, 2371 partitioner=None, 2372 custom_getter=None, 2373 reuse=None, 2374 dtype=None, 2375 use_resource=None, 2376 constraint=None): 2377 """Deprecated: context manager for defining an op that creates variables.""" 2378 logging.warn("tf.variable_op_scope(values, name, default_name) is deprecated," 2379 " use tf.variable_scope(name, default_name, values)") 2380 with variable_scope(name_or_scope, 2381 default_name=default_name, 2382 values=values, 2383 initializer=initializer, 2384 regularizer=regularizer, 2385 caching_device=caching_device, 2386 partitioner=partitioner, 2387 custom_getter=custom_getter, 2388 reuse=reuse, 2389 dtype=dtype, 2390 use_resource=use_resource, 2391 constraint=constraint) as scope: 2392 yield scope 2393 2394 2395def _call_partitioner(partitioner, shape, dtype): 2396 """Call partitioner validating its inputs/output. 2397 2398 Args: 2399 partitioner: a function mapping `Tensor` shape and dtype to a 2400 list of partitions. 2401 shape: shape of the `Tensor` to partition, must have at least two 2402 dimensions. 2403 dtype: dtype of the elements in the `Tensor`. 2404 2405 Returns: 2406 A list with elements >=1 and exactly one >1. The index of that 2407 element corresponds to the partitioning axis. 2408 """ 2409 if not shape.is_fully_defined(): 2410 raise ValueError("Shape of a new partitioned variable must be " 2411 "fully defined, but instead was %s." % (shape,)) 2412 if shape.ndims < 1: 2413 raise ValueError("A partitioned Variable must have rank at least 1, " 2414 "shape: %s" % shape) 2415 2416 slicing = partitioner(shape=shape, dtype=dtype) 2417 if not isinstance(slicing, collections_lib.Sequence): 2418 raise ValueError("Partitioner must return a sequence, but saw: %s" 2419 % slicing) 2420 if len(slicing) != shape.ndims: 2421 raise ValueError( 2422 "Partitioner returned a partition list that does not match the " 2423 "Variable's rank: %s vs. %s" % (slicing, shape)) 2424 if any(p < 1 for p in slicing): 2425 raise ValueError( 2426 "Partitioner returned zero partitions for some axes: %s" % 2427 slicing) 2428 if sum(p > 1 for p in slicing) > 1: 2429 raise ValueError( 2430 "Can only slice a variable along one dimension: " 2431 "shape: %s, partitioning: %s" % (shape, slicing)) 2432 return slicing 2433 2434 2435# TODO(slebedev): could be inlined, but 2436# `_VariableStore._get_partitioned_variable` is too complex even 2437# without this logic. 2438def _get_slice_dim_and_num_slices(slicing): 2439 """Get slicing dimension and number of slices from the partitioner output.""" 2440 for slice_dim, num_slices in enumerate(slicing): 2441 if num_slices > 1: 2442 break 2443 else: 2444 # Degenerate case: no partitioning applied. 2445 slice_dim = 0 2446 num_slices = 1 2447 return slice_dim, num_slices 2448 2449 2450def _iter_slices(full_shape, num_slices, slice_dim): 2451 """Slices a given a shape along the specified dimension.""" 2452 num_slices_with_excess = full_shape[slice_dim] % num_slices 2453 offset = [0] * len(full_shape) 2454 min_slice_len = full_shape[slice_dim] // num_slices 2455 for i in xrange(num_slices): 2456 shape = full_shape[:] 2457 shape[slice_dim] = min_slice_len + bool(i < num_slices_with_excess) 2458 yield offset[:], shape 2459 offset[slice_dim] += shape[slice_dim] 2460 2461 2462def _get_trainable_value(synchronization, trainable): 2463 """Computes the trainable value based on the given arguments.""" 2464 if synchronization == VariableSynchronization.ON_READ: 2465 if trainable: 2466 raise ValueError( 2467 "Synchronization value can be set to " 2468 "VariableSynchronization.ON_READ only for non-trainable variables. " 2469 "You have specified trainable=True and " 2470 "synchronization=VariableSynchronization.ON_READ.") 2471 else: 2472 # Set trainable to be false when variable is to be synced on read. 2473 trainable = False 2474 elif trainable is None: 2475 trainable = True 2476 return trainable 2477 2478 2479def default_variable_creator(next_creator=None, **kwargs): 2480 """Default variable creator.""" 2481 assert next_creator is None 2482 initial_value = kwargs.get("initial_value", None) 2483 trainable = kwargs.get("trainable", None) 2484 collections = kwargs.get("collections", None) 2485 validate_shape = kwargs.get("validate_shape", True) 2486 caching_device = kwargs.get("caching_device", None) 2487 name = kwargs.get("name", None) 2488 variable_def = kwargs.get("variable_def", None) 2489 dtype = kwargs.get("dtype", None) 2490 expected_shape = kwargs.get("expected_shape", None) 2491 import_scope = kwargs.get("import_scope", None) 2492 constraint = kwargs.get("constraint", None) 2493 use_resource = kwargs.get("use_resource", None) 2494 2495 # Set trainable value based on synchronization value. 2496 synchronization = kwargs.get("synchronization", VariableSynchronization.AUTO) 2497 trainable = _get_trainable_value( 2498 synchronization=synchronization, trainable=trainable) 2499 2500 if use_resource is None: 2501 use_resource = get_variable_scope().use_resource 2502 if use_resource is None: 2503 use_resource = _DEFAULT_USE_RESOURCE 2504 use_resource = use_resource or context.executing_eagerly() 2505 if use_resource: 2506 distribute_strategy = kwargs.get("distribute_strategy", None) 2507 return resource_variable_ops.ResourceVariable( 2508 initial_value=initial_value, trainable=trainable, 2509 collections=collections, validate_shape=validate_shape, 2510 caching_device=caching_device, name=name, dtype=dtype, 2511 constraint=constraint, variable_def=variable_def, 2512 import_scope=import_scope, distribute_strategy=distribute_strategy) 2513 else: 2514 return variables.RefVariable( 2515 initial_value=initial_value, trainable=trainable, 2516 collections=collections, validate_shape=validate_shape, 2517 caching_device=caching_device, name=name, dtype=dtype, 2518 constraint=constraint, variable_def=variable_def, 2519 expected_shape=expected_shape, import_scope=import_scope) 2520 2521 2522def default_variable_creator_v2(next_creator=None, **kwargs): 2523 """Default variable creator.""" 2524 assert next_creator is None 2525 initial_value = kwargs.get("initial_value", None) 2526 trainable = kwargs.get("trainable", None) 2527 validate_shape = kwargs.get("validate_shape", True) 2528 caching_device = kwargs.get("caching_device", None) 2529 name = kwargs.get("name", None) 2530 variable_def = kwargs.get("variable_def", None) 2531 dtype = kwargs.get("dtype", None) 2532 import_scope = kwargs.get("import_scope", None) 2533 constraint = kwargs.get("constraint", None) 2534 distribute_strategy = kwargs.get("distribute_strategy", None) 2535 2536 # Set trainable value based on synchronization value. 2537 synchronization = kwargs.get("synchronization", VariableSynchronization.AUTO) 2538 trainable = _get_trainable_value( 2539 synchronization=synchronization, trainable=trainable) 2540 2541 return resource_variable_ops.ResourceVariable( 2542 initial_value=initial_value, trainable=trainable, 2543 validate_shape=validate_shape, caching_device=caching_device, 2544 name=name, dtype=dtype, constraint=constraint, variable_def=variable_def, 2545 import_scope=import_scope, distribute_strategy=distribute_strategy) 2546 2547 2548variables.default_variable_creator = default_variable_creator 2549variables.default_variable_creator_v2 = default_variable_creator_v2 2550 2551 2552def _make_getter(captured_getter, captured_previous): 2553 """Gets around capturing loop variables in python being broken.""" 2554 return lambda **kwargs: captured_getter(captured_previous, **kwargs) 2555 2556 2557# TODO(apassos) remove forwarding symbol 2558variable = variables.VariableV1 2559 2560 2561@tf_export(v1=["variable_creator_scope"]) 2562@tf_contextlib.contextmanager 2563def variable_creator_scope_v1(variable_creator): 2564 """Scope which defines a variable creation function to be used by variable(). 2565 2566 variable_creator is expected to be a function with the following signature: 2567 2568 ``` 2569 def variable_creator(next_creator, **kwargs) 2570 ``` 2571 2572 The creator is supposed to eventually call the next_creator to create a 2573 variable if it does want to create a variable and not call Variable or 2574 ResourceVariable directly. This helps make creators composable. A creator may 2575 choose to create multiple variables, return already existing variables, or 2576 simply register that a variable was created and defer to the next creators in 2577 line. Creators can also modify the keyword arguments seen by the next 2578 creators. 2579 2580 Custom getters in the variable scope will eventually resolve down to these 2581 custom creators when they do create variables. 2582 2583 The valid keyword arguments in kwds are: 2584 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 2585 which is the initial value for the Variable. The initial value must have 2586 a shape specified unless `validate_shape` is set to False. Can also be a 2587 callable with no argument that returns the initial value when called. In 2588 that case, `dtype` must be specified. (Note that initializer functions 2589 from init_ops.py must first be bound to a shape before being used here.) 2590 trainable: If `True`, the default, also adds the variable to the graph 2591 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 2592 the default list of variables to use by the `Optimizer` classes. 2593 `trainable` defaults to `True` unless `synchronization` is 2594 set to `ON_READ`. 2595 collections: List of graph collections keys. The new variable is added to 2596 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 2597 validate_shape: If `False`, allows the variable to be initialized with a 2598 value of unknown shape. If `True`, the default, the shape of 2599 `initial_value` must be known. 2600 caching_device: Optional device string describing where the Variable 2601 should be cached for reading. Defaults to the Variable's device. 2602 If not `None`, caches on another device. Typical use is to cache 2603 on the device where the Ops using the Variable reside, to deduplicate 2604 copying through `Switch` and other conditional statements. 2605 name: Optional name for the variable. Defaults to `'Variable'` and gets 2606 uniquified automatically. 2607 dtype: If set, initial_value will be converted to the given type. 2608 If `None`, either the datatype will be kept (if `initial_value` is 2609 a Tensor), or `convert_to_tensor` will decide. 2610 constraint: A constraint function to be applied to the variable after 2611 updates by some algorithms. 2612 use_resource: if True, a ResourceVariable is always created. 2613 synchronization: Indicates when a distributed a variable will be 2614 aggregated. Accepted values are constants defined in the class 2615 `tf.VariableSynchronization`. By default the synchronization is set to 2616 `AUTO` and the current `DistributionStrategy` chooses 2617 when to synchronize. If `synchronization` is set to `ON_READ`, 2618 `trainable` must not be set to `True`. 2619 aggregation: Indicates how a distributed variable will be aggregated. 2620 Accepted values are constants defined in the class 2621 `tf.VariableAggregation`. 2622 2623 This set may grow over time, so it's important the signature of creators is as 2624 mentioned above. 2625 2626 Args: 2627 variable_creator: the passed creator 2628 2629 Yields: 2630 A scope in which the creator is active 2631 """ 2632 with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access 2633 yield 2634 2635 2636# Note: only the docstrings differ between this and v1. 2637@tf_export("variable_creator_scope", v1=[]) 2638@tf_contextlib.contextmanager 2639def variable_creator_scope(variable_creator): 2640 """Scope which defines a variable creation function to be used by variable(). 2641 2642 variable_creator is expected to be a function with the following signature: 2643 2644 ``` 2645 def variable_creator(next_creator, **kwargs) 2646 ``` 2647 2648 The creator is supposed to eventually call the next_creator to create a 2649 variable if it does want to create a variable and not call Variable or 2650 ResourceVariable directly. This helps make creators composable. A creator may 2651 choose to create multiple variables, return already existing variables, or 2652 simply register that a variable was created and defer to the next creators in 2653 line. Creators can also modify the keyword arguments seen by the next 2654 creators. 2655 2656 Custom getters in the variable scope will eventually resolve down to these 2657 custom creators when they do create variables. 2658 2659 The valid keyword arguments in kwds are: 2660 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 2661 which is the initial value for the Variable. The initial value must have 2662 a shape specified unless `validate_shape` is set to False. Can also be a 2663 callable with no argument that returns the initial value when called. In 2664 that case, `dtype` must be specified. (Note that initializer functions 2665 from init_ops.py must first be bound to a shape before being used here.) 2666 trainable: If `True`, the default, GradientTapes automatically watch 2667 uses of this Variable. 2668 validate_shape: If `False`, allows the variable to be initialized with a 2669 value of unknown shape. If `True`, the default, the shape of 2670 `initial_value` must be known. 2671 caching_device: Optional device string describing where the Variable 2672 should be cached for reading. Defaults to the Variable's device. 2673 If not `None`, caches on another device. Typical use is to cache 2674 on the device where the Ops using the Variable reside, to deduplicate 2675 copying through `Switch` and other conditional statements. 2676 name: Optional name for the variable. Defaults to `'Variable'` and gets 2677 uniquified automatically. 2678 dtype: If set, initial_value will be converted to the given type. 2679 If `None`, either the datatype will be kept (if `initial_value` is 2680 a Tensor), or `convert_to_tensor` will decide. 2681 constraint: A constraint function to be applied to the variable after 2682 updates by some algorithms. 2683 synchronization: Indicates when a distributed a variable will be 2684 aggregated. Accepted values are constants defined in the class 2685 `tf.VariableSynchronization`. By default the synchronization is set to 2686 `AUTO` and the current `DistributionStrategy` chooses 2687 when to synchronize. If `synchronization` is set to `ON_READ`, 2688 `trainable` must not be set to `True`. 2689 aggregation: Indicates how a distributed variable will be aggregated. 2690 Accepted values are constants defined in the class 2691 `tf.VariableAggregation`. 2692 2693 This set may grow over time, so it's important the signature of creators is as 2694 mentioned above. 2695 2696 Args: 2697 variable_creator: the passed creator 2698 2699 Yields: 2700 A scope in which the creator is active 2701 """ 2702 with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access 2703 yield 2704