1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A class to store named variables and a scope operator to manage sharing.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections as collections_lib 22import copy 23import enum # pylint: disable=g-bad-import-order 24import functools 25import sys 26import threading 27import traceback 28 29import six 30from six import iteritems 31from six.moves import xrange, zip # pylint: disable=redefined-builtin 32 33from tensorflow.python import tf2 34from tensorflow.python.client import session 35from tensorflow.python.eager import context 36from tensorflow.python.eager import monitoring 37from tensorflow.python.framework import dtypes 38from tensorflow.python.framework import ops 39from tensorflow.python.framework import tensor_shape 40from tensorflow.python.ops import array_ops 41from tensorflow.python.ops import init_ops 42from tensorflow.python.ops import resource_variable_ops 43from tensorflow.python.ops import variables 44from tensorflow.python.platform import tf_logging as logging 45from tensorflow.python.util import deprecation 46from tensorflow.python.util import function_utils 47from tensorflow.python.util import tf_contextlib 48from tensorflow.python.util import tf_inspect 49from tensorflow.python.util.tf_export import tf_export 50 51__all__ = [ 52 "AUTO_REUSE", "VariableScope", "get_variable_scope", "get_variable", 53 "get_local_variable", "variable_scope", "variable_op_scope", 54 "no_regularizer", "VariableSynchronization", "VariableAggregation" 55] 56 57_api_usage_gauge = monitoring.BoolGauge( 58 "/tensorflow/api/resource_variables", 59 "Whether variable_scope.enable_resource_variables() is called.") 60 61 62class _PartitionInfo(object): 63 """Holds partition info used by initializer functions.""" 64 65 def __init__(self, full_shape, var_offset): 66 """Constructor. 67 68 Args: 69 full_shape: Tuple or list of `int` indicating the full combined shape of 70 the partitioned variables. 71 var_offset: Tuple or list of `int` specifying offset of this partition 72 with respect to the full variable for each dimension. 73 74 Raises: 75 TypeError: If `full_shape` or `var_offset` is not a sequence. 76 ValueError: If `full_shape` or `var_offset` differ in length. If 77 `var_offset` exceeds `full_shape` in any dimension. 78 """ 79 if not isinstance(full_shape, collections_lib.Sequence) or isinstance( 80 full_shape, six.string_types): 81 raise TypeError( 82 "`full_shape` must be a sequence (like tuple or list) instead of " + 83 type(full_shape).__name__) 84 85 if not isinstance(var_offset, collections_lib.Sequence) or isinstance( 86 var_offset, six.string_types): 87 raise TypeError( 88 "`var_offset` must be a sequence (like tuple or list) instead of " + 89 type(var_offset).__name__) 90 91 if len(var_offset) != len(full_shape): 92 raise ValueError( 93 "Expected equal length, but `var_offset` is of length {} while " 94 "full_shape is of length {}.".format( 95 len(var_offset), len(full_shape))) 96 97 for offset, shape in zip(var_offset, full_shape): 98 if offset < 0 or offset >= shape: 99 raise ValueError( 100 "Expected 0 <= offset < shape but found offset={}, shape={} for " 101 "var_offset={}, full_shape={}".format(offset, shape, var_offset, 102 full_shape)) 103 104 self._full_shape = full_shape 105 self._var_offset = var_offset 106 107 @property 108 def full_shape(self): 109 return self._full_shape 110 111 @property 112 def var_offset(self): 113 return self._var_offset 114 115 def single_offset(self, shape): 116 """Returns the offset when the variable is partitioned in at most one dim. 117 118 Args: 119 shape: Tuple or list of `int` indicating the shape of one specific 120 variable partition. 121 122 Returns: 123 `int` representing the offset in the dimension along which the variable is 124 partitioned. Returns 0 if the variable is not being partitioned. 125 126 Raises: 127 ValueError: Depending on self.single_slice_dim(). 128 """ 129 130 single_slice_dim = self.single_slice_dim(shape) 131 # If this variable is not being partitioned at all, single_slice_dim() could 132 # return None. 133 if single_slice_dim is None: 134 return 0 135 return self.var_offset[single_slice_dim] 136 137 def single_slice_dim(self, shape): 138 """Returns the slice dim when the variable is partitioned only in one dim. 139 140 Args: 141 shape: Tuple or list of `int` indicating the shape of one specific 142 variable partition. 143 144 Returns: 145 `int` representing the dimension that the variable is partitioned in, or 146 `None` if the variable doesn't seem to be partitioned at all. 147 148 Raises: 149 TypeError: If `shape` is not a sequence. 150 ValueError: If `shape` is not the same length as `self.full_shape`. If 151 the variable is partitioned in more than one dimension. 152 """ 153 if not isinstance(shape, collections_lib.Sequence) or isinstance( 154 shape, six.string_types): 155 raise TypeError( 156 "`shape` must be a sequence (like tuple or list) instead of " + 157 type(shape).__name__) 158 159 if len(shape) != len(self.full_shape): 160 raise ValueError( 161 "Expected equal length, but received shape={} of length {} while " 162 "self.full_shape={} is of length {}.".format(shape, len(shape), 163 self.full_shape, 164 len(self.full_shape))) 165 166 for i in xrange(len(shape)): 167 if self.var_offset[i] + shape[i] > self.full_shape[i]: 168 raise ValueError( 169 "With self.var_offset={}, a partition of shape={} would exceed " 170 "self.full_shape={} in dimension {}.".format( 171 self.var_offset, shape, self.full_shape, i)) 172 173 slice_dim = None 174 for i in xrange(len(shape)): 175 if shape[i] == self.full_shape[i]: 176 continue 177 if slice_dim is not None: 178 raise ValueError( 179 "Cannot use single_slice_dim() with shape={} and " 180 "self.full_shape={} since slice dim could be either dimension {} " 181 "or {}.".format(shape, self.full_shape, i, slice_dim)) 182 slice_dim = i 183 184 return slice_dim 185 186 187class _ReuseMode(enum.Enum): 188 """Mode for variable access within a variable scope.""" 189 190 # Indicates that variables are to be fetched if they already exist or 191 # otherwise created. 192 AUTO_REUSE = 1 193 194 # TODO(alive): For TensorFlow 2.0, Deprecate True/False/None API in favor of 195 # enum values. 196 # REUSE_FALSE = 2 197 # REUSE_TRUE = 3 198 199 200# TODO(apassos) remove these forwarding symbols. 201VariableSynchronization = variables.VariableSynchronization # pylint: disable=invalid-name 202VariableAggregation = variables.VariableAggregation # pylint: disable=invalid-name 203 204AUTO_REUSE = _ReuseMode.AUTO_REUSE 205tf_export(v1=["AUTO_REUSE"]).export_constant(__name__, "AUTO_REUSE") 206AUTO_REUSE.__doc__ = """ 207When passed in as the value for the `reuse` flag, AUTO_REUSE indicates that 208get_variable() should create the requested variable if it doesn't exist or, if 209it does exist, simply return it. 210""" 211 212_DEFAULT_USE_RESOURCE = tf2.enabled() 213 214 215@tf_export(v1=["enable_resource_variables"]) 216def enable_resource_variables(): 217 """Creates resource variables by default. 218 219 Resource variables are improved versions of TensorFlow variables with a 220 well-defined memory model. Accessing a resource variable reads its value, and 221 all ops which access a specific read value of the variable are guaranteed to 222 see the same value for that tensor. Writes which happen after a read (by 223 having a control or data dependency on the read) are guaranteed not to affect 224 the value of the read tensor, and similarly writes which happen before a read 225 are guaranteed to affect the value. No guarantees are made about unordered 226 read/write pairs. 227 228 Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0 229 feature. 230 """ 231 global _DEFAULT_USE_RESOURCE 232 _DEFAULT_USE_RESOURCE = True 233 _api_usage_gauge.get_cell().set(True) 234 235 236@tf_export(v1=["resource_variables_enabled"]) 237def resource_variables_enabled(): 238 """Returns `True` if resource variables are enabled. 239 240 Resource variables are improved versions of TensorFlow variables with a 241 well-defined memory model. Accessing a resource variable reads its value, and 242 all ops which access a specific read value of the variable are guaranteed to 243 see the same value for that tensor. Writes which happen after a read (by 244 having a control or data dependency on the read) are guaranteed not to affect 245 the value of the read tensor, and similarly writes which happen before a read 246 are guaranteed to affect the value. No guarantees are made about unordered 247 read/write pairs. 248 249 Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0 250 feature. 251 """ 252 global _DEFAULT_USE_RESOURCE 253 return _DEFAULT_USE_RESOURCE 254 255 256@deprecation.deprecated( 257 None, "non-resource variables are not supported in the long term") 258@tf_export(v1=["disable_resource_variables"]) 259def disable_resource_variables(): 260 """Opts out of resource variables. 261 262 If your code needs tf.disable_resource_variables() to be called to work 263 properly please file a bug. 264 """ 265 global _DEFAULT_USE_RESOURCE 266 _DEFAULT_USE_RESOURCE = False 267 _api_usage_gauge.get_cell().set(False) 268 269 270class _VariableStore(object): 271 """Variable store that carries a number of named Variables. 272 273 New variable names and new variables can be created; all stored 274 variables are initialized with the initializer passed to __init__. 275 276 Attributes: 277 vars: a dictionary with string names (same as passed in GetVar) as keys and 278 the corresponding TensorFlow Variables as values. 279 """ 280 281 def __init__(self): 282 """Create a variable store.""" 283 self._vars = {} # A dictionary of the stored TensorFlow variables. 284 self._partitioned_vars = {} # A dict of the stored PartitionedVariables. 285 self._store_eager_variables = False 286 287 def get_variable(self, 288 name, 289 shape=None, 290 dtype=dtypes.float32, 291 initializer=None, 292 regularizer=None, 293 reuse=None, 294 trainable=None, 295 collections=None, 296 caching_device=None, 297 partitioner=None, 298 validate_shape=True, 299 use_resource=None, 300 custom_getter=None, 301 constraint=None, 302 synchronization=VariableSynchronization.AUTO, 303 aggregation=VariableAggregation.NONE): 304 """Gets an existing variable with these parameters or create a new one. 305 306 If a variable with the given name is already stored, we return the stored 307 variable. Otherwise, we create a new one. 308 309 Set `reuse` to `True` when you only want to reuse existing Variables. 310 Set `reuse` to `False` when you only want to create new Variables. 311 Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want 312 variables to be created if they don't exist or returned if they do. 313 314 If initializer is `None` (the default), the default initializer passed in 315 the constructor is used. If that one is `None` too, we use a new 316 `glorot_uniform_initializer`. If initializer is a Tensor, we use 317 it as a value and derive the shape from the initializer. 318 319 If a partitioner is provided, a `PartitionedVariable` is returned. 320 Accessing this object as a `Tensor` returns the shards concatenated along 321 the partition axis. 322 323 Some useful partitioners are available. See, e.g., 324 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 325 326 Args: 327 name: The name of the new or existing variable. 328 shape: Shape of the new or existing variable. 329 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 330 initializer: Initializer for the variable. 331 regularizer: A (Tensor -> Tensor or None) function; the result of applying 332 it on a newly created variable will be added to the collection 333 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 334 reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of 335 variables. When eager execution is enabled this argument is always 336 forced to be False. 337 trainable: If `True` also add the variable to the graph collection 338 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). `trainable` 339 defaults to `True`, unless `synchronization` is set to `ON_READ`, in 340 which case it defaults to `False`. 341 collections: List of graph collections keys to add the `Variable` to. 342 Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 343 caching_device: Optional device string or function describing where the 344 Variable should be cached for reading. Defaults to the Variable's 345 device. If not `None`, caches on another device. Typical use is to 346 cache on the device where the Ops using the `Variable` reside, to 347 deduplicate copying through `Switch` and other conditional statements. 348 partitioner: Optional callable that accepts a fully defined `TensorShape` 349 and dtype of the `Variable` to be created, and returns a list of 350 partitions for each axis (currently only one axis can be partitioned). 351 validate_shape: If False, allows the variable to be initialized with a 352 value of unknown shape. If True, the default, the shape of initial_value 353 must be known. 354 use_resource: If False, creates a regular Variable. If True, creates 355 instead an experimental ResourceVariable which has well-defined 356 semantics. Defaults to False (will later change to True). When eager 357 execution is enabled this argument is always forced to be true. 358 custom_getter: Callable that takes as a first argument the true getter, 359 and allows overwriting the internal get_variable method. The signature 360 of `custom_getter` should match that of this method, 361 but the most future-proof version will allow for changes: `def 362 custom_getter(getter, *args, **kwargs)`. Direct access to 363 all `get_variable` parameters is also allowed: `def 364 custom_getter(getter, name, *args, **kwargs)`. A simple identity 365 custom getter that simply creates variables with modified names is: 366 ```python 367 def custom_getter(getter, name, *args, **kwargs): return getter(name + 368 '_suffix', *args, **kwargs) ``` 369 constraint: An optional projection function to be applied to the variable 370 after being updated by an `Optimizer` (e.g. used to implement norm 371 constraints or value constraints for layer weights). The function must 372 take as input the unprojected Tensor representing the value of the 373 variable and return the Tensor for the projected value (which must have 374 the same shape). Constraints are not safe to use when doing asynchronous 375 distributed training. 376 synchronization: Indicates when a distributed a variable will be 377 aggregated. Accepted values are constants defined in the class 378 `tf.VariableSynchronization`. By default the synchronization is set to 379 `AUTO` and the current `DistributionStrategy` chooses when to 380 synchronize. 381 aggregation: Indicates how a distributed variable will be aggregated. 382 Accepted values are constants defined in the class 383 `tf.VariableAggregation`. 384 385 Returns: 386 The created or existing `Variable` (or `PartitionedVariable`, if a 387 partitioner was used). 388 389 Raises: 390 ValueError: when creating a new variable and shape is not declared, 391 when reusing a variable and specifying a conflicting shape, 392 or when violating reuse during variable creation. 393 RuntimeError: when eager execution is enabled and not called from an 394 EagerVariableStore. 395 """ 396 if custom_getter is not None and not callable(custom_getter): 397 raise ValueError("Passed a custom_getter which is not callable: %s" % 398 custom_getter) 399 400 with ops.init_scope(): 401 if context.executing_eagerly(): 402 # Variable creation and initialization takes place in `init_scope`s; 403 # as such, if an `init_scope` lifts us into the eager context, then we 404 # need to use `ResourceVariable`s. 405 use_resource = True 406 407 # Note that it's fine to reuse eager variables whose initialization was 408 # lifted from a function-building graph into the eager context (that's why 409 # the following clause is not wrapped in an `init_scope`); lifted variables 410 # are tracked by the graph's `VariableStore`. 411 if context.executing_eagerly(): 412 if not self._store_eager_variables and reuse: 413 raise RuntimeError( 414 "When eager execution is enabled variable reuse is only supported" 415 " when an EagerVariableStore is active. See the documentation on" 416 " EagerVariableStore for example usage.") 417 if self._store_eager_variables: 418 reuse = AUTO_REUSE 419 420 # If a *_ref type is passed in an error would be triggered further down the 421 # stack. We prevent this using base_dtype to get a non-ref version of the 422 # type, before doing anything else. When _ref types are removed in favor of 423 # resources, this line can be removed. 424 try: 425 dtype = dtype.base_dtype 426 except AttributeError: 427 # .base_dtype not existing means that we will try and use the raw dtype 428 # which was passed in - this might be a NumPy type which is valid. 429 pass 430 431 # This is the main logic of get_variable. However, custom_getter 432 # may override this logic. So we save it as a callable and pass 433 # it to custom_getter. 434 # Note: the parameters of _true_getter, and their documentation, match 435 # *exactly* item-for-item with the docstring of this method. 436 def _true_getter( # pylint: disable=missing-docstring 437 name, 438 shape=None, 439 dtype=dtypes.float32, 440 initializer=None, 441 regularizer=None, 442 reuse=None, 443 trainable=None, 444 collections=None, 445 caching_device=None, 446 partitioner=None, 447 validate_shape=True, 448 use_resource=None, 449 constraint=None, 450 synchronization=VariableSynchronization.AUTO, 451 aggregation=VariableAggregation.NONE): 452 is_scalar = ( 453 shape is not None and isinstance(shape, collections_lib.Sequence) and 454 not shape) 455 # Partitioned variable case 456 if partitioner is not None and not is_scalar: 457 if not callable(partitioner): 458 raise ValueError("Partitioner must be callable, but received: %s" % 459 partitioner) 460 with ops.name_scope(None): 461 return self._get_partitioned_variable( 462 name=name, 463 shape=shape, 464 dtype=dtype, 465 initializer=initializer, 466 regularizer=regularizer, 467 reuse=reuse, 468 trainable=trainable, 469 collections=collections, 470 caching_device=caching_device, 471 partitioner=partitioner, 472 validate_shape=validate_shape, 473 use_resource=use_resource, 474 constraint=constraint, 475 synchronization=synchronization, 476 aggregation=aggregation) 477 478 # Special case for partitioned variable to allow reuse without having to 479 # specify partitioner. 480 if (reuse is True and partitioner is None 481 and name in self._partitioned_vars): 482 return self._get_partitioned_variable( 483 name=name, 484 shape=shape, 485 dtype=dtype, 486 initializer=initializer, 487 regularizer=regularizer, 488 reuse=reuse, 489 trainable=trainable, 490 collections=collections, 491 caching_device=caching_device, 492 partitioner=None, 493 validate_shape=validate_shape, 494 use_resource=use_resource, 495 constraint=constraint, 496 synchronization=synchronization, 497 aggregation=aggregation) 498 499 # Single variable case 500 if "%s/part_0" % name in self._vars: 501 raise ValueError( 502 "No partitioner was provided, but a partitioned version of the " 503 "variable was found: %s/part_0. Perhaps a variable of the same " 504 "name was already created with partitioning?" % name) 505 506 return self._get_single_variable( 507 name=name, 508 shape=shape, 509 dtype=dtype, 510 initializer=initializer, 511 regularizer=regularizer, 512 reuse=reuse, 513 trainable=trainable, 514 collections=collections, 515 caching_device=caching_device, 516 validate_shape=validate_shape, 517 use_resource=use_resource, 518 constraint=constraint, 519 synchronization=synchronization, 520 aggregation=aggregation) 521 522 synchronization, aggregation, trainable = ( 523 variables.validate_synchronization_aggregation_trainable( 524 synchronization, aggregation, trainable, name)) 525 526 if custom_getter is not None: 527 # Handle backwards compatibility with getter arguments that were added 528 # to the API after users started writing custom getters. 529 custom_getter_kwargs = { 530 "getter": _true_getter, 531 "name": name, 532 "shape": shape, 533 "dtype": dtype, 534 "initializer": initializer, 535 "regularizer": regularizer, 536 "reuse": reuse, 537 "trainable": trainable, 538 "collections": collections, 539 "caching_device": caching_device, 540 "partitioner": partitioner, 541 "validate_shape": validate_shape, 542 "use_resource": use_resource, 543 "synchronization": synchronization, 544 "aggregation": aggregation, 545 } 546 # `fn_args` and `has_kwargs` can handle functions, `functools.partial`, 547 # `lambda`. 548 if ("constraint" in function_utils.fn_args(custom_getter) or 549 function_utils.has_kwargs(custom_getter)): 550 custom_getter_kwargs["constraint"] = constraint 551 return custom_getter(**custom_getter_kwargs) 552 else: 553 return _true_getter( 554 name, 555 shape=shape, 556 dtype=dtype, 557 initializer=initializer, 558 regularizer=regularizer, 559 reuse=reuse, 560 trainable=trainable, 561 collections=collections, 562 caching_device=caching_device, 563 partitioner=partitioner, 564 validate_shape=validate_shape, 565 use_resource=use_resource, 566 constraint=constraint, 567 synchronization=synchronization, 568 aggregation=aggregation) 569 570 def _get_partitioned_variable(self, 571 name, 572 partitioner, 573 shape=None, 574 dtype=dtypes.float32, 575 initializer=None, 576 regularizer=None, 577 reuse=None, 578 trainable=None, 579 collections=None, 580 caching_device=None, 581 validate_shape=True, 582 use_resource=None, 583 constraint=None, 584 synchronization=VariableSynchronization.AUTO, 585 aggregation=VariableAggregation.NONE): 586 """Gets or creates a sharded variable list with these parameters. 587 588 The `partitioner` must be a callable that accepts a fully defined 589 `TensorShape` and returns a sequence of integers (the `partitions`). 590 These integers describe how to partition the given sharded `Variable` 591 along the given dimension. That is, `partitions[1] = 3` means split 592 the `Variable` into 3 shards along dimension 1. Currently, sharding along 593 only one axis is supported. 594 595 If the list of variables with the given name (prefix) is already stored, 596 we return the stored variables. Otherwise, we create a new one. 597 598 Set `reuse` to `True` when you only want to reuse existing Variables. 599 Set `reuse` to `False` when you only want to create new Variables. 600 Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want 601 variables to be created if they don't exist or returned if they do. 602 603 If initializer is `None` (the default), the default initializer passed in 604 the constructor is used. If that one is `None` too, we use a new 605 `glorot_uniform_initializer`. If initializer is a Tensor, we use 606 it as a value and derive the shape from the initializer. 607 608 If the initializer is a callable, then it will be called for each 609 shard. Otherwise the initializer should match the shape of the entire 610 sharded Variable, and it will be sliced accordingly for each shard. 611 612 Some useful partitioners are available. See, e.g., 613 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 614 615 Args: 616 name: the name of the new or existing sharded variable. 617 partitioner: Optional callable that accepts a fully defined `TensorShape` 618 and `dtype` of the Variable to be created, and returns a list of 619 partitions for each axis (currently only one axis can be partitioned). 620 shape: shape of the new or existing sharded variable. 621 dtype: type of the new or existing sharded variable (defaults to 622 `DT_FLOAT`). 623 initializer: initializer for the sharded variable. 624 regularizer: a (Tensor -> Tensor or None) function; the result of applying 625 it on a newly created variable will be added to the collection 626 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 627 reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of 628 variables. 629 trainable: If `True` also add the variable to the graph collection 630 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 631 collections: List of graph collections keys to add the Variable to. 632 Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 633 caching_device: Optional device string or function describing where the 634 Variable should be cached for reading. Defaults to the Variable's 635 device. If not `None`, caches on another device. Typical use is to 636 cache on the device where the Ops using the Variable reside, to 637 deduplicate copying through `Switch` and other conditional statements. 638 validate_shape: If False, allows the variable to be initialized with a 639 value of unknown shape. If True, the default, the shape of initial_value 640 must be known. 641 use_resource: If False, creates a regular Variable. If True, creates an 642 experimental ResourceVariable which has well-defined semantics. Defaults 643 to False (will later change to True). 644 constraint: An optional projection function to be applied to the variable 645 after being updated by an `Optimizer` (e.g. used to implement norm 646 constraints or value constraints for layer weights). The function must 647 take as input the unprojected Tensor representing the value of the 648 variable and return the Tensor for the projected value (which must have 649 the same shape). Constraints are not safe to use when doing asynchronous 650 distributed training. 651 synchronization: Indicates when a distributed a variable will be 652 aggregated. Accepted values are constants defined in the class 653 `tf.VariableSynchronization`. By default the synchronization is set to 654 `AUTO` and the current `DistributionStrategy` chooses when to 655 synchronize. 656 aggregation: Indicates how a distributed variable will be aggregated. 657 Accepted values are constants defined in the class 658 `tf.VariableAggregation`. 659 660 Returns: 661 A `PartitionedVariable` object. 662 663 Raises: 664 ValueError: when creating a new variable and shape is not declared, 665 when reusing a variable and specifying a conflicting shape, 666 when violating reuse during variable creation, or if an existing 667 sharded variable exists for the given name but with different sharding. 668 """ 669 initializing_from_value = initializer is not None and isinstance( 670 initializer, ops.Tensor) 671 if name in self._vars: 672 raise ValueError( 673 "A partitioner was provided, but an unpartitioned version of the " 674 "variable was found: %s. Perhaps a variable of the same name was " 675 "already created without partitioning?" % name) 676 677 shape = tensor_shape.as_shape(shape) 678 if initializing_from_value: 679 shape = shape.merge_with(initializer.get_shape()) 680 681 partitions = None 682 if not reuse or partitioner: 683 partitions = _call_partitioner(partitioner, shape, dtype) 684 685 if name in self._partitioned_vars: 686 if reuse is False: 687 raise ValueError( 688 "Partitioned variable with name %s already exists. Did you mean to " 689 "set reuse=True or reuse=tf.AUTO_REUSE in VarScope?" % name) 690 691 existing_var = self._partitioned_vars[name] 692 if not shape.is_compatible_with(existing_var.get_shape()): 693 raise ValueError( 694 "Trying to reuse partitioned variable %s, but specified shape %s " 695 "and found shape %s." % (name, shape, existing_var.get_shape())) 696 if not dtype.is_compatible_with(existing_var.dtype): 697 raise ValueError( 698 "Trying to reuse partitioned variable %s, but specified dtype %s " 699 "and found dtype %s." % (name, dtype.name, existing_var.dtype.name)) 700 701 # pylint: disable=protected-access 702 if (partitions is not None and 703 existing_var._get_partitions() != partitions): 704 raise ValueError( 705 "Trying to reuse partitioned variable %s, but specified partitions " 706 "%s and found partitions %s." % 707 (name, partitions, existing_var._get_partitions())) 708 # pylint: enable=protected-access 709 710 return existing_var 711 712 if reuse is True: 713 raise ValueError("PartitionedVariable %s does not exist, or was not " 714 "created with tf.get_variable(). Did you mean to set " 715 "reuse=False or reuse=tf.AUTO_REUSE in VarScope?" % name) 716 717 slice_dim, num_slices = _get_slice_dim_and_num_slices(partitions) 718 719 if "%s/part_0" % name in self._vars: 720 if "%s/part_%d" % (name, num_slices - 1) not in self._vars: 721 raise ValueError( 722 "Partitioner returned a different partitioning than what was " 723 "already found. Partitioner returned %d shards, and shard " 724 "%s/part_0 was found, but %s/part_%d was not." % 725 (num_slices, name, name, num_slices - 1)) 726 if "%s/part_%d" % (name, num_slices) in self._vars: 727 raise ValueError( 728 "Partitioner returned a different partitioning than what was " 729 "already found. Partitioner returned %d shards, and shard " 730 "%s/part_0 was found, but so was the extra shard %s/part_%d." % 731 (num_slices, name, name, num_slices)) 732 733 vs = [] 734 for i, (var_offset, var_shape) in enumerate( 735 _iter_slices(shape.as_list(), num_slices, slice_dim)): 736 partition_info = _PartitionInfo( 737 full_shape=shape.as_list(), var_offset=var_offset) 738 var_full_name = "%s/part_%d" % (name, i) 739 with ops.name_scope( 740 var_full_name + "/PartitionedInitializer", skip_on_eager=False): 741 # Create the tensor to initialize the variable with default value. 742 if initializer is None: 743 init, initializing_from_value = self._get_default_initializer( 744 name=name, shape=shape, dtype=dtype) 745 if initializing_from_value: 746 init_shape = None 747 else: 748 init_shape = var_shape 749 elif callable(initializer): 750 init = initializer 751 init_shape = var_shape 752 elif isinstance(initializer, ops.Tensor): 753 init = array_ops.slice(initializer, var_offset, var_shape) 754 # Use the dtype of the given tensor. 755 dtype = init.dtype.base_dtype 756 init_shape = None 757 else: 758 init = ops.convert_to_tensor(initializer, dtype=dtype) 759 init = array_ops.slice(init, var_offset, var_shape) 760 init_shape = None 761 762 with ops.name_scope(None): 763 var = self._get_single_variable( 764 name=var_full_name, 765 shape=init_shape, 766 dtype=dtype, 767 initializer=init, 768 partition_info=partition_info, 769 regularizer=regularizer, 770 reuse=reuse, 771 trainable=trainable, 772 collections=collections, 773 caching_device=caching_device, 774 validate_shape=validate_shape, 775 use_resource=use_resource, 776 constraint=constraint, 777 synchronization=synchronization, 778 aggregation=aggregation) 779 780 # pylint: disable=protected-access 781 var._set_save_slice_info( 782 variables.Variable.SaveSliceInfo(name, shape.as_list(), var_offset, 783 var_shape)) 784 vs.append(var) 785 # pylint: enable=protected-access 786 787 partitioned_var = variables.PartitionedVariable( 788 name=name, 789 shape=shape, 790 dtype=dtype, 791 variable_list=vs, 792 partitions=partitions) 793 if not context.executing_eagerly() or self._store_eager_variables: 794 self._partitioned_vars[name] = partitioned_var 795 return partitioned_var 796 797 def _get_single_variable(self, 798 name, 799 shape=None, 800 dtype=dtypes.float32, 801 initializer=None, 802 regularizer=None, 803 partition_info=None, 804 reuse=None, 805 trainable=None, 806 collections=None, 807 caching_device=None, 808 validate_shape=True, 809 use_resource=None, 810 constraint=None, 811 synchronization=VariableSynchronization.AUTO, 812 aggregation=VariableAggregation.NONE): 813 """Get or create a single Variable (e.g. 814 815 a shard or entire variable). 816 817 See the documentation of get_variable above (ignore partitioning components) 818 for details. 819 820 Args: 821 name: see get_variable. 822 shape: see get_variable. 823 dtype: see get_variable. 824 initializer: see get_variable. 825 regularizer: see get_variable. 826 partition_info: _PartitionInfo object. 827 reuse: see get_variable. 828 trainable: see get_variable. 829 collections: see get_variable. 830 caching_device: see get_variable. 831 validate_shape: see get_variable. 832 use_resource: see get_variable. 833 constraint: see get_variable. 834 synchronization: see get_variable. 835 aggregation: see get_variable. 836 837 Returns: 838 A Variable. See documentation of get_variable above. 839 840 Raises: 841 ValueError: See documentation of get_variable above. 842 """ 843 # Set to true if initializer is a constant. 844 initializing_from_value = False 845 if initializer is not None and not callable(initializer): 846 initializing_from_value = True 847 if shape is not None and initializing_from_value: 848 raise ValueError("If initializer is a constant, do not specify shape.") 849 850 dtype = dtypes.as_dtype(dtype) 851 shape = tensor_shape.as_shape(shape) 852 853 if name in self._vars: 854 # Here we handle the case when returning an existing variable. 855 if reuse is False: 856 var = self._vars[name] 857 err_msg = ("Variable %s already exists, disallowed." 858 " Did you mean to set reuse=True or " 859 "reuse=tf.AUTO_REUSE in VarScope?" % name) 860 # ResourceVariables don't have an op associated with so no traceback 861 if isinstance(var, resource_variable_ops.ResourceVariable): 862 raise ValueError(err_msg) 863 tb = var.op.traceback[::-1] 864 # Throw away internal tf entries and only take a few lines. In some 865 # cases the traceback can be longer (e.g. if someone uses factory 866 # functions to create variables) so we take more than needed in the 867 # default case. 868 tb = [x for x in tb if "tensorflow/python" not in x[0]][:5] 869 raise ValueError("%s Originally defined at:\n\n%s" % 870 (err_msg, "".join(traceback.format_list(tb)))) 871 found_var = self._vars[name] 872 if not shape.is_compatible_with(found_var.get_shape()): 873 raise ValueError("Trying to share variable %s, but specified shape %s" 874 " and found shape %s." % 875 (name, shape, found_var.get_shape())) 876 if not dtype.is_compatible_with(found_var.dtype): 877 dtype_str = dtype.name 878 found_type_str = found_var.dtype.name 879 raise ValueError("Trying to share variable %s, but specified dtype %s" 880 " and found dtype %s." % 881 (name, dtype_str, found_type_str)) 882 return found_var 883 884 # The code below handles only the case of creating a new variable. 885 if reuse is True: 886 raise ValueError("Variable %s does not exist, or was not created with " 887 "tf.get_variable(). Did you mean to set " 888 "reuse=tf.AUTO_REUSE in VarScope?" % name) 889 890 # Create the tensor to initialize the variable with default value. 891 if initializer is None: 892 initializer, initializing_from_value = self._get_default_initializer( 893 name=name, shape=shape, dtype=dtype) 894 # Enter an init scope when creating the initializer. 895 with ops.init_scope(): 896 if initializing_from_value: 897 init_val = initializer 898 variable_dtype = None 899 else: 900 # Instantiate initializer if provided initializer is a type object. 901 if tf_inspect.isclass(initializer): 902 initializer = initializer() 903 if shape is not None and shape.is_fully_defined(): 904 if "partition_info" in tf_inspect.getargspec(initializer).args: 905 init_val = lambda: initializer( # pylint: disable=g-long-lambda 906 shape.as_list(), 907 dtype=dtype, 908 partition_info=partition_info) 909 else: 910 init_val = lambda: initializer( # pylint: disable=g-long-lambda 911 shape.as_list(), dtype=dtype) 912 variable_dtype = dtype.base_dtype 913 elif len(tf_inspect.getargspec(initializer).args) == len( 914 tf_inspect.getargspec(initializer).defaults or []): 915 init_val = initializer 916 variable_dtype = None 917 else: 918 raise ValueError("The initializer passed is not valid. It should " 919 "be a callable with no arguments and the " 920 "shape should not be provided or an instance of " 921 "`tf.keras.initializers.*' and `shape` should be " 922 "fully defined.") 923 924 # Create the variable. 925 if use_resource is None: 926 # Set the default value if unspecified. 927 use_resource = _DEFAULT_USE_RESOURCE 928 v = variables.VariableV1( 929 initial_value=init_val, 930 name=name, 931 trainable=trainable, 932 collections=collections, 933 caching_device=caching_device, 934 dtype=variable_dtype, 935 validate_shape=validate_shape, 936 constraint=constraint, 937 use_resource=use_resource, 938 synchronization=synchronization, 939 aggregation=aggregation) 940 if context.executing_eagerly() and self._store_eager_variables: 941 if collections: 942 ops.add_to_collections(collections, v) 943 else: 944 ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, v) 945 if trainable: 946 ops.add_to_collection(ops.GraphKeys.TRAINABLE_VARIABLES, v) 947 948 if not context.executing_eagerly() or self._store_eager_variables: 949 # In eager mode we do not want to keep default references to Variable 950 # objects as this will prevent their memory from being released. 951 self._vars[name] = v 952 logging.vlog(1, "Created variable %s with shape %s and init %s", v.name, 953 format(shape), initializer) 954 955 # Run the regularizer if requested and save the resulting loss. 956 if regularizer: 957 def make_regularizer_op(): 958 with ops.colocate_with(v): 959 with ops.name_scope(name + "/Regularizer/"): 960 return regularizer(v) 961 962 if regularizer(v) is not None: 963 lazy_eval_tensor = _LazyEvalTensor(make_regularizer_op) 964 ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, 965 lazy_eval_tensor) 966 967 return v 968 969 # Initialize variable when no initializer provided 970 def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32): 971 """Provide a default initializer and a corresponding value. 972 973 Args: 974 name: see get_variable. 975 shape: see get_variable. 976 dtype: see get_variable. 977 978 Returns: 979 initializer and initializing_from_value. See get_variable above. 980 981 Raises: 982 ValueError: When giving unsupported dtype. 983 """ 984 del shape 985 # If dtype is DT_FLOAT, provide a uniform unit scaling initializer 986 if dtype.is_floating: 987 initializer = init_ops.glorot_uniform_initializer() 988 initializing_from_value = False 989 # If dtype is DT_INT/DT_UINT, provide a default value `zero` 990 # If dtype is DT_BOOL, provide a default value `FALSE` 991 elif (dtype.is_integer or dtype.is_unsigned or dtype.is_bool or 992 dtype == dtypes.string): 993 initializer = init_ops.zeros_initializer() 994 initializing_from_value = False 995 # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here? 996 else: 997 raise ValueError("An initializer for variable %s of %s is required" % 998 (name, dtype.base_dtype)) 999 1000 return initializer, initializing_from_value 1001 1002 1003class _LazyEvalTensor(object): 1004 """A Tensor-like object that only evaluates its thunk when used.""" 1005 1006 def __init__(self, thunk): 1007 """Initializes a _LazyEvalTensor object. 1008 1009 Args: 1010 thunk: A callable. A thunk which computes the value of the tensor. 1011 """ 1012 self._thunk = thunk 1013 self._master_tensor = thunk() 1014 1015 def _as_tensor(self, dtype=None, name=None, as_ref=False): 1016 del name 1017 assert not as_ref 1018 assert dtype in [None, self.dtype] 1019 1020 return self._thunk() 1021 1022 1023def _make_master_property(name): 1024 @property 1025 def prop(self): 1026 return getattr(self._master_tensor, name) # pylint: disable=protected-access 1027 return prop 1028 1029_master_property_list = ("device", "dtype", "graph", "name", "op", "shape", 1030 "value_index") 1031for _name in _master_property_list: 1032 setattr(_LazyEvalTensor, _name, _make_master_property(_name)) 1033 1034 1035def _make_master_method(name): 1036 def method(self, *args, **kwargs): 1037 return getattr(self._master_tensor, name)(*args, **kwargs) # pylint: disable=protected-access 1038 return method 1039 1040_master_method_list = ("get_shape", "__str__", "shape_as_list") 1041for _name in _master_method_list: 1042 setattr(_LazyEvalTensor, _name, _make_master_method(_name)) 1043 1044 1045def _make_op_method(name): 1046 def method(self, *args, **kwargs): 1047 return getattr(self._as_tensor(), name)(*args, **kwargs) # pylint: disable=protected-access 1048 return method 1049 1050_op_list = ("__abs__", "__add__", "__and__", "__bool__", "__div__", "__eq__", 1051 "__floordiv__", "__ge__", "__getitem__", "__gt__", "__invert__", 1052 "__iter__", "__le__", "__len__", "__lt__", "__matmul__", "__mod__", 1053 "__mul__", "__ne__", "__neg__", "__nonzero__", "__or__", "__pow__", 1054 "__radd__", "__rand__", "__rdiv__", "__rfloordiv__", "__rmatmul__", 1055 "__rmod__", "__rmul__", "__ror__", "__rpow__", "__rsub__", 1056 "__rtruediv__", "__rxor__", "__sub__", "__truediv__", "__xor__", 1057 "eval", "numpy") 1058for _name in _op_list: 1059 setattr(_LazyEvalTensor, _name, _make_op_method(_name)) 1060 1061 1062ops.register_tensor_conversion_function( 1063 _LazyEvalTensor, 1064 lambda val, dtype, name, as_ref: val._as_tensor(dtype, name, as_ref) # pylint: disable=protected-access 1065 ) 1066 1067session.register_session_run_conversion_functions( 1068 _LazyEvalTensor, 1069 lambda fetch: ([fetch._master_tensor], lambda fetched_vals: fetched_vals[0]) # pylint: disable=protected-access 1070 ) 1071 1072ops.register_dense_tensor_like_type(_LazyEvalTensor) 1073 1074 1075# To stop regularization, use this regularizer 1076@tf_export(v1=["no_regularizer"]) 1077def no_regularizer(_): 1078 """Use this function to prevent regularization of variables.""" 1079 return None 1080 1081 1082# TODO(alive): support caching devices and partitioned variables in Eager mode. 1083@tf_export(v1=["VariableScope"]) 1084class VariableScope(object): 1085 """Variable scope object to carry defaults to provide to `get_variable`. 1086 1087 Many of the arguments we need for `get_variable` in a variable store are most 1088 easily handled with a context. This object is used for the defaults. 1089 1090 Attributes: 1091 name: name of the current scope, used as prefix in get_variable. 1092 initializer: default initializer passed to get_variable. 1093 regularizer: default regularizer passed to get_variable. 1094 reuse: Boolean, None, or tf.compat.v1.AUTO_REUSE, setting the reuse in 1095 get_variable. When eager execution is enabled this argument is always 1096 forced to be False. 1097 caching_device: string, callable, or None: the caching device passed to 1098 get_variable. 1099 partitioner: callable or `None`: the partitioner passed to `get_variable`. 1100 custom_getter: default custom getter passed to get_variable. 1101 name_scope: The name passed to `tf.name_scope`. 1102 dtype: default type passed to get_variable (defaults to DT_FLOAT). 1103 use_resource: if False, create a normal Variable; if True create an 1104 experimental ResourceVariable with well-defined semantics. Defaults to 1105 False (will later change to True). When eager execution is enabled this 1106 argument is always forced to be True. 1107 constraint: An optional projection function to be applied to the variable 1108 after being updated by an `Optimizer` (e.g. used to implement norm 1109 constraints or value constraints for layer weights). The function must 1110 take as input the unprojected Tensor representing the value of the 1111 variable and return the Tensor for the projected value (which must have 1112 the same shape). Constraints are not safe to use when doing asynchronous 1113 distributed training. 1114 """ 1115 1116 def __init__(self, 1117 reuse, 1118 name="", 1119 initializer=None, 1120 regularizer=None, 1121 caching_device=None, 1122 partitioner=None, 1123 custom_getter=None, 1124 name_scope="", 1125 dtype=dtypes.float32, 1126 use_resource=None, 1127 constraint=None): 1128 """Creates a new VariableScope with the given properties.""" 1129 self._name = name 1130 self._initializer = initializer 1131 self._regularizer = regularizer 1132 self._reuse = reuse 1133 self._caching_device = caching_device 1134 self._partitioner = partitioner 1135 self._custom_getter = custom_getter 1136 self._name_scope = name_scope 1137 self._dtype = dtype 1138 self._use_resource = use_resource 1139 self._constraint = constraint 1140 if context.executing_eagerly(): 1141 if self._caching_device is not None: 1142 raise NotImplementedError("Caching devices is not yet supported " 1143 "when eager execution is enabled.") 1144 self._reuse = AUTO_REUSE 1145 self._use_resource = True 1146 1147 @property 1148 def name(self): 1149 return self._name 1150 1151 @property 1152 def original_name_scope(self): 1153 return self._name_scope 1154 1155 @property 1156 def reuse(self): 1157 return self._reuse 1158 1159 @property 1160 def initializer(self): 1161 return self._initializer 1162 1163 @property 1164 def dtype(self): 1165 return self._dtype 1166 1167 @property 1168 def use_resource(self): 1169 return self._use_resource 1170 1171 @property 1172 def regularizer(self): 1173 return self._regularizer 1174 1175 @property 1176 def caching_device(self): 1177 return self._caching_device 1178 1179 @property 1180 def partitioner(self): 1181 return self._partitioner 1182 1183 @property 1184 def custom_getter(self): 1185 return self._custom_getter 1186 1187 @property 1188 def constraint(self): 1189 return self._constraint 1190 1191 def reuse_variables(self): 1192 """Reuse variables in this scope.""" 1193 self._reuse = True 1194 1195 def set_initializer(self, initializer): 1196 """Set initializer for this scope.""" 1197 self._initializer = initializer 1198 1199 def set_dtype(self, dtype): 1200 """Set data type for this scope.""" 1201 self._dtype = dtype 1202 1203 def set_use_resource(self, use_resource): 1204 """Sets whether to use ResourceVariables for this scope.""" 1205 if context.executing_eagerly() and not use_resource: 1206 raise ValueError("When eager execution is enabled, " 1207 "use_resource cannot be set to false.") 1208 self._use_resource = use_resource 1209 1210 def set_regularizer(self, regularizer): 1211 """Set regularizer for this scope.""" 1212 self._regularizer = regularizer 1213 1214 def set_caching_device(self, caching_device): 1215 """Set caching_device for this scope.""" 1216 if context.executing_eagerly(): 1217 raise NotImplementedError("Caching devices are not yet supported " 1218 "when eager execution is enabled.") 1219 self._caching_device = caching_device 1220 1221 def set_partitioner(self, partitioner): 1222 """Set partitioner for this scope.""" 1223 self._partitioner = partitioner 1224 1225 def set_custom_getter(self, custom_getter): 1226 """Set custom getter for this scope.""" 1227 self._custom_getter = custom_getter 1228 1229 def get_collection(self, name): 1230 """Get this scope's variables.""" 1231 scope = self._name + "/" if self._name else "" 1232 return ops.get_collection(name, scope) 1233 1234 def trainable_variables(self): 1235 """Get this scope's trainable variables.""" 1236 return self.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) 1237 1238 def global_variables(self): 1239 """Get this scope's global variables.""" 1240 return self.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 1241 1242 def local_variables(self): 1243 """Get this scope's local variables.""" 1244 return self.get_collection(ops.GraphKeys.LOCAL_VARIABLES) 1245 1246 def get_variable(self, 1247 var_store, 1248 name, 1249 shape=None, 1250 dtype=None, 1251 initializer=None, 1252 regularizer=None, 1253 reuse=None, 1254 trainable=None, 1255 collections=None, 1256 caching_device=None, 1257 partitioner=None, 1258 validate_shape=True, 1259 use_resource=None, 1260 custom_getter=None, 1261 constraint=None, 1262 synchronization=VariableSynchronization.AUTO, 1263 aggregation=VariableAggregation.NONE): 1264 """Gets an existing variable with this name or create a new one.""" 1265 if regularizer is None: 1266 regularizer = self._regularizer 1267 if caching_device is None: 1268 caching_device = self._caching_device 1269 if partitioner is None: 1270 partitioner = self._partitioner 1271 if custom_getter is None: 1272 custom_getter = self._custom_getter 1273 if context.executing_eagerly(): 1274 reuse = False 1275 use_resource = True 1276 else: 1277 if reuse is None: 1278 reuse = self._reuse 1279 if use_resource is None: 1280 use_resource = self._use_resource 1281 1282 full_name = self.name + "/" + name if self.name else name 1283 # Variable names only depend on variable_scope (full_name here), 1284 # not name_scope, so we reset it below for the time of variable creation. 1285 with ops.name_scope(None, skip_on_eager=False): 1286 # Check that `initializer` dtype and `dtype` are consistent before 1287 # replacing them with defaults. 1288 if (dtype is not None and initializer is not None and 1289 not callable(initializer)): 1290 init_dtype = ops.convert_to_tensor(initializer).dtype.base_dtype 1291 if init_dtype != dtype: 1292 raise ValueError("Initializer type '%s' and explicit dtype '%s' " 1293 "don't match." % (init_dtype, dtype)) 1294 if initializer is None: 1295 initializer = self._initializer 1296 if constraint is None: 1297 constraint = self._constraint 1298 if dtype is None: 1299 dtype = self._dtype 1300 return var_store.get_variable( 1301 full_name, 1302 shape=shape, 1303 dtype=dtype, 1304 initializer=initializer, 1305 regularizer=regularizer, 1306 reuse=reuse, 1307 trainable=trainable, 1308 collections=collections, 1309 caching_device=caching_device, 1310 partitioner=partitioner, 1311 validate_shape=validate_shape, 1312 use_resource=use_resource, 1313 custom_getter=custom_getter, 1314 constraint=constraint, 1315 synchronization=synchronization, 1316 aggregation=aggregation) 1317 1318 def _get_partitioned_variable(self, 1319 var_store, 1320 name, 1321 shape=None, 1322 dtype=None, 1323 initializer=None, 1324 regularizer=None, 1325 trainable=None, 1326 collections=None, 1327 caching_device=None, 1328 partitioner=None, 1329 validate_shape=True, 1330 use_resource=None, 1331 constraint=None, 1332 synchronization=VariableSynchronization.AUTO, 1333 aggregation=VariableAggregation.NONE): 1334 """Gets an existing variable with this name or create a new one.""" 1335 if initializer is None: 1336 initializer = self._initializer 1337 if regularizer is None: 1338 regularizer = self._regularizer 1339 if constraint is None: 1340 constraint = self._constraint 1341 if caching_device is None: 1342 caching_device = self._caching_device 1343 if partitioner is None: 1344 partitioner = self._partitioner 1345 if dtype is None: 1346 dtype = self._dtype 1347 if use_resource is None: 1348 use_resource = self._use_resource 1349 1350 if self._custom_getter is not None: 1351 raise ValueError( 1352 "Private access to _get_partitioned_variable is not allowed when " 1353 "a custom getter is set. Current custom getter: %s. " 1354 "It is likely that you're using create_partitioned_variables. " 1355 "If so, consider instead using get_variable with a non-empty " 1356 "partitioner parameter instead." % self._custom_getter) 1357 1358 if partitioner is None: 1359 raise ValueError("No partitioner was specified") 1360 1361 # This allows the variable scope name to be used as the variable name if 1362 # this function is invoked with an empty name arg, for backward 1363 # compatibility with create_partitioned_variables(). 1364 full_name_list = [] 1365 if self.name: 1366 full_name_list.append(self.name) 1367 if name: 1368 full_name_list.append(name) 1369 full_name = "/".join(full_name_list) 1370 1371 # Variable names only depend on variable_scope (full_name here), 1372 # not name_scope, so we reset it below for the time of variable creation. 1373 with ops.name_scope(None, skip_on_eager=False): 1374 # pylint: disable=protected-access 1375 return var_store._get_partitioned_variable( 1376 full_name, 1377 shape=shape, 1378 dtype=dtype, 1379 initializer=initializer, 1380 regularizer=regularizer, 1381 reuse=self.reuse, 1382 trainable=trainable, 1383 collections=collections, 1384 caching_device=caching_device, 1385 partitioner=partitioner, 1386 validate_shape=validate_shape, 1387 use_resource=use_resource, 1388 constraint=constraint, 1389 synchronization=synchronization, 1390 aggregation=aggregation) 1391 # pylint: enable=protected-access 1392 1393 1394_VARSTORE_KEY = ("__variable_store",) 1395_VARSCOPESTORE_KEY = ("__varscope",) 1396 1397 1398class _VariableScopeStore(threading.local): 1399 """A thread local store for the current variable scope and scope counts.""" 1400 1401 def __init__(self): 1402 super(_VariableScopeStore, self).__init__() 1403 self.current_scope = VariableScope(False) 1404 self.variable_scopes_count = {} 1405 1406 def open_variable_scope(self, scope_name): 1407 if scope_name in self.variable_scopes_count: 1408 self.variable_scopes_count[scope_name] += 1 1409 else: 1410 self.variable_scopes_count[scope_name] = 1 1411 1412 def close_variable_subscopes(self, scope_name): 1413 for k in list(self.variable_scopes_count.keys()): 1414 if scope_name is None or k.startswith(scope_name + "/"): 1415 self.variable_scopes_count[k] = 0 1416 1417 def variable_scope_count(self, scope_name): 1418 return self.variable_scopes_count.get(scope_name, 0) 1419 1420 1421def get_variable_scope_store(): 1422 """Returns the variable scope store for current thread.""" 1423 scope_store = ops.get_collection(_VARSCOPESTORE_KEY) 1424 1425 if not scope_store: 1426 scope_store = _VariableScopeStore() 1427 ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store) 1428 else: 1429 scope_store = scope_store[0] 1430 1431 return scope_store 1432 1433 1434@tf_export(v1=["get_variable_scope"]) 1435def get_variable_scope(): 1436 """Returns the current variable scope.""" 1437 return get_variable_scope_store().current_scope 1438 1439 1440def _get_default_variable_store(): 1441 store = ops.get_collection(_VARSTORE_KEY) 1442 if store: 1443 return store[0] 1444 store = _VariableStore() 1445 ops.add_to_collection(_VARSTORE_KEY, store) 1446 return store 1447 1448 1449@tf_contextlib.contextmanager 1450def with_variable_store(store): 1451 store_collection = ops.get_collection_ref(_VARSTORE_KEY) 1452 old = list(store_collection) 1453 store_collection[:] = [store] 1454 try: 1455 yield 1456 finally: 1457 store_collection[:] = old 1458 1459 1460class EagerVariableStore(object): 1461 """Wrapper allowing functional layers to be used with eager execution. 1462 1463 When eager execution is enabled Variables get deleted when they go out of 1464 scope, and are not stored in global collections by default. A lot of code 1465 (mostly the functional layers in tf.layers) assumes that variables are kept in 1466 a global list. 1467 1468 EagerVariableStore can be used in conjunction with this code to make it 1469 eager-friendly. For example, to create a dense layer, use: 1470 1471 ``` 1472 container = tfe.EagerVariableStore() 1473 for input in dataset_iterator: 1474 with container.as_default(): 1475 x = tf.compat.v1.layers.dense(input, name="l1") 1476 print(container.variables) # Should print the variables used in the layer. 1477 ``` 1478 """ 1479 1480 def __init__(self, store=None): 1481 if store is not None: 1482 if not store._store_eager_variables: # pylint: disable=protected-access 1483 raise ValueError("Cannot construct EagerVariableStore from a " 1484 "VariableStore object that does not hold eager " 1485 "variables.") 1486 self._store = store 1487 else: 1488 self._store = _VariableStore() 1489 self._store._store_eager_variables = True # pylint: disable=protected-access 1490 1491 def as_default(self): 1492 return with_variable_store(self._store) 1493 1494 def variables(self): 1495 return sorted(self._store._vars.values(), key=lambda x: x.name) # pylint: disable=protected-access 1496 1497 def trainable_variables(self): 1498 # pylint: disable=protected-access 1499 return sorted([x for x in self._store._vars.values() if x.trainable], 1500 key=lambda x: x.name) 1501 # pylint: enable=protected-access 1502 1503 def non_trainable_variables(self): 1504 # pylint: disable=protected-access 1505 return sorted([x for x in self._store._vars.values() if not x.trainable], 1506 key=lambda x: x.name) 1507 # pylint: enable=protected-access 1508 1509 def copy(self): 1510 """Copy this variable store and all of its contents. 1511 1512 Variables contained in this store will be copied over to the new variable 1513 store, meaning that they can be modified without affecting the variables in 1514 this store. 1515 1516 Returns: 1517 A new EagerVariableStore instance containing copied variables. 1518 """ 1519 # pylint: disable=protected-access 1520 new_store = EagerVariableStore() 1521 for key, var in iteritems(self._store._vars): 1522 # Strip device out of variable name. 1523 try: 1524 index = var.name.index(":") 1525 except ValueError: 1526 stripped_var_name = var.name 1527 else: 1528 stripped_var_name = var.name[:index] 1529 1530 # Create new variable with same value, name, and "trainable" flag. 1531 new_var = resource_variable_ops.ResourceVariable( 1532 var.read_value(), name=stripped_var_name, trainable=var.trainable) 1533 new_store._store._vars[key] = new_var 1534 return new_store 1535 # pylint: enable=protected-access 1536 1537 1538# The argument list for get_variable must match arguments to get_local_variable. 1539# So, if you are updating the arguments, also update arguments to 1540# get_local_variable below. 1541@tf_export(v1=["get_variable"]) 1542def get_variable(name, 1543 shape=None, 1544 dtype=None, 1545 initializer=None, 1546 regularizer=None, 1547 trainable=None, 1548 collections=None, 1549 caching_device=None, 1550 partitioner=None, 1551 validate_shape=True, 1552 use_resource=None, 1553 custom_getter=None, 1554 constraint=None, 1555 synchronization=VariableSynchronization.AUTO, 1556 aggregation=VariableAggregation.NONE): 1557 return get_variable_scope().get_variable( 1558 _get_default_variable_store(), 1559 name, 1560 shape=shape, 1561 dtype=dtype, 1562 initializer=initializer, 1563 regularizer=regularizer, 1564 trainable=trainable, 1565 collections=collections, 1566 caching_device=caching_device, 1567 partitioner=partitioner, 1568 validate_shape=validate_shape, 1569 use_resource=use_resource, 1570 custom_getter=custom_getter, 1571 constraint=constraint, 1572 synchronization=synchronization, 1573 aggregation=aggregation) 1574 1575 1576get_variable_or_local_docstring = ("""%s 1577 1578%sThis function prefixes the name with the current variable scope 1579and performs reuse checks. See the 1580[Variable Scope How To](https://tensorflow.org/guide/variables) 1581for an extensive description of how reusing works. Here is a basic example: 1582 1583```python 1584def foo(): 1585 with tf.variable_scope("foo", reuse=tf.AUTO_REUSE): 1586 v = tf.get_variable("v", [1]) 1587 return v 1588 1589v1 = foo() # Creates v. 1590v2 = foo() # Gets the same, existing v. 1591assert v1 == v2 1592``` 1593 1594If initializer is `None` (the default), the default initializer passed in 1595the variable scope will be used. If that one is `None` too, a 1596`glorot_uniform_initializer` will be used. The initializer can also be 1597a Tensor, in which case the variable is initialized to this value and shape. 1598 1599Similarly, if the regularizer is `None` (the default), the default regularizer 1600passed in the variable scope will be used (if that is `None` too, 1601then by default no regularization is performed). 1602 1603If a partitioner is provided, a `PartitionedVariable` is returned. 1604Accessing this object as a `Tensor` returns the shards concatenated along 1605the partition axis. 1606 1607Some useful partitioners are available. See, e.g., 1608`variable_axis_size_partitioner` and `min_max_variable_partitioner`. 1609 1610Args: 1611 name: The name of the new or existing variable. 1612 shape: Shape of the new or existing variable. 1613 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 1614 initializer: Initializer for the variable if one is created. Can either be 1615 an initializer object or a Tensor. If it's a Tensor, its shape must be known 1616 unless validate_shape is False. 1617 regularizer: A (Tensor -> Tensor or None) function; the result of 1618 applying it on a newly created variable will be added to the collection 1619 `tf.GraphKeys.REGULARIZATION_LOSSES` and can be used for regularization. 1620 %scollections: List of graph collections keys to add the Variable to. 1621 Defaults to `[%s]` (see `tf.Variable`). 1622 caching_device: Optional device string or function describing where the 1623 Variable should be cached for reading. Defaults to the Variable's 1624 device. If not `None`, caches on another device. Typical use is to 1625 cache on the device where the Ops using the Variable reside, to 1626 deduplicate copying through `Switch` and other conditional statements. 1627 partitioner: Optional callable that accepts a fully defined `TensorShape` 1628 and `dtype` of the Variable to be created, and returns a list of 1629 partitions for each axis (currently only one axis can be partitioned). 1630 validate_shape: If False, allows the variable to be initialized with a 1631 value of unknown shape. If True, the default, the shape of initial_value 1632 must be known. For this to be used the initializer must be a Tensor and 1633 not an initializer object. 1634 use_resource: If False, creates a regular Variable. If true, creates an 1635 experimental ResourceVariable instead with well-defined semantics. 1636 Defaults to False (will later change to True). When eager execution is 1637 enabled this argument is always forced to be True. 1638 custom_getter: Callable that takes as a first argument the true getter, and 1639 allows overwriting the internal get_variable method. 1640 The signature of `custom_getter` should match that of this method, 1641 but the most future-proof version will allow for changes: 1642 `def custom_getter(getter, *args, **kwargs)`. Direct access to 1643 all `get_variable` parameters is also allowed: 1644 `def custom_getter(getter, name, *args, **kwargs)`. A simple identity 1645 custom getter that simply creates variables with modified names is: 1646 ```python 1647 def custom_getter(getter, name, *args, **kwargs): 1648 return getter(name + '_suffix', *args, **kwargs) 1649 ``` 1650 constraint: An optional projection function to be applied to the variable 1651 after being updated by an `Optimizer` (e.g. used to implement norm 1652 constraints or value constraints for layer weights). The function must 1653 take as input the unprojected Tensor representing the value of the 1654 variable and return the Tensor for the projected value 1655 (which must have the same shape). Constraints are not safe to 1656 use when doing asynchronous distributed training. 1657 synchronization: Indicates when a distributed a variable will be 1658 aggregated. Accepted values are constants defined in the class 1659 `tf.VariableSynchronization`. By default the synchronization is set to 1660 `AUTO` and the current `DistributionStrategy` chooses 1661 when to synchronize. 1662 aggregation: Indicates how a distributed variable will be aggregated. 1663 Accepted values are constants defined in the class 1664 `tf.VariableAggregation`. 1665 1666Returns: 1667 The created or existing `Variable` (or `PartitionedVariable`, if a 1668 partitioner was used). 1669 1670Raises: 1671 ValueError: when creating a new variable and shape is not declared, 1672 when violating reuse during variable creation, or when `initializer` dtype 1673 and `dtype` don't match. Reuse is set inside `variable_scope`. 1674""") 1675get_variable.__doc__ = get_variable_or_local_docstring % ( 1676 "Gets an existing variable with these parameters or create a new one.", "", 1677 "trainable: If `True` also add the variable to the graph collection\n" 1678 " `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).\n ", 1679 "GraphKeys.GLOBAL_VARIABLES") 1680 1681 1682# The argument list for get_local_variable must match arguments to get_variable. 1683# So, if you are updating the arguments, also update arguments to get_variable. 1684@tf_export(v1=["get_local_variable"]) 1685def get_local_variable( # pylint: disable=missing-docstring 1686 name, 1687 shape=None, 1688 dtype=None, 1689 initializer=None, 1690 regularizer=None, 1691 trainable=False, # pylint: disable=unused-argument 1692 collections=None, 1693 caching_device=None, 1694 partitioner=None, 1695 validate_shape=True, 1696 use_resource=None, 1697 custom_getter=None, 1698 constraint=None, 1699 synchronization=VariableSynchronization.AUTO, 1700 aggregation=VariableAggregation.NONE): 1701 if collections: 1702 collections += [ops.GraphKeys.LOCAL_VARIABLES] 1703 else: 1704 collections = [ops.GraphKeys.LOCAL_VARIABLES] 1705 return get_variable( 1706 name, 1707 shape=shape, 1708 dtype=dtype, 1709 initializer=initializer, 1710 regularizer=regularizer, 1711 trainable=False, 1712 collections=collections, 1713 caching_device=caching_device, 1714 partitioner=partitioner, 1715 validate_shape=validate_shape, 1716 use_resource=use_resource, 1717 synchronization=synchronization, 1718 aggregation=aggregation, 1719 custom_getter=custom_getter, 1720 constraint=constraint) 1721 1722 1723get_local_variable.__doc__ = get_variable_or_local_docstring % ( 1724 "Gets an existing *local* variable or creates a new one.", 1725 "Behavior is the same as in `get_variable`, except that variables are\n" 1726 "added to the `LOCAL_VARIABLES` collection and `trainable` is set to\n" 1727 "`False`.\n", "", "GraphKeys.LOCAL_VARIABLES") 1728 1729 1730def _get_partitioned_variable(name, 1731 shape=None, 1732 dtype=None, 1733 initializer=None, 1734 regularizer=None, 1735 trainable=True, 1736 collections=None, 1737 caching_device=None, 1738 partitioner=None, 1739 validate_shape=True, 1740 use_resource=None, 1741 constraint=None, 1742 synchronization=VariableSynchronization.AUTO, 1743 aggregation=VariableAggregation.NONE): 1744 """Gets or creates a sharded variable list with these parameters. 1745 1746 The `partitioner` must be a callable that accepts a fully defined 1747 `TensorShape` and returns a sequence of integers (the `partitions`). 1748 These integers describe how to partition the given sharded `Variable` 1749 along the given dimension. That is, `partitions[1] = 3` means split 1750 the `Variable` into 3 shards along dimension 1. Currently, sharding along 1751 only one axis is supported. 1752 1753 If the list of variables with the given name (prefix) is already stored, 1754 we return the stored variables. Otherwise, we create a new one. 1755 1756 If initializer is `None` (the default), the default initializer passed in 1757 the constructor is used. If that one is `None` too, we use a new 1758 `glorot_uniform_initializer`. If initializer is a Tensor, we use 1759 it as a value and derive the shape from the initializer. 1760 1761 If the initializer is a callable, then it will be called for each 1762 shard. Otherwise the initializer should match the shape of the entire 1763 sharded Variable, and it will be sliced accordingly for each shard. 1764 1765 Some useful partitioners are available. See, e.g., 1766 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 1767 1768 Args: 1769 name: The name of the new or existing variable. 1770 shape: Shape of the new or existing variable. 1771 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 1772 initializer: Initializer for the variable if one is created. 1773 regularizer: A (Tensor -> Tensor or None) function; the result of applying 1774 it on a newly created variable will be added to the collection 1775 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 1776 trainable: If `True` also add the variable to the graph collection 1777 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 1778 collections: List of graph collections keys to add the Variable to. Defaults 1779 to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 1780 caching_device: Optional device string or function describing where the 1781 Variable should be cached for reading. Defaults to the Variable's device. 1782 If not `None`, caches on another device. Typical use is to cache on the 1783 device where the Ops using the Variable reside, to deduplicate copying 1784 through `Switch` and other conditional statements. 1785 partitioner: Optional callable that accepts a fully defined `TensorShape` 1786 and `dtype` of the Variable to be created, and returns a list of 1787 partitions for each axis (currently only one axis can be partitioned). 1788 validate_shape: If False, allows the variable to be initialized with a value 1789 of unknown shape. If True, the default, the shape of initial_value must be 1790 known. 1791 use_resource: If False, creates a regular Variable. If True, creates an 1792 experimental ResourceVariable instead which has well-defined semantics. 1793 Defaults to False (will later change to True). 1794 constraint: An optional projection function to be applied to the variable 1795 after being updated by an `Optimizer` (e.g. used to implement norm 1796 constraints or value constraints for layer weights). The function must 1797 take as input the unprojected Tensor representing the value of the 1798 variable and return the Tensor for the projected value (which must have 1799 the same shape). Constraints are not safe to use when doing asynchronous 1800 distributed training. 1801 synchronization: Indicates when a distributed a variable will be aggregated. 1802 Accepted values are constants defined in the class 1803 `tf.VariableSynchronization`. By default the synchronization is set to 1804 `AUTO` and the current `DistributionStrategy` chooses when to synchronize. 1805 aggregation: Indicates how a distributed variable will be aggregated. 1806 Accepted values are constants defined in the class 1807 `tf.VariableAggregation`. 1808 1809 Returns: 1810 A tuple `(shards, partitions)` where `shards` is the list of `Variable` 1811 shards and `partitions` is the output of the partitioner on the input 1812 shape. 1813 1814 Raises: 1815 ValueError: when creating a new variable and shape is not declared, 1816 or when violating reuse during variable creation. Reuse is set inside 1817 `variable_scope`. 1818 """ 1819 # pylint: disable=protected-access 1820 scope = get_variable_scope() 1821 if scope.custom_getter is not None: 1822 raise ValueError( 1823 "Private access to _get_partitioned_variable is not allowed when " 1824 "a custom getter is set. Current custom getter: %s. " 1825 "It is likely that you're using create_partitioned_variables. " 1826 "If so, consider instead using get_variable with a non-empty " 1827 "partitioner parameter instead." % scope.custom_getter) 1828 return scope._get_partitioned_variable( 1829 _get_default_variable_store(), 1830 name, 1831 shape=shape, 1832 dtype=dtype, 1833 initializer=initializer, 1834 regularizer=regularizer, 1835 trainable=trainable, 1836 collections=collections, 1837 caching_device=caching_device, 1838 partitioner=partitioner, 1839 validate_shape=validate_shape, 1840 use_resource=use_resource, 1841 constraint=constraint, 1842 synchronization=synchronization, 1843 aggregation=aggregation) 1844 # pylint: enable=protected-access 1845 1846 1847# Named like a function for compatibility with the previous 1848# @tf_contextlib.contextmanager definition. 1849class _pure_variable_scope(object): # pylint: disable=invalid-name 1850 """A context for the variable_scope, see `variable_scope` for docs.""" 1851 1852 def __init__(self, 1853 name_or_scope, 1854 reuse=None, 1855 initializer=None, 1856 regularizer=None, 1857 caching_device=None, 1858 partitioner=None, 1859 custom_getter=None, 1860 old_name_scope=None, 1861 dtype=dtypes.float32, 1862 use_resource=None, 1863 constraint=None): 1864 """Creates a context for the variable_scope, see `variable_scope` for docs. 1865 1866 Note: this does not create a name scope. 1867 1868 Args: 1869 name_or_scope: `string` or `VariableScope`: the scope to open. 1870 reuse: `True` or None, or tf.compat.v1.AUTO_REUSE; if `None`, we inherit 1871 the parent scope's reuse flag. 1872 initializer: default initializer for variables within this scope. 1873 regularizer: default regularizer for variables within this scope. 1874 caching_device: default caching device for variables within this scope. 1875 partitioner: default partitioner for variables within this scope. 1876 custom_getter: default custom getter for variables within this scope. 1877 old_name_scope: the original name scope when re-entering a variable scope. 1878 dtype: type of the variables within this scope (defaults to `DT_FLOAT`). 1879 use_resource: If False, variables in this scope will be regular Variables. 1880 If True, experimental ResourceVariables will be creates instead, with 1881 well-defined semantics. Defaults to False (will later change to True). 1882 constraint: An optional projection function to be applied to the variable 1883 after being updated by an `Optimizer` (e.g. used to implement norm 1884 constraints or value constraints for layer weights). The function must 1885 take as input the unprojected Tensor representing the value of the 1886 variable and return the Tensor for the projected value (which must have 1887 the same shape). Constraints are not safe to use when doing asynchronous 1888 distributed training. 1889 """ 1890 self._name_or_scope = name_or_scope 1891 self._reuse = reuse 1892 self._initializer = initializer 1893 self._regularizer = regularizer 1894 self._caching_device = caching_device 1895 self._partitioner = partitioner 1896 self._custom_getter = custom_getter 1897 self._old_name_scope = old_name_scope 1898 self._dtype = dtype 1899 self._use_resource = use_resource 1900 self._constraint = constraint 1901 self._var_store = _get_default_variable_store() 1902 self._var_scope_store = get_variable_scope_store() 1903 self._last_variable_scope_object = None 1904 if isinstance(self._name_or_scope, VariableScope): 1905 self._new_name = self._name_or_scope.name 1906 name_scope = self._name_or_scope._name_scope # pylint: disable=protected-access 1907 # Handler for the case when we jump to a shared scope. We create a new 1908 # VariableScope (self._var_scope_object) that contains a copy of the 1909 # provided shared scope, possibly with changed reuse and initializer, if 1910 # the user requested this. 1911 variable_scope_object = VariableScope( 1912 self._name_or_scope.reuse if not self._reuse else self._reuse, 1913 name=self._new_name, 1914 initializer=self._name_or_scope.initializer, 1915 regularizer=self._name_or_scope.regularizer, 1916 caching_device=self._name_or_scope.caching_device, 1917 partitioner=self._name_or_scope.partitioner, 1918 dtype=self._name_or_scope.dtype, 1919 custom_getter=self._name_or_scope.custom_getter, 1920 name_scope=name_scope, 1921 use_resource=self._name_or_scope.use_resource, 1922 constraint=self._constraint) 1923 if self._initializer is not None: 1924 variable_scope_object.set_initializer(self._initializer) 1925 if self._regularizer is not None: 1926 variable_scope_object.set_regularizer(self._regularizer) 1927 if self._caching_device is not None: 1928 variable_scope_object.set_caching_device(self._caching_device) 1929 if self._partitioner is not None: 1930 variable_scope_object.set_partitioner(self._partitioner) 1931 if self._custom_getter is not None: 1932 variable_scope_object.set_custom_getter( 1933 _maybe_wrap_custom_getter(self._custom_getter, 1934 self._name_or_scope.custom_getter)) 1935 if self._dtype is not None: 1936 variable_scope_object.set_dtype(self._dtype) 1937 if self._use_resource is not None: 1938 variable_scope_object.set_use_resource(self._use_resource) 1939 self._cached_variable_scope_object = variable_scope_object 1940 1941 def __enter__(self): 1942 """Begins the scope block. 1943 1944 Returns: 1945 A VariableScope. 1946 Raises: 1947 ValueError: when trying to reuse within a create scope, or create within 1948 a reuse scope, or if reuse is not `None` or `True`. 1949 TypeError: when the types of some arguments are not appropriate. 1950 """ 1951 self._old = self._var_scope_store.current_scope 1952 if isinstance(self._name_or_scope, VariableScope): 1953 self._var_scope_store.open_variable_scope(self._new_name) 1954 self._old_subscopes = copy.copy( 1955 self._var_scope_store.variable_scopes_count) 1956 variable_scope_object = self._cached_variable_scope_object 1957 else: 1958 # Handler for the case when we just prolong current variable scope. 1959 # VariableScope with name extended by the provided one, and inherited 1960 # reuse and initializer (except if the user provided values to set). 1961 self._new_name = ( 1962 self._old.name + "/" + 1963 self._name_or_scope if self._old.name else self._name_or_scope) 1964 self._reuse = (self._reuse or 1965 self._old.reuse) # Re-using is inherited by sub-scopes. 1966 if self._old_name_scope is None: 1967 name_scope = self._name_or_scope 1968 else: 1969 name_scope = self._old_name_scope 1970 variable_scope_object = VariableScope( 1971 self._reuse, 1972 name=self._new_name, 1973 initializer=self._old.initializer, 1974 regularizer=self._old.regularizer, 1975 caching_device=self._old.caching_device, 1976 partitioner=self._old.partitioner, 1977 dtype=self._old.dtype, 1978 use_resource=self._old.use_resource, 1979 custom_getter=self._old.custom_getter, 1980 name_scope=name_scope, 1981 constraint=self._constraint) 1982 if self._initializer is not None: 1983 variable_scope_object.set_initializer(self._initializer) 1984 if self._regularizer is not None: 1985 variable_scope_object.set_regularizer(self._regularizer) 1986 if self._caching_device is not None: 1987 variable_scope_object.set_caching_device(self._caching_device) 1988 if self._partitioner is not None: 1989 variable_scope_object.set_partitioner(self._partitioner) 1990 if self._custom_getter is not None: 1991 variable_scope_object.set_custom_getter( 1992 _maybe_wrap_custom_getter(self._custom_getter, 1993 self._old.custom_getter)) 1994 if self._dtype is not None: 1995 variable_scope_object.set_dtype(self._dtype) 1996 if self._use_resource is not None: 1997 variable_scope_object.set_use_resource(self._use_resource) 1998 self._var_scope_store.open_variable_scope(self._new_name) 1999 self._var_scope_store.current_scope = variable_scope_object 2000 self._last_variable_scope_object = variable_scope_object 2001 return variable_scope_object 2002 2003 def __exit__(self, type_arg, value_arg, traceback_arg): 2004 if (self._var_scope_store.current_scope is 2005 not self._last_variable_scope_object): 2006 raise RuntimeError("Improper nesting of variable_scope.") 2007 # If jumping out from a non-prolonged scope, restore counts. 2008 if isinstance(self._name_or_scope, VariableScope): 2009 self._var_scope_store.variable_scopes_count = self._old_subscopes 2010 else: 2011 self._var_scope_store.close_variable_subscopes(self._new_name) 2012 self._var_scope_store.current_scope = self._old 2013 2014 2015def _maybe_wrap_custom_getter(custom_getter, old_getter): 2016 """Wrap a call to a custom_getter to use the old_getter internally.""" 2017 if old_getter is None: 2018 return custom_getter 2019 2020 # The new custom_getter should call the old one 2021 def wrapped_custom_getter(getter, *args, **kwargs): 2022 # Call: 2023 # custom_getter( 2024 # lambda: old_getter(true_getter, ...), *args, **kwargs) 2025 # which means custom_getter will call old_getter, which 2026 # will call the true_getter, perform any intermediate 2027 # processing, and return the results to the current 2028 # getter, which will also perform additional processing. 2029 return custom_getter(functools.partial(old_getter, getter), *args, **kwargs) 2030 2031 return wrapped_custom_getter 2032 2033 2034def _get_unique_variable_scope(prefix): 2035 """Get a name with the given prefix unique in the current variable scope.""" 2036 var_scope_store = get_variable_scope_store() 2037 current_scope = get_variable_scope() 2038 name = current_scope.name + "/" + prefix if current_scope.name else prefix 2039 if var_scope_store.variable_scope_count(name) == 0: 2040 return prefix 2041 idx = 1 2042 while var_scope_store.variable_scope_count(name + ("_%d" % idx)) > 0: 2043 idx += 1 2044 return prefix + ("_%d" % idx) 2045 2046 2047# Named like a function for backwards compatibility with the 2048# @tf_contextlib.contextmanager version, which was switched to a class to avoid 2049# some object creation overhead. 2050@tf_export(v1=["variable_scope"]) # pylint: disable=invalid-name 2051class variable_scope(object): 2052 """A context manager for defining ops that creates variables (layers). 2053 2054 This context manager validates that the (optional) `values` are from the same 2055 graph, ensures that graph is the default graph, and pushes a name scope and a 2056 variable scope. 2057 2058 If `name_or_scope` is not None, it is used as is. If `name_or_scope` is None, 2059 then `default_name` is used. In that case, if the same name has been 2060 previously used in the same scope, it will be made unique by appending `_N` 2061 to it. 2062 2063 Variable scope allows you to create new variables and to share already created 2064 ones while providing checks to not create or share by accident. For details, 2065 see the [Variable Scope How To](https://tensorflow.org/guide/variables), here 2066 we present only a few basic examples. 2067 2068 Simple example of how to create a new variable: 2069 2070 ```python 2071 with tf.compat.v1.variable_scope("foo"): 2072 with tf.compat.v1.variable_scope("bar"): 2073 v = tf.compat.v1.get_variable("v", [1]) 2074 assert v.name == "foo/bar/v:0" 2075 ``` 2076 2077 Simple example of how to reenter a premade variable scope safely: 2078 2079 ```python 2080 with tf.compat.v1.variable_scope("foo") as vs: 2081 pass 2082 2083 # Re-enter the variable scope. 2084 with tf.compat.v1.variable_scope(vs, 2085 auxiliary_name_scope=False) as vs1: 2086 # Restore the original name_scope. 2087 with tf.name_scope(vs1.original_name_scope): 2088 v = tf.compat.v1.get_variable("v", [1]) 2089 assert v.name == "foo/v:0" 2090 c = tf.constant([1], name="c") 2091 assert c.name == "foo/c:0" 2092 ``` 2093 2094 Keep in mind that the counters for `default_name` are discarded once the 2095 parent scope is exited. Therefore when the code re-enters the scope (for 2096 instance by saving it), all nested default_name counters will be restarted. 2097 2098 For instance: 2099 2100 ```python 2101 with tf.compat.v1.variable_scope("foo") as vs: 2102 with tf.compat.v1.variable_scope(None, default_name="bar"): 2103 v = tf.compat.v1.get_variable("a", [1]) 2104 assert v.name == "foo/bar/a:0", v.name 2105 with tf.compat.v1.variable_scope(None, default_name="bar"): 2106 v = tf.compat.v1.get_variable("b", [1]) 2107 assert v.name == "foo/bar_1/b:0" 2108 2109 with tf.compat.v1.variable_scope(vs): 2110 with tf.compat.v1.variable_scope(None, default_name="bar"): 2111 v = tf.compat.v1.get_variable("c", [1]) 2112 assert v.name == "foo/bar/c:0" # Uses bar instead of bar_2! 2113 ``` 2114 2115 Basic example of sharing a variable AUTO_REUSE: 2116 2117 ```python 2118 def foo(): 2119 with tf.compat.v1.variable_scope("foo", reuse=tf.compat.v1.AUTO_REUSE): 2120 v = tf.compat.v1.get_variable("v", [1]) 2121 return v 2122 2123 v1 = foo() # Creates v. 2124 v2 = foo() # Gets the same, existing v. 2125 assert v1 == v2 2126 ``` 2127 2128 Basic example of sharing a variable with reuse=True: 2129 2130 ```python 2131 with tf.compat.v1.variable_scope("foo"): 2132 v = tf.compat.v1.get_variable("v", [1]) 2133 with tf.compat.v1.variable_scope("foo", reuse=True): 2134 v1 = tf.compat.v1.get_variable("v", [1]) 2135 assert v1 == v 2136 ``` 2137 2138 Sharing a variable by capturing a scope and setting reuse: 2139 2140 ```python 2141 with tf.compat.v1.variable_scope("foo") as scope: 2142 v = tf.compat.v1.get_variable("v", [1]) 2143 scope.reuse_variables() 2144 v1 = tf.compat.v1.get_variable("v", [1]) 2145 assert v1 == v 2146 ``` 2147 2148 To prevent accidental sharing of variables, we raise an exception when getting 2149 an existing variable in a non-reusing scope. 2150 2151 ```python 2152 with tf.compat.v1.variable_scope("foo"): 2153 v = tf.compat.v1.get_variable("v", [1]) 2154 v1 = tf.compat.v1.get_variable("v", [1]) 2155 # Raises ValueError("... v already exists ..."). 2156 ``` 2157 2158 Similarly, we raise an exception when trying to get a variable that does not 2159 exist in reuse mode. 2160 2161 ```python 2162 with tf.compat.v1.variable_scope("foo", reuse=True): 2163 v = tf.compat.v1.get_variable("v", [1]) 2164 # Raises ValueError("... v does not exists ..."). 2165 ``` 2166 2167 Note that the `reuse` flag is inherited: if we open a reusing scope, then all 2168 its sub-scopes become reusing as well. 2169 2170 A note about name scoping: Setting `reuse` does not impact the naming of other 2171 ops such as mult. See related discussion on 2172 [github#6189](https://github.com/tensorflow/tensorflow/issues/6189) 2173 2174 Note that up to and including version 1.0, it was allowed (though explicitly 2175 discouraged) to pass False to the reuse argument, yielding undocumented 2176 behaviour slightly different from None. Starting at 1.1.0 passing None and 2177 False as reuse has exactly the same effect. 2178 2179 A note about using variable scopes in multi-threaded environment: Variable 2180 scopes are thread local, so one thread will not see another thread's current 2181 scope. Also, when using `default_name`, unique scopes names are also generated 2182 only on a per thread basis. If the same name was used within a different 2183 thread, that doesn't prevent a new thread from creating the same scope. 2184 However, the underlying variable store is shared across threads (within the 2185 same graph). As such, if another thread tries to create a new variable with 2186 the same name as a variable created by a previous thread, it will fail unless 2187 reuse is True. 2188 2189 Further, each thread starts with an empty variable scope. So if you wish to 2190 preserve name prefixes from a scope from the main thread, you should capture 2191 the main thread's scope and re-enter it in each thread. For e.g. 2192 2193 ``` 2194 main_thread_scope = variable_scope.get_variable_scope() 2195 2196 # Thread's target function: 2197 def thread_target_fn(captured_scope): 2198 with variable_scope.variable_scope(captured_scope): 2199 # .... regular code for this thread 2200 2201 2202 thread = threading.Thread(target=thread_target_fn, args=(main_thread_scope,)) 2203 ``` 2204 """ 2205 2206 def __init__(self, 2207 name_or_scope, 2208 default_name=None, 2209 values=None, 2210 initializer=None, 2211 regularizer=None, 2212 caching_device=None, 2213 partitioner=None, 2214 custom_getter=None, 2215 reuse=None, 2216 dtype=None, 2217 use_resource=None, 2218 constraint=None, 2219 auxiliary_name_scope=True): 2220 """Initialize the context manager. 2221 2222 Args: 2223 name_or_scope: `string` or `VariableScope`: the scope to open. 2224 default_name: The default name to use if the `name_or_scope` argument is 2225 `None`, this name will be uniquified. If name_or_scope is provided it 2226 won't be used and therefore it is not required and can be None. 2227 values: The list of `Tensor` arguments that are passed to the op function. 2228 initializer: default initializer for variables within this scope. 2229 regularizer: default regularizer for variables within this scope. 2230 caching_device: default caching device for variables within this scope. 2231 partitioner: default partitioner for variables within this scope. 2232 custom_getter: default custom getter for variables within this scope. 2233 reuse: `True`, None, or tf.compat.v1.AUTO_REUSE; if `True`, we go into 2234 reuse mode for this scope as well as all sub-scopes; if 2235 tf.compat.v1.AUTO_REUSE, we create variables if they do not exist, and 2236 return them otherwise; if None, we inherit the parent scope's reuse 2237 flag. When eager execution is enabled, new variables are always created 2238 unless an EagerVariableStore or template is currently active. 2239 dtype: type of variables created in this scope (defaults to the type in 2240 the passed scope, or inherited from parent scope). 2241 use_resource: If False, all variables will be regular Variables. If True, 2242 experimental ResourceVariables with well-defined semantics will be used 2243 instead. Defaults to False (will later change to True). When eager 2244 execution is enabled this argument is always forced to be True. 2245 constraint: An optional projection function to be applied to the variable 2246 after being updated by an `Optimizer` (e.g. used to implement norm 2247 constraints or value constraints for layer weights). The function must 2248 take as input the unprojected Tensor representing the value of the 2249 variable and return the Tensor for the projected value (which must have 2250 the same shape). Constraints are not safe to use when doing asynchronous 2251 distributed training. 2252 auxiliary_name_scope: If `True`, we create an auxiliary name scope with 2253 the scope. If `False`, we don't create it. Note that the argument is not 2254 inherited, and it only takes effect for once when creating. You should 2255 only use it for re-entering a premade variable scope. 2256 2257 Returns: 2258 A scope that can be captured and reused. 2259 2260 Raises: 2261 ValueError: when trying to reuse within a create scope, or create within 2262 a reuse scope. 2263 TypeError: when the types of some arguments are not appropriate. 2264 """ 2265 self._name_or_scope = name_or_scope 2266 self._default_name = default_name 2267 self._values = values 2268 self._initializer = initializer 2269 self._regularizer = regularizer 2270 self._caching_device = caching_device 2271 self._partitioner = partitioner 2272 self._custom_getter = custom_getter 2273 self._reuse = reuse 2274 self._dtype = dtype 2275 self._use_resource = use_resource 2276 self._constraint = constraint 2277 if self._default_name is None and self._name_or_scope is None: 2278 raise TypeError("If default_name is None then name_or_scope is required") 2279 if self._reuse is False: 2280 # We don't allow non-inheriting scopes, False = None here. 2281 self._reuse = None 2282 if not (self._reuse is True 2283 or self._reuse is None 2284 or self._reuse is AUTO_REUSE): 2285 raise ValueError("The reuse parameter must be True or False or None.") 2286 if self._values is None: 2287 self._values = [] 2288 self._in_graph_mode = not context.executing_eagerly() 2289 if self._in_graph_mode: 2290 self._graph = ops._get_graph_from_inputs(self._values) # pylint: disable=protected-access 2291 self._cached_pure_variable_scope = None 2292 self._current_name_scope = None 2293 if not isinstance(auxiliary_name_scope, bool): 2294 raise TypeError("The auxiliary_name_scope must be `True` or `False`, " 2295 "while get {}".format(auxiliary_name_scope)) 2296 self._auxiliary_name_scope = auxiliary_name_scope 2297 2298 def __enter__(self): 2299 # If the default graph is building a function, then we should not replace it 2300 # with the cached graph. 2301 if ops.get_default_graph().building_function: 2302 self._building_function = True 2303 else: 2304 self._building_function = False 2305 if self._in_graph_mode and not self._building_function: 2306 self._graph_context_manager = self._graph.as_default() 2307 self._graph_context_manager.__enter__() 2308 if self._cached_pure_variable_scope is not None: 2309 # Fast path for re-entering variable_scopes. We've held on to the pure 2310 # variable scope from a previous successful __enter__, so we avoid some 2311 # overhead by re-using that object. 2312 if self._current_name_scope is not None: 2313 self._current_name_scope.__enter__() 2314 return self._cached_pure_variable_scope.__enter__() 2315 2316 try: 2317 return self._enter_scope_uncached() 2318 except: 2319 if (self._in_graph_mode and not self._building_function and 2320 self._graph_context_manager is not None): 2321 self._graph_context_manager.__exit__(*sys.exc_info()) 2322 raise 2323 2324 def _enter_scope_uncached(self): 2325 """Enters the context manager when there is no cached scope yet. 2326 2327 Returns: 2328 The entered variable scope. 2329 2330 Raises: 2331 TypeError: A wrong type is passed as `scope` at __init__(). 2332 ValueError: `reuse` is incorrectly set at __init__(). 2333 """ 2334 if self._auxiliary_name_scope: 2335 # Create a new name scope later 2336 current_name_scope = None 2337 else: 2338 # Reenter the current name scope 2339 name_scope = ops.get_name_scope() 2340 if name_scope: 2341 # Hack to reenter 2342 name_scope += "/" 2343 current_name_scope = ops.name_scope(name_scope, skip_on_eager=False) 2344 else: 2345 # Root scope 2346 current_name_scope = ops.name_scope(name_scope, skip_on_eager=False) 2347 2348 # IMPORTANT: Only assign to self._cached_pure_variable_scope and 2349 # self._current_name_scope after successful __enter__() calls. 2350 if self._name_or_scope is not None: 2351 if not isinstance(self._name_or_scope, 2352 (VariableScope,) + six.string_types): 2353 raise TypeError("VariableScope: name_or_scope must be a string or " 2354 "VariableScope.") 2355 if isinstance(self._name_or_scope, six.string_types): 2356 name_scope = self._name_or_scope 2357 else: 2358 name_scope = self._name_or_scope.name.split("/")[-1] 2359 if name_scope or current_name_scope: 2360 current_name_scope = current_name_scope or ops.name_scope( 2361 name_scope, skip_on_eager=False) 2362 try: 2363 current_name_scope_name = current_name_scope.__enter__() 2364 except: 2365 current_name_scope.__exit__(*sys.exc_info()) 2366 raise 2367 self._current_name_scope = current_name_scope 2368 if isinstance(self._name_or_scope, six.string_types): 2369 old_name_scope = current_name_scope_name 2370 else: 2371 old_name_scope = self._name_or_scope.original_name_scope 2372 pure_variable_scope = _pure_variable_scope( 2373 self._name_or_scope, 2374 reuse=self._reuse, 2375 initializer=self._initializer, 2376 regularizer=self._regularizer, 2377 caching_device=self._caching_device, 2378 partitioner=self._partitioner, 2379 custom_getter=self._custom_getter, 2380 old_name_scope=old_name_scope, 2381 dtype=self._dtype, 2382 use_resource=self._use_resource, 2383 constraint=self._constraint) 2384 try: 2385 entered_pure_variable_scope = pure_variable_scope.__enter__() 2386 except: 2387 pure_variable_scope.__exit__(*sys.exc_info()) 2388 raise 2389 self._cached_pure_variable_scope = pure_variable_scope 2390 return entered_pure_variable_scope 2391 else: 2392 self._current_name_scope = None 2393 # This can only happen if someone is entering the root variable scope. 2394 pure_variable_scope = _pure_variable_scope( 2395 self._name_or_scope, 2396 reuse=self._reuse, 2397 initializer=self._initializer, 2398 regularizer=self._regularizer, 2399 caching_device=self._caching_device, 2400 partitioner=self._partitioner, 2401 custom_getter=self._custom_getter, 2402 dtype=self._dtype, 2403 use_resource=self._use_resource, 2404 constraint=self._constraint) 2405 try: 2406 entered_pure_variable_scope = pure_variable_scope.__enter__() 2407 except: 2408 pure_variable_scope.__exit__(*sys.exc_info()) 2409 raise 2410 self._cached_pure_variable_scope = pure_variable_scope 2411 return entered_pure_variable_scope 2412 2413 else: # Here name_or_scope is None. Using default name, but made unique. 2414 if self._reuse: 2415 raise ValueError("reuse=True cannot be used without a name_or_scope") 2416 current_name_scope = current_name_scope or ops.name_scope( 2417 self._default_name, skip_on_eager=False) 2418 try: 2419 current_name_scope_name = current_name_scope.__enter__() 2420 except: 2421 current_name_scope.__exit__(*sys.exc_info()) 2422 raise 2423 self._current_name_scope = current_name_scope 2424 unique_default_name = _get_unique_variable_scope(self._default_name) 2425 pure_variable_scope = _pure_variable_scope( 2426 unique_default_name, 2427 initializer=self._initializer, 2428 regularizer=self._regularizer, 2429 caching_device=self._caching_device, 2430 partitioner=self._partitioner, 2431 custom_getter=self._custom_getter, 2432 old_name_scope=current_name_scope_name, 2433 dtype=self._dtype, 2434 use_resource=self._use_resource, 2435 constraint=self._constraint) 2436 try: 2437 entered_pure_variable_scope = pure_variable_scope.__enter__() 2438 except: 2439 pure_variable_scope.__exit__(*sys.exc_info()) 2440 raise 2441 self._cached_pure_variable_scope = pure_variable_scope 2442 return entered_pure_variable_scope 2443 2444 def __exit__(self, type_arg, value_arg, traceback_arg): 2445 try: 2446 self._cached_pure_variable_scope.__exit__(type_arg, value_arg, 2447 traceback_arg) 2448 finally: 2449 try: 2450 if self._current_name_scope: 2451 self._current_name_scope.__exit__(type_arg, value_arg, 2452 traceback_arg) 2453 finally: 2454 if self._in_graph_mode and not self._building_function: 2455 self._graph_context_manager.__exit__(type_arg, value_arg, 2456 traceback_arg) 2457 2458 2459# pylint: disable=g-doc-return-or-yield 2460@tf_export(v1=["variable_op_scope"]) 2461@tf_contextlib.contextmanager 2462def variable_op_scope(values, 2463 name_or_scope, 2464 default_name=None, 2465 initializer=None, 2466 regularizer=None, 2467 caching_device=None, 2468 partitioner=None, 2469 custom_getter=None, 2470 reuse=None, 2471 dtype=None, 2472 use_resource=None, 2473 constraint=None): 2474 """Deprecated: context manager for defining an op that creates variables.""" 2475 logging.warn("tf.variable_op_scope(values, name, default_name) is deprecated," 2476 " use tf.variable_scope(name, default_name, values)") 2477 with variable_scope( 2478 name_or_scope, 2479 default_name=default_name, 2480 values=values, 2481 initializer=initializer, 2482 regularizer=regularizer, 2483 caching_device=caching_device, 2484 partitioner=partitioner, 2485 custom_getter=custom_getter, 2486 reuse=reuse, 2487 dtype=dtype, 2488 use_resource=use_resource, 2489 constraint=constraint) as scope: 2490 yield scope 2491 2492 2493def _call_partitioner(partitioner, shape, dtype): 2494 """Call partitioner validating its inputs/output. 2495 2496 Args: 2497 partitioner: a function mapping `Tensor` shape and dtype to a list of 2498 partitions. 2499 shape: shape of the `Tensor` to partition, must have at least two 2500 dimensions. 2501 dtype: dtype of the elements in the `Tensor`. 2502 2503 Returns: 2504 A list with elements >=1 and exactly one >1. The index of that 2505 element corresponds to the partitioning axis. 2506 """ 2507 if not shape.is_fully_defined(): 2508 raise ValueError("Shape of a new partitioned variable must be " 2509 "fully defined, but instead was %s." % (shape,)) 2510 if shape.ndims < 1: 2511 raise ValueError("A partitioned Variable must have rank at least 1, " 2512 "shape: %s" % shape) 2513 2514 slicing = partitioner(shape=shape, dtype=dtype) 2515 if not isinstance(slicing, collections_lib.Sequence): 2516 raise ValueError("Partitioner must return a sequence, but saw: %s" % 2517 slicing) 2518 if len(slicing) != shape.ndims: 2519 raise ValueError( 2520 "Partitioner returned a partition list that does not match the " 2521 "Variable's rank: %s vs. %s" % (slicing, shape)) 2522 if any(p < 1 for p in slicing): 2523 raise ValueError("Partitioner returned zero partitions for some axes: %s" % 2524 slicing) 2525 if sum(p > 1 for p in slicing) > 1: 2526 raise ValueError("Can only slice a variable along one dimension: " 2527 "shape: %s, partitioning: %s" % (shape, slicing)) 2528 return slicing 2529 2530 2531# TODO(slebedev): could be inlined, but 2532# `_VariableStore._get_partitioned_variable` is too complex even 2533# without this logic. 2534def _get_slice_dim_and_num_slices(slicing): 2535 """Get slicing dimension and number of slices from the partitioner output.""" 2536 for slice_dim, num_slices in enumerate(slicing): 2537 if num_slices > 1: 2538 break 2539 else: 2540 # Degenerate case: no partitioning applied. 2541 slice_dim = 0 2542 num_slices = 1 2543 return slice_dim, num_slices 2544 2545 2546def _iter_slices(full_shape, num_slices, slice_dim): 2547 """Slices a given a shape along the specified dimension.""" 2548 num_slices_with_excess = full_shape[slice_dim] % num_slices 2549 offset = [0] * len(full_shape) 2550 min_slice_len = full_shape[slice_dim] // num_slices 2551 for i in xrange(num_slices): 2552 shape = full_shape[:] 2553 shape[slice_dim] = min_slice_len + bool(i < num_slices_with_excess) 2554 yield offset[:], shape 2555 offset[slice_dim] += shape[slice_dim] 2556 2557 2558def default_variable_creator(next_creator=None, **kwargs): 2559 """Default variable creator.""" 2560 assert next_creator is None 2561 initial_value = kwargs.get("initial_value", None) 2562 trainable = kwargs.get("trainable", None) 2563 collections = kwargs.get("collections", None) 2564 validate_shape = kwargs.get("validate_shape", True) 2565 caching_device = kwargs.get("caching_device", None) 2566 name = kwargs.get("name", None) 2567 variable_def = kwargs.get("variable_def", None) 2568 dtype = kwargs.get("dtype", None) 2569 expected_shape = kwargs.get("expected_shape", None) 2570 import_scope = kwargs.get("import_scope", None) 2571 constraint = kwargs.get("constraint", None) 2572 use_resource = kwargs.get("use_resource", None) 2573 synchronization = kwargs.get("synchronization", None) 2574 aggregation = kwargs.get("aggregation", None) 2575 shape = kwargs.get("shape", None) 2576 2577 if use_resource is None: 2578 use_resource = get_variable_scope().use_resource 2579 if use_resource is None: 2580 use_resource = _DEFAULT_USE_RESOURCE 2581 use_resource = use_resource or context.executing_eagerly() 2582 if use_resource: 2583 distribute_strategy = kwargs.get("distribute_strategy", None) 2584 return resource_variable_ops.ResourceVariable( 2585 initial_value=initial_value, 2586 trainable=trainable, 2587 collections=collections, 2588 validate_shape=validate_shape, 2589 caching_device=caching_device, 2590 name=name, 2591 dtype=dtype, 2592 constraint=constraint, 2593 variable_def=variable_def, 2594 import_scope=import_scope, 2595 distribute_strategy=distribute_strategy, 2596 synchronization=synchronization, 2597 aggregation=aggregation, 2598 shape=shape) 2599 else: 2600 return variables.RefVariable( 2601 initial_value=initial_value, 2602 trainable=trainable, 2603 collections=collections, 2604 validate_shape=validate_shape, 2605 caching_device=caching_device, 2606 name=name, 2607 dtype=dtype, 2608 constraint=constraint, 2609 variable_def=variable_def, 2610 expected_shape=expected_shape, 2611 import_scope=import_scope, 2612 synchronization=synchronization, 2613 aggregation=aggregation, 2614 shape=shape) 2615 2616 2617def default_variable_creator_v2(next_creator=None, **kwargs): 2618 """Default variable creator.""" 2619 assert next_creator is None 2620 initial_value = kwargs.get("initial_value", None) 2621 trainable = kwargs.get("trainable", None) 2622 validate_shape = kwargs.get("validate_shape", True) 2623 caching_device = kwargs.get("caching_device", None) 2624 name = kwargs.get("name", None) 2625 variable_def = kwargs.get("variable_def", None) 2626 dtype = kwargs.get("dtype", None) 2627 import_scope = kwargs.get("import_scope", None) 2628 constraint = kwargs.get("constraint", None) 2629 distribute_strategy = kwargs.get("distribute_strategy", None) 2630 synchronization = kwargs.get("synchronization", None) 2631 aggregation = kwargs.get("aggregation", None) 2632 shape = kwargs.get("shape", None) 2633 2634 return resource_variable_ops.ResourceVariable( 2635 initial_value=initial_value, 2636 trainable=trainable, 2637 validate_shape=validate_shape, 2638 caching_device=caching_device, 2639 name=name, 2640 dtype=dtype, 2641 constraint=constraint, 2642 variable_def=variable_def, 2643 import_scope=import_scope, 2644 distribute_strategy=distribute_strategy, 2645 synchronization=synchronization, 2646 aggregation=aggregation, 2647 shape=shape) 2648 2649 2650variables.default_variable_creator = default_variable_creator 2651variables.default_variable_creator_v2 = default_variable_creator_v2 2652 2653 2654def _make_getter(captured_getter, captured_previous): 2655 """Gets around capturing loop variables in python being broken.""" 2656 return lambda **kwargs: captured_getter(captured_previous, **kwargs) 2657 2658 2659# TODO(apassos) remove forwarding symbol 2660variable = variables.VariableV1 2661 2662 2663@tf_export(v1=["variable_creator_scope"]) 2664@tf_contextlib.contextmanager 2665def variable_creator_scope_v1(variable_creator): 2666 """Scope which defines a variable creation function to be used by variable(). 2667 2668 variable_creator is expected to be a function with the following signature: 2669 2670 ``` 2671 def variable_creator(next_creator, **kwargs) 2672 ``` 2673 2674 The creator is supposed to eventually call the next_creator to create a 2675 variable if it does want to create a variable and not call Variable or 2676 ResourceVariable directly. This helps make creators composable. A creator may 2677 choose to create multiple variables, return already existing variables, or 2678 simply register that a variable was created and defer to the next creators in 2679 line. Creators can also modify the keyword arguments seen by the next 2680 creators. 2681 2682 Custom getters in the variable scope will eventually resolve down to these 2683 custom creators when they do create variables. 2684 2685 The valid keyword arguments in kwds are: 2686 2687 * initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 2688 which is the initial value for the Variable. The initial value must have 2689 a shape specified unless `validate_shape` is set to False. Can also be a 2690 callable with no argument that returns the initial value when called. In 2691 that case, `dtype` must be specified. (Note that initializer functions 2692 from init_ops.py must first be bound to a shape before being used here.) 2693 * trainable: If `True`, the default, also adds the variable to the graph 2694 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 2695 the default list of variables to use by the `Optimizer` classes. 2696 `trainable` defaults to `True`, unless `synchronization` is 2697 set to `ON_READ`, in which case it defaults to `False`. 2698 * collections: List of graph collections keys. The new variable is added to 2699 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 2700 * validate_shape: If `False`, allows the variable to be initialized with a 2701 value of unknown shape. If `True`, the default, the shape of 2702 `initial_value` must be known. 2703 * caching_device: Optional device string describing where the Variable 2704 should be cached for reading. Defaults to the Variable's device. 2705 If not `None`, caches on another device. Typical use is to cache 2706 on the device where the Ops using the Variable reside, to deduplicate 2707 copying through `Switch` and other conditional statements. 2708 * name: Optional name for the variable. Defaults to `'Variable'` and gets 2709 uniquified automatically. 2710 * dtype: If set, initial_value will be converted to the given type. 2711 If `None`, either the datatype will be kept (if `initial_value` is 2712 a Tensor), or `convert_to_tensor` will decide. 2713 * constraint: A constraint function to be applied to the variable after 2714 updates by some algorithms. 2715 * use_resource: if True, a ResourceVariable is always created. 2716 * synchronization: Indicates when a distributed a variable will be 2717 aggregated. Accepted values are constants defined in the class 2718 `tf.VariableSynchronization`. By default the synchronization is set to 2719 `AUTO` and the current `DistributionStrategy` chooses 2720 when to synchronize. 2721 * aggregation: Indicates how a distributed variable will be aggregated. 2722 Accepted values are constants defined in the class 2723 `tf.VariableAggregation`. 2724 2725 This set may grow over time, so it's important the signature of creators is as 2726 mentioned above. 2727 2728 Args: 2729 variable_creator: the passed creator 2730 2731 Yields: 2732 A scope in which the creator is active 2733 """ 2734 with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access 2735 yield 2736 2737 2738# Note: only the docstrings differ between this and v1. 2739@tf_export("variable_creator_scope", v1=[]) 2740@tf_contextlib.contextmanager 2741def variable_creator_scope(variable_creator): 2742 """Scope which defines a variable creation function to be used by variable(). 2743 2744 variable_creator is expected to be a function with the following signature: 2745 2746 ``` 2747 def variable_creator(next_creator, **kwargs) 2748 ``` 2749 2750 The creator is supposed to eventually call the next_creator to create a 2751 variable if it does want to create a variable and not call Variable or 2752 ResourceVariable directly. This helps make creators composable. A creator may 2753 choose to create multiple variables, return already existing variables, or 2754 simply register that a variable was created and defer to the next creators in 2755 line. Creators can also modify the keyword arguments seen by the next 2756 creators. 2757 2758 Custom getters in the variable scope will eventually resolve down to these 2759 custom creators when they do create variables. 2760 2761 The valid keyword arguments in kwds are: 2762 2763 * initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 2764 which is the initial value for the Variable. The initial value must have 2765 a shape specified unless `validate_shape` is set to False. Can also be a 2766 callable with no argument that returns the initial value when called. In 2767 that case, `dtype` must be specified. (Note that initializer functions 2768 from init_ops.py must first be bound to a shape before being used here.) 2769 * trainable: If `True`, the default, GradientTapes automatically watch 2770 uses of this Variable. 2771 * validate_shape: If `False`, allows the variable to be initialized with a 2772 value of unknown shape. If `True`, the default, the shape of 2773 `initial_value` must be known. 2774 * caching_device: Optional device string describing where the Variable 2775 should be cached for reading. Defaults to the Variable's device. 2776 If not `None`, caches on another device. Typical use is to cache 2777 on the device where the Ops using the Variable reside, to deduplicate 2778 copying through `Switch` and other conditional statements. 2779 * name: Optional name for the variable. Defaults to `'Variable'` and gets 2780 uniquified automatically. 2781 dtype: If set, initial_value will be converted to the given type. 2782 If `None`, either the datatype will be kept (if `initial_value` is 2783 a Tensor), or `convert_to_tensor` will decide. 2784 * constraint: A constraint function to be applied to the variable after 2785 updates by some algorithms. 2786 * synchronization: Indicates when a distributed a variable will be 2787 aggregated. Accepted values are constants defined in the class 2788 `tf.VariableSynchronization`. By default the synchronization is set to 2789 `AUTO` and the current `DistributionStrategy` chooses 2790 when to synchronize. 2791 * aggregation: Indicates how a distributed variable will be aggregated. 2792 Accepted values are constants defined in the class 2793 `tf.VariableAggregation`. 2794 2795 This set may grow over time, so it's important the signature of creators is as 2796 mentioned above. 2797 2798 Args: 2799 variable_creator: the passed creator 2800 2801 Yields: 2802 A scope in which the creator is active 2803 """ 2804 with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access 2805 yield 2806