1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Variable class.""" 16 17import enum 18import functools 19import itertools 20import os 21 22from tensorflow.core.framework import attr_value_pb2 23from tensorflow.core.framework import variable_pb2 24from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import 25from tensorflow.python.eager import context 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import indexed_slices 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import control_flow_ops 32from tensorflow.python.ops import gen_array_ops 33from tensorflow.python.ops import gen_math_ops 34from tensorflow.python.ops import gen_state_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import state_ops 37from tensorflow.python.platform import tf_logging as logging 38from tensorflow.python.trackable import base as trackable 39from tensorflow.python.types import core 40from tensorflow.python.util import _pywrap_utils 41from tensorflow.python.util import compat 42from tensorflow.python.util import object_identity 43from tensorflow.python.util import tf_should_use 44from tensorflow.python.util import traceback_utils 45from tensorflow.python.util.deprecation import deprecated 46from tensorflow.python.util.deprecation import deprecated_args 47from tensorflow.python.util import traceback_utils 48from tensorflow.python.util.tf_export import tf_export 49 50 51def default_variable_creator(_, **kwds): 52 del kwds 53 raise NotImplementedError("variable_scope needs to be imported") 54 55 56def default_variable_creator_v2(_, **kwds): 57 del kwds 58 raise NotImplementedError("variable_scope needs to be imported") 59 60 61def _make_getter(captured_getter, captured_previous): 62 """To avoid capturing loop variables.""" 63 64 def getter(**kwargs): 65 return captured_getter(captured_previous, **kwargs) 66 67 return getter 68 69 70@tf_export("VariableSynchronization") 71class VariableSynchronization(enum.Enum): 72 """Indicates when a distributed variable will be synced. 73 74 * `AUTO`: Indicates that the synchronization will be determined by the current 75 `DistributionStrategy` (eg. With `MirroredStrategy` this would be 76 `ON_WRITE`). 77 * `NONE`: Indicates that there will only be one copy of the variable, so 78 there is no need to sync. 79 * `ON_WRITE`: Indicates that the variable will be updated across devices 80 every time it is written. 81 * `ON_READ`: Indicates that the variable will be aggregated across devices 82 when it is read (eg. when checkpointing or when evaluating an op that uses 83 the variable). 84 85 Example: 86 >>> temp_grad=[tf.Variable([0.], trainable=False, 87 ... synchronization=tf.VariableSynchronization.ON_READ, 88 ... aggregation=tf.VariableAggregation.MEAN 89 ... )] 90 """ 91 AUTO = 0 92 NONE = 1 93 ON_WRITE = 2 94 ON_READ = 3 95 96 97# LINT.IfChange 98@tf_export("VariableAggregation", v1=[]) 99class VariableAggregationV2(enum.Enum): 100 """Indicates how a distributed variable will be aggregated. 101 102 `tf.distribute.Strategy` distributes a model by making multiple copies 103 (called "replicas") acting data-parallel on different elements of the input 104 batch. When performing some variable-update operation, say 105 `var.assign_add(x)`, in a model, we need to resolve how to combine the 106 different values for `x` computed in the different replicas. 107 108 * `NONE`: This is the default, giving an error if you use a 109 variable-update operation with multiple replicas. 110 * `SUM`: Add the updates across replicas. 111 * `MEAN`: Take the arithmetic mean ("average") of the updates across replicas. 112 * `ONLY_FIRST_REPLICA`: This is for when every replica is performing the same 113 update, but we only want to perform the update once. Used, e.g., for the 114 global step counter. 115 """ 116 NONE = 0 117 SUM = 1 118 MEAN = 2 119 ONLY_FIRST_REPLICA = 3 120 121 def __hash__(self): 122 return hash(self.value) 123 124 def __eq__(self, other): 125 if self is other: 126 return True 127 elif isinstance(other, VariableAggregation): 128 return int(self.value) == int(other.value) 129 else: 130 return False 131 132 133@tf_export(v1=["VariableAggregation"]) 134class VariableAggregation(enum.Enum): 135 NONE = 0 136 SUM = 1 137 MEAN = 2 138 ONLY_FIRST_REPLICA = 3 139 ONLY_FIRST_TOWER = 3 # DEPRECATED 140 141 def __hash__(self): 142 return hash(self.value) 143 144 145# LINT.ThenChange(//tensorflow/core/framework/variable.proto) 146# 147# Note that we are currently relying on the integer values of the Python enums 148# matching the integer values of the proto enums. 149 150VariableAggregation.__doc__ = ( 151 VariableAggregationV2.__doc__ + 152 "* `ONLY_FIRST_TOWER`: Deprecated alias for `ONLY_FIRST_REPLICA`.\n ") 153 154 155def validate_synchronization_aggregation_trainable(synchronization, aggregation, 156 trainable, name): 157 """Given user-provided variable properties, sets defaults and validates.""" 158 if aggregation is None: 159 aggregation = VariableAggregation.NONE 160 else: 161 if not isinstance(aggregation, 162 (VariableAggregation, VariableAggregationV2)): 163 try: 164 aggregation = VariableAggregationV2(aggregation) 165 except ValueError: 166 raise ValueError( 167 "Invalid variable aggregation mode: {} for variable: {}".format( 168 aggregation, name)) 169 if synchronization is None: 170 synchronization = VariableSynchronization.AUTO 171 else: 172 try: 173 synchronization = VariableSynchronization(synchronization) 174 except ValueError: 175 raise ValueError( 176 "Invalid variable synchronization mode: {} for variable: {}".format( 177 synchronization, name)) 178 if trainable is None: 179 trainable = synchronization != VariableSynchronization.ON_READ 180 return synchronization, aggregation, trainable 181 182 183class VariableMetaclass(type): 184 """Metaclass to allow construction of tf.Variable to be overridden.""" 185 186 def _variable_v1_call(cls, 187 initial_value=None, 188 trainable=None, 189 collections=None, 190 validate_shape=True, 191 caching_device=None, 192 name=None, 193 variable_def=None, 194 dtype=None, 195 expected_shape=None, 196 import_scope=None, 197 constraint=None, 198 use_resource=None, 199 synchronization=VariableSynchronization.AUTO, 200 aggregation=VariableAggregation.NONE, 201 shape=None): 202 """Call on Variable class. Useful to force the signature.""" 203 previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs) 204 for _, getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access 205 previous_getter = _make_getter(getter, previous_getter) 206 207 # Reset `aggregation` that is explicitly set as `None` to the enum NONE. 208 if aggregation is None: 209 aggregation = VariableAggregation.NONE 210 return previous_getter( 211 initial_value=initial_value, 212 trainable=trainable, 213 collections=collections, 214 validate_shape=validate_shape, 215 caching_device=caching_device, 216 name=name, 217 variable_def=variable_def, 218 dtype=dtype, 219 expected_shape=expected_shape, 220 import_scope=import_scope, 221 constraint=constraint, 222 use_resource=use_resource, 223 synchronization=synchronization, 224 aggregation=aggregation, 225 shape=shape) 226 227 def _variable_v2_call(cls, 228 initial_value=None, 229 trainable=None, 230 validate_shape=True, 231 caching_device=None, 232 name=None, 233 variable_def=None, 234 dtype=None, 235 import_scope=None, 236 constraint=None, 237 synchronization=VariableSynchronization.AUTO, 238 aggregation=VariableAggregation.NONE, 239 shape=None): 240 """Call on Variable class. Useful to force the signature.""" 241 previous_getter = lambda **kws: default_variable_creator_v2(None, **kws) 242 for _, getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access 243 previous_getter = _make_getter(getter, previous_getter) 244 245 # Reset `aggregation` that is explicitly set as `None` to the enum NONE. 246 if aggregation is None: 247 aggregation = VariableAggregation.NONE 248 return previous_getter( 249 initial_value=initial_value, 250 trainable=trainable, 251 validate_shape=validate_shape, 252 caching_device=caching_device, 253 name=name, 254 variable_def=variable_def, 255 dtype=dtype, 256 import_scope=import_scope, 257 constraint=constraint, 258 synchronization=synchronization, 259 aggregation=aggregation, 260 shape=shape) 261 262 @traceback_utils.filter_traceback 263 def __call__(cls, *args, **kwargs): 264 if cls is VariableV1: 265 return cls._variable_v1_call(*args, **kwargs) 266 elif cls is Variable: 267 return cls._variable_v2_call(*args, **kwargs) 268 else: 269 return super(VariableMetaclass, cls).__call__(*args, **kwargs) 270 271 272@tf_export("Variable", v1=[]) 273# TODO(mdan): This should subclass core.Tensor, and not all its subclasses? 274class Variable(trackable.Trackable, metaclass=VariableMetaclass): 275 """See the [variable guide](https://tensorflow.org/guide/variable). 276 277 A variable maintains shared, persistent state manipulated by a program. 278 279 The `Variable()` constructor requires an initial value for the variable, which 280 can be a `Tensor` of any type and shape. This initial value defines the type 281 and shape of the variable. After construction, the type and shape of the 282 variable are fixed. The value can be changed using one of the assign methods. 283 284 >>> v = tf.Variable(1.) 285 >>> v.assign(2.) 286 <tf.Variable ... shape=() dtype=float32, numpy=2.0> 287 >>> v.assign_add(0.5) 288 <tf.Variable ... shape=() dtype=float32, numpy=2.5> 289 290 The `shape` argument to `Variable`'s constructor allows you to construct a 291 variable with a less defined shape than its `initial_value`: 292 293 >>> v = tf.Variable(1., shape=tf.TensorShape(None)) 294 >>> v.assign([[1.]]) 295 <tf.Variable ... shape=<unknown> dtype=float32, numpy=array([[1.]], ...)> 296 297 Just like any `Tensor`, variables created with `Variable()` can be used as 298 inputs to operations. Additionally, all the operators overloaded for the 299 `Tensor` class are carried over to variables. 300 301 >>> w = tf.Variable([[1.], [2.]]) 302 >>> x = tf.constant([[3., 4.]]) 303 >>> tf.matmul(w, x) 304 <tf.Tensor:... shape=(2, 2), ... numpy= 305 array([[3., 4.], 306 [6., 8.]], dtype=float32)> 307 >>> tf.sigmoid(w + x) 308 <tf.Tensor:... shape=(2, 2), ...> 309 310 When building a machine learning model it is often convenient to distinguish 311 between variables holding trainable model parameters and other variables such 312 as a `step` variable used to count training steps. To make this easier, the 313 variable constructor supports a `trainable=<bool>` 314 parameter. `tf.GradientTape` watches trainable variables by default: 315 316 >>> with tf.GradientTape(persistent=True) as tape: 317 ... trainable = tf.Variable(1.) 318 ... non_trainable = tf.Variable(2., trainable=False) 319 ... x1 = trainable * 2. 320 ... x2 = non_trainable * 3. 321 >>> tape.gradient(x1, trainable) 322 <tf.Tensor:... shape=(), dtype=float32, numpy=2.0> 323 >>> assert tape.gradient(x2, non_trainable) is None # Unwatched 324 325 Variables are automatically tracked when assigned to attributes of types 326 inheriting from `tf.Module`. 327 328 >>> m = tf.Module() 329 >>> m.v = tf.Variable([1.]) 330 >>> m.trainable_variables 331 (<tf.Variable ... shape=(1,) ... numpy=array([1.], dtype=float32)>,) 332 333 This tracking then allows saving variable values to 334 [training checkpoints](https://www.tensorflow.org/guide/checkpoint), or to 335 [SavedModels](https://www.tensorflow.org/guide/saved_model) which include 336 serialized TensorFlow graphs. 337 338 Variables are often captured and manipulated by `tf.function`s. This works the 339 same way the un-decorated function would have: 340 341 >>> v = tf.Variable(0.) 342 >>> read_and_decrement = tf.function(lambda: v.assign_sub(0.1)) 343 >>> read_and_decrement() 344 <tf.Tensor: shape=(), dtype=float32, numpy=-0.1> 345 >>> read_and_decrement() 346 <tf.Tensor: shape=(), dtype=float32, numpy=-0.2> 347 348 Variables created inside a `tf.function` must be owned outside the function 349 and be created only once: 350 351 >>> class M(tf.Module): 352 ... @tf.function 353 ... def __call__(self, x): 354 ... if not hasattr(self, "v"): # Or set self.v to None in __init__ 355 ... self.v = tf.Variable(x) 356 ... return self.v * x 357 >>> m = M() 358 >>> m(2.) 359 <tf.Tensor: shape=(), dtype=float32, numpy=4.0> 360 >>> m(3.) 361 <tf.Tensor: shape=(), dtype=float32, numpy=6.0> 362 >>> m.v 363 <tf.Variable ... shape=() dtype=float32, numpy=2.0> 364 365 See the `tf.function` documentation for details. 366 """ 367 368 @deprecated_args( 369 None, "A variable's value can be manually cached by calling " 370 "tf.Variable.read_value() under a tf.device scope. The caching_device " 371 "argument does not work properly.", "caching_device") 372 def __init__(self, 373 initial_value=None, 374 trainable=None, 375 validate_shape=True, 376 caching_device=None, 377 name=None, 378 variable_def=None, 379 dtype=None, 380 import_scope=None, 381 constraint=None, 382 synchronization=VariableSynchronization.AUTO, 383 aggregation=VariableAggregation.NONE, 384 shape=None): 385 """Creates a new variable with value `initial_value`. 386 387 Args: 388 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 389 which is the initial value for the Variable. The initial value must have 390 a shape specified unless `validate_shape` is set to False. Can also be a 391 callable with no argument that returns the initial value when called. In 392 that case, `dtype` must be specified. (Note that initializer functions 393 from init_ops.py must first be bound to a shape before being used here.) 394 trainable: If `True`, GradientTapes automatically watch uses of this 395 variable. Defaults to `True`, unless `synchronization` is set to 396 `ON_READ`, in which case it defaults to `False`. 397 validate_shape: If `False`, allows the variable to be initialized with a 398 value of unknown shape. If `True`, the default, the shape of 399 `initial_value` must be known. 400 caching_device: Note: This argument is only valid when using a v1-style 401 `Session`. Optional device string describing where the Variable should 402 be cached for reading. Defaults to the Variable's device. If not `None`, 403 caches on another device. Typical use is to cache on the device where 404 the Ops using the Variable reside, to deduplicate copying through 405 `Switch` and other conditional statements. 406 name: Optional name for the variable. Defaults to `'Variable'` and gets 407 uniquified automatically. 408 variable_def: `VariableDef` protocol buffer. If not `None`, recreates the 409 Variable object with its contents, referencing the variable's nodes in 410 the graph, which must already exist. The graph is not changed. 411 `variable_def` and the other arguments are mutually exclusive. 412 dtype: If set, initial_value will be converted to the given type. If 413 `None`, either the datatype will be kept (if `initial_value` is a 414 Tensor), or `convert_to_tensor` will decide. 415 import_scope: Optional `string`. Name scope to add to the `Variable.` Only 416 used when initializing from protocol buffer. 417 constraint: An optional projection function to be applied to the variable 418 after being updated by an `Optimizer` (e.g. used to implement norm 419 constraints or value constraints for layer weights). The function must 420 take as input the unprojected Tensor representing the value of the 421 variable and return the Tensor for the projected value (which must have 422 the same shape). Constraints are not safe to use when doing asynchronous 423 distributed training. 424 synchronization: Indicates when a distributed a variable will be 425 aggregated. Accepted values are constants defined in the class 426 `tf.VariableSynchronization`. By default the synchronization is set to 427 `AUTO` and the current `DistributionStrategy` chooses when to 428 synchronize. 429 aggregation: Indicates how a distributed variable will be aggregated. 430 Accepted values are constants defined in the class 431 `tf.VariableAggregation`. 432 shape: (optional) The shape of this variable. If None, the shape of 433 `initial_value` will be used. When setting this argument to 434 `tf.TensorShape(None)` (representing an unspecified shape), the variable 435 can be assigned with values of different shapes. 436 437 Raises: 438 ValueError: If both `variable_def` and initial_value are specified. 439 ValueError: If the initial value is not specified, or does not have a 440 shape and `validate_shape` is `True`. 441 """ 442 raise NotImplementedError 443 444 def __repr__(self): 445 raise NotImplementedError 446 447 def value(self): 448 """Returns the last snapshot of this variable. 449 450 You usually do not need to call this method as all ops that need the value 451 of the variable call it automatically through a `convert_to_tensor()` call. 452 453 Returns a `Tensor` which holds the value of the variable. You can not 454 assign a new value to this tensor as it is not a reference to the variable. 455 456 To avoid copies, if the consumer of the returned value is on the same device 457 as the variable, this actually returns the live value of the variable, not 458 a copy. Updates to the variable are seen by the consumer. If the consumer 459 is on a different device it will get a copy of the variable. 460 461 Returns: 462 A `Tensor` containing the value of the variable. 463 """ 464 raise NotImplementedError 465 466 def read_value(self): 467 """Returns the value of this variable, read in the current context. 468 469 Can be different from value() if it's on another device, with control 470 dependencies, etc. 471 472 Returns: 473 A `Tensor` containing the value of the variable. 474 """ 475 raise NotImplementedError 476 477 def set_shape(self, shape): 478 """Overrides the shape for this variable. 479 480 Args: 481 shape: the `TensorShape` representing the overridden shape. 482 """ 483 raise NotImplementedError 484 485 @property 486 def trainable(self): 487 raise NotImplementedError 488 489 @property 490 def synchronization(self): 491 raise NotImplementedError 492 493 @property 494 def aggregation(self): 495 raise NotImplementedError 496 497 def eval(self, session=None): 498 """In a session, computes and returns the value of this variable. 499 500 This is not a graph construction method, it does not add ops to the graph. 501 502 This convenience method requires a session where the graph 503 containing this variable has been launched. If no session is 504 passed, the default session is used. See `tf.compat.v1.Session` for more 505 information on launching a graph and on sessions. 506 507 ```python 508 v = tf.Variable([1, 2]) 509 init = tf.compat.v1.global_variables_initializer() 510 511 with tf.compat.v1.Session() as sess: 512 sess.run(init) 513 # Usage passing the session explicitly. 514 print(v.eval(sess)) 515 # Usage with the default session. The 'with' block 516 # above makes 'sess' the default session. 517 print(v.eval()) 518 ``` 519 520 Args: 521 session: The session to use to evaluate this variable. If none, the 522 default session is used. 523 524 Returns: 525 A numpy `ndarray` with a copy of the value of this variable. 526 """ 527 raise NotImplementedError 528 529 @deprecated( 530 None, "Use Variable.read_value. Variables in 2.X are initialized " 531 "automatically both in eager and graph (inside tf.defun) contexts.") 532 def initialized_value(self): 533 """Returns the value of the initialized variable. 534 535 You should use this instead of the variable itself to initialize another 536 variable with a value that depends on the value of this variable. 537 538 ```python 539 # Initialize 'v' with a random tensor. 540 v = tf.Variable(tf.random.truncated_normal([10, 40])) 541 # Use `initialized_value` to guarantee that `v` has been 542 # initialized before its value is used to initialize `w`. 543 # The random values are picked only once. 544 w = tf.Variable(v.initialized_value() * 2.0) 545 ``` 546 547 Returns: 548 A `Tensor` holding the value of this variable after its initializer 549 has run. 550 """ 551 with ops.init_scope(): 552 return control_flow_ops.cond( 553 is_variable_initialized(self), self.read_value, 554 lambda: self.initial_value) 555 556 @property 557 def initial_value(self): 558 """Returns the Tensor used as the initial value for the variable. 559 560 Note that this is different from `initialized_value()` which runs 561 the op that initializes the variable before returning its value. 562 This method returns the tensor that is used by the op that initializes 563 the variable. 564 565 Returns: 566 A `Tensor`. 567 """ 568 raise NotImplementedError 569 570 @property 571 def constraint(self): 572 """Returns the constraint function associated with this variable. 573 574 Returns: 575 The constraint function that was passed to the variable constructor. 576 Can be `None` if no constraint was passed. 577 """ 578 raise NotImplementedError 579 580 def assign(self, value, use_locking=False, name=None, read_value=True): 581 """Assigns a new value to the variable. 582 583 This is essentially a shortcut for `assign(self, value)`. 584 585 Args: 586 value: A `Tensor`. The new value for this variable. 587 use_locking: If `True`, use locking during the assignment. 588 name: The name of the operation to be created 589 read_value: if True, will return something which evaluates to the new 590 value of the variable; if False will return the assign op. 591 592 Returns: 593 The updated variable. If `read_value` is false, instead returns None in 594 Eager mode and the assign op in graph mode. 595 """ 596 raise NotImplementedError 597 598 def assign_add(self, delta, use_locking=False, name=None, read_value=True): 599 """Adds a value to this variable. 600 601 This is essentially a shortcut for `assign_add(self, delta)`. 602 603 Args: 604 delta: A `Tensor`. The value to add to this variable. 605 use_locking: If `True`, use locking during the operation. 606 name: The name of the operation to be created 607 read_value: if True, will return something which evaluates to the new 608 value of the variable; if False will return the assign op. 609 610 Returns: 611 The updated variable. If `read_value` is false, instead returns None in 612 Eager mode and the assign op in graph mode. 613 """ 614 raise NotImplementedError 615 616 def assign_sub(self, delta, use_locking=False, name=None, read_value=True): 617 """Subtracts a value from this variable. 618 619 This is essentially a shortcut for `assign_sub(self, delta)`. 620 621 Args: 622 delta: A `Tensor`. The value to subtract from this variable. 623 use_locking: If `True`, use locking during the operation. 624 name: The name of the operation to be created 625 read_value: if True, will return something which evaluates to the new 626 value of the variable; if False will return the assign op. 627 628 Returns: 629 The updated variable. If `read_value` is false, instead returns None in 630 Eager mode and the assign op in graph mode. 631 """ 632 raise NotImplementedError 633 634 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 635 """Subtracts `tf.IndexedSlices` from this variable. 636 637 Args: 638 sparse_delta: `tf.IndexedSlices` to be subtracted from this variable. 639 use_locking: If `True`, use locking during the operation. 640 name: the name of the operation. 641 642 Returns: 643 The updated variable. 644 645 Raises: 646 TypeError: if `sparse_delta` is not an `IndexedSlices`. 647 """ 648 raise NotImplementedError 649 650 def scatter_add(self, sparse_delta, use_locking=False, name=None): 651 """Adds `tf.IndexedSlices` to this variable. 652 653 Args: 654 sparse_delta: `tf.IndexedSlices` to be added to this variable. 655 use_locking: If `True`, use locking during the operation. 656 name: the name of the operation. 657 658 Returns: 659 The updated variable. 660 661 Raises: 662 TypeError: if `sparse_delta` is not an `IndexedSlices`. 663 """ 664 raise NotImplementedError 665 666 def scatter_max(self, sparse_delta, use_locking=False, name=None): 667 """Updates this variable with the max of `tf.IndexedSlices` and itself. 668 669 Args: 670 sparse_delta: `tf.IndexedSlices` to use as an argument of max with this 671 variable. 672 use_locking: If `True`, use locking during the operation. 673 name: the name of the operation. 674 675 Returns: 676 The updated variable. 677 678 Raises: 679 TypeError: if `sparse_delta` is not an `IndexedSlices`. 680 """ 681 raise NotImplementedError 682 683 def scatter_min(self, sparse_delta, use_locking=False, name=None): 684 """Updates this variable with the min of `tf.IndexedSlices` and itself. 685 686 Args: 687 sparse_delta: `tf.IndexedSlices` to use as an argument of min with this 688 variable. 689 use_locking: If `True`, use locking during the operation. 690 name: the name of the operation. 691 692 Returns: 693 The updated variable. 694 695 Raises: 696 TypeError: if `sparse_delta` is not an `IndexedSlices`. 697 """ 698 raise NotImplementedError 699 700 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 701 """Multiply this variable by `tf.IndexedSlices`. 702 703 Args: 704 sparse_delta: `tf.IndexedSlices` to multiply this variable by. 705 use_locking: If `True`, use locking during the operation. 706 name: the name of the operation. 707 708 Returns: 709 The updated variable. 710 711 Raises: 712 TypeError: if `sparse_delta` is not an `IndexedSlices`. 713 """ 714 raise NotImplementedError 715 716 def scatter_div(self, sparse_delta, use_locking=False, name=None): 717 """Divide this variable by `tf.IndexedSlices`. 718 719 Args: 720 sparse_delta: `tf.IndexedSlices` to divide this variable by. 721 use_locking: If `True`, use locking during the operation. 722 name: the name of the operation. 723 724 Returns: 725 The updated variable. 726 727 Raises: 728 TypeError: if `sparse_delta` is not an `IndexedSlices`. 729 """ 730 raise NotImplementedError 731 732 def scatter_update(self, sparse_delta, use_locking=False, name=None): 733 """Assigns `tf.IndexedSlices` to this variable. 734 735 Args: 736 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 737 use_locking: If `True`, use locking during the operation. 738 name: the name of the operation. 739 740 Returns: 741 The updated variable. 742 743 Raises: 744 TypeError: if `sparse_delta` is not an `IndexedSlices`. 745 """ 746 raise NotImplementedError 747 748 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 749 """Assigns `tf.IndexedSlices` to this variable batch-wise. 750 751 Analogous to `batch_gather`. This assumes that this variable and the 752 sparse_delta IndexedSlices have a series of leading dimensions that are the 753 same for all of them, and the updates are performed on the last dimension of 754 indices. In other words, the dimensions should be the following: 755 756 `num_prefix_dims = sparse_delta.indices.ndims - 1` 757 `batch_dim = num_prefix_dims + 1` 758 `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[ 759 batch_dim:]` 760 761 where 762 763 `sparse_delta.updates.shape[:num_prefix_dims]` 764 `== sparse_delta.indices.shape[:num_prefix_dims]` 765 `== var.shape[:num_prefix_dims]` 766 767 And the operation performed can be expressed as: 768 769 `var[i_1, ..., i_n, 770 sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[ 771 i_1, ..., i_n, j]` 772 773 When sparse_delta.indices is a 1D tensor, this operation is equivalent to 774 `scatter_update`. 775 776 To avoid this operation one can looping over the first `ndims` of the 777 variable and using `scatter_update` on the subtensors that result of slicing 778 the first dimension. This is a valid option for `ndims = 1`, but less 779 efficient than this implementation. 780 781 Args: 782 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 783 use_locking: If `True`, use locking during the operation. 784 name: the name of the operation. 785 786 Returns: 787 The updated variable. 788 789 Raises: 790 TypeError: if `sparse_delta` is not an `IndexedSlices`. 791 """ 792 raise NotImplementedError 793 794 def scatter_nd_sub(self, indices, updates, name=None): 795 """Applies sparse subtraction to individual values or slices in a Variable. 796 797 Assuming the variable has rank `P` and `indices` is a `Tensor` of rank `Q`. 798 799 `indices` must be integer tensor, containing indices into self. 800 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 801 802 The innermost dimension of `indices` (with length `K`) corresponds to 803 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 804 dimension of self. 805 806 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 807 808 ``` 809 [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]]. 810 ``` 811 812 For example, say we want to add 4 scattered elements to a rank-1 tensor to 813 8 elements. In Python, that update would look like this: 814 815 ```python 816 v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 817 indices = tf.constant([[4], [3], [1] ,[7]]) 818 updates = tf.constant([9, 10, 11, 12]) 819 v.scatter_nd_sub(indices, updates) 820 print(v) 821 ``` 822 823 After the update `v` would look like this: 824 825 [1, -9, 3, -6, -4, 6, 7, -4] 826 827 See `tf.scatter_nd` for more details about how to make updates to 828 slices. 829 830 Args: 831 indices: The indices to be used in the operation. 832 updates: The values to be used in the operation. 833 name: the name of the operation. 834 835 Returns: 836 The updated variable. 837 """ 838 raise NotImplementedError 839 840 def scatter_nd_add(self, indices, updates, name=None): 841 """Applies sparse addition to individual values or slices in a Variable. 842 843 The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`. 844 845 `indices` must be integer tensor, containing indices into self. 846 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 847 848 The innermost dimension of `indices` (with length `K`) corresponds to 849 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 850 dimension of self. 851 852 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 853 854 ``` 855 [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]]. 856 ``` 857 858 For example, say we want to add 4 scattered elements to a rank-1 tensor to 859 8 elements. In Python, that update would look like this: 860 861 ```python 862 v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 863 indices = tf.constant([[4], [3], [1] ,[7]]) 864 updates = tf.constant([9, 10, 11, 12]) 865 v.scatter_nd_add(indices, updates) 866 print(v) 867 ``` 868 869 The resulting update to v would look like this: 870 871 [1, 13, 3, 14, 14, 6, 7, 20] 872 873 See `tf.scatter_nd` for more details about how to make updates to 874 slices. 875 876 Args: 877 indices: The indices to be used in the operation. 878 updates: The values to be used in the operation. 879 name: the name of the operation. 880 881 Returns: 882 The updated variable. 883 """ 884 raise NotImplementedError 885 886 def scatter_nd_update(self, indices, updates, name=None): 887 """Applies sparse assignment to individual values or slices in a Variable. 888 889 The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`. 890 891 `indices` must be integer tensor, containing indices into self. 892 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 893 894 The innermost dimension of `indices` (with length `K`) corresponds to 895 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 896 dimension of self. 897 898 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 899 900 ``` 901 [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]]. 902 ``` 903 904 For example, say we want to add 4 scattered elements to a rank-1 tensor to 905 8 elements. In Python, that update would look like this: 906 907 ```python 908 v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 909 indices = tf.constant([[4], [3], [1] ,[7]]) 910 updates = tf.constant([9, 10, 11, 12]) 911 v.scatter_nd_update(indices, updates) 912 print(v) 913 ``` 914 915 The resulting update to v would look like this: 916 917 [1, 11, 3, 10, 9, 6, 7, 12] 918 919 See `tf.scatter_nd` for more details about how to make updates to 920 slices. 921 922 Args: 923 indices: The indices to be used in the operation. 924 updates: The values to be used in the operation. 925 name: the name of the operation. 926 927 Returns: 928 The updated variable. 929 """ 930 raise NotImplementedError 931 932 def sparse_read(self, indices, name=None): 933 r"""Gather slices from params axis axis according to indices. 934 935 This function supports a subset of tf.gather, see tf.gather for details on 936 usage. 937 938 Args: 939 indices: The index `Tensor`. Must be one of the following types: `int32`, 940 `int64`. Must be in range `[0, params.shape[axis])`. 941 name: A name for the operation (optional). 942 943 Returns: 944 A `Tensor`. Has the same type as `params`. 945 """ 946 raise AttributeError 947 948 def gather_nd(self, indices, name=None): 949 r"""Gather slices from `params` into a Tensor with shape specified by `indices`. 950 951 See tf.gather_nd for details. 952 953 Args: 954 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 955 Index tensor. 956 name: A name for the operation (optional). 957 958 Returns: 959 A `Tensor`. Has the same type as `params`. 960 """ 961 raise AttributeError 962 963 @deprecated(None, "Prefer Dataset.range instead.") 964 def count_up_to(self, limit): 965 """Increments this variable until it reaches `limit`. 966 967 When that Op is run it tries to increment the variable by `1`. If 968 incrementing the variable would bring it above `limit` then the Op raises 969 the exception `OutOfRangeError`. 970 971 If no error is raised, the Op outputs the value of the variable before 972 the increment. 973 974 This is essentially a shortcut for `count_up_to(self, limit)`. 975 976 Args: 977 limit: value at which incrementing the variable raises an error. 978 979 Returns: 980 A `Tensor` that will hold the variable value before the increment. If no 981 other Op modifies this variable, the values produced will all be 982 distinct. 983 """ 984 raise NotImplementedError 985 986 @deprecated(None, 987 "Prefer Variable.assign which has equivalent behavior in 2.X.") 988 def load(self, value, session=None): 989 """Load new value into this variable. 990 991 Writes new value to variable's memory. Doesn't add ops to the graph. 992 993 This convenience method requires a session where the graph 994 containing this variable has been launched. If no session is 995 passed, the default session is used. See `tf.compat.v1.Session` for more 996 information on launching a graph and on sessions. 997 998 ```python 999 v = tf.Variable([1, 2]) 1000 init = tf.compat.v1.global_variables_initializer() 1001 1002 with tf.compat.v1.Session() as sess: 1003 sess.run(init) 1004 # Usage passing the session explicitly. 1005 v.load([2, 3], sess) 1006 print(v.eval(sess)) # prints [2 3] 1007 # Usage with the default session. The 'with' block 1008 # above makes 'sess' the default session. 1009 v.load([3, 4], sess) 1010 print(v.eval()) # prints [3 4] 1011 ``` 1012 1013 Args: 1014 value: New variable value 1015 session: The session to use to evaluate this variable. If none, the 1016 default session is used. 1017 1018 Raises: 1019 ValueError: Session is not passed and no default session 1020 """ 1021 if context.executing_eagerly(): 1022 self.assign(value) 1023 else: 1024 session = session or ops.get_default_session() 1025 if session is None: 1026 raise ValueError( 1027 "Either session argument should be provided or default session " 1028 "should be established") 1029 session.run(self.initializer, {self.initializer.inputs[1]: value}) 1030 1031 # Conversion to tensor. 1032 @staticmethod 1033 def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name 1034 """Utility function for converting a Variable to a Tensor.""" 1035 _ = name 1036 if dtype and not dtype.is_compatible_with(v.dtype): 1037 raise ValueError( 1038 f"Incompatible type conversion requested to type '{dtype.name}' for " 1039 f"variable of type '{v.dtype.name}' (Variable: {v}).") 1040 if as_ref: 1041 return v._ref() # pylint: disable=protected-access 1042 else: 1043 return v.value() 1044 1045 @classmethod 1046 def _OverloadAllOperators(cls): # pylint: disable=invalid-name 1047 """Register overloads for all operators.""" 1048 for operator in ops.Tensor.OVERLOADABLE_OPERATORS: 1049 cls._OverloadOperator(operator) 1050 # For slicing, bind getitem differently than a tensor (use SliceHelperVar 1051 # instead) 1052 # pylint: disable=protected-access 1053 setattr(cls, "__getitem__", array_ops._SliceHelperVar) 1054 1055 @classmethod 1056 def _OverloadOperator(cls, operator): # pylint: disable=invalid-name 1057 """Defer an operator overload to `ops.Tensor`. 1058 1059 We pull the operator out of ops.Tensor dynamically to avoid ordering issues. 1060 1061 Args: 1062 operator: string. The operator name. 1063 """ 1064 # We can't use the overload mechanism on __eq__ & __ne__ since __eq__ is 1065 # called when adding a variable to sets. As a result we call a.value() which 1066 # causes infinite recursion when operating within a GradientTape 1067 # TODO(gjn): Consider removing this 1068 if operator == "__eq__" or operator == "__ne__": 1069 return 1070 1071 tensor_oper = getattr(ops.Tensor, operator) 1072 1073 def _run_op(a, *args, **kwargs): 1074 # pylint: disable=protected-access 1075 return tensor_oper(a.value(), *args, **kwargs) 1076 1077 functools.update_wrapper(_run_op, tensor_oper) 1078 setattr(cls, operator, _run_op) 1079 1080 def __hash__(self): 1081 if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access 1082 raise TypeError( 1083 "Variable is unhashable. " 1084 f"Instead, use variable.ref() as the key. (Variable: {self})") 1085 else: 1086 return id(self) 1087 1088 # TODO(gjn): duplicate of math_ops.tensor_equals, consider removing 1089 def __eq__(self, other): 1090 """Compares two variables element-wise for equality.""" 1091 if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access 1092 return gen_math_ops.equal(self, other, incompatible_shape_error=False) 1093 else: 1094 # In legacy graph mode, tensor equality is object equality 1095 return self is other 1096 1097 # TODO(gjn): duplicate of math_ops.tensor_not_equals, consider removing 1098 def __ne__(self, other): 1099 """Compares two variables element-wise for equality.""" 1100 if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access 1101 return gen_math_ops.not_equal(self, other, incompatible_shape_error=False) 1102 else: 1103 # In legacy graph mode, tensor equality is object equality 1104 return self is not other 1105 1106 def __iter__(self): 1107 """When executing eagerly, iterates over the value of the variable.""" 1108 return iter(self.read_value()) 1109 1110 # NOTE(mrry): This enables the Variable's overloaded "right" binary 1111 # operators to run when the left operand is an ndarray, because it 1112 # accords the Variable class higher priority than an ndarray, or a 1113 # numpy matrix. 1114 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ 1115 # mechanism, which allows more control over how Variables interact 1116 # with ndarrays. 1117 __array_priority__ = 100 1118 1119 @property 1120 def name(self): 1121 """The name of this variable.""" 1122 raise NotImplementedError 1123 1124 @property 1125 def _shared_name(self): 1126 """The shared name of the variable. 1127 1128 Unlike name(), shared_name doesn't have ":0" suffix. It is user-specified 1129 name with name scope prefix. 1130 1131 Returns: 1132 variable name. 1133 """ 1134 return self.name[:self.name.index(":")] 1135 1136 @property 1137 def initializer(self): 1138 """The initializer operation for this variable.""" 1139 raise NotImplementedError 1140 1141 @property 1142 def device(self): 1143 """The device of this variable.""" 1144 raise NotImplementedError 1145 1146 @property 1147 def dtype(self): 1148 """The `DType` of this variable.""" 1149 raise NotImplementedError 1150 1151 @property 1152 def op(self): 1153 """The `Operation` of this variable.""" 1154 raise NotImplementedError 1155 1156 @property 1157 def graph(self): 1158 """The `Graph` of this variable.""" 1159 raise NotImplementedError 1160 1161 @property 1162 def shape(self): 1163 """The `TensorShape` of this variable. 1164 1165 Returns: 1166 A `TensorShape`. 1167 """ 1168 raise NotImplementedError 1169 1170 def get_shape(self): 1171 """Alias of `Variable.shape`.""" 1172 return self.shape 1173 1174 def _gather_saveables_for_checkpoint(self): 1175 """For implementing `Trackable`. This object is saveable on its own.""" 1176 return {trackable.VARIABLE_VALUE_KEY: self} 1177 1178 def to_proto(self, export_scope=None): 1179 """Converts a `Variable` to a `VariableDef` protocol buffer. 1180 1181 Args: 1182 export_scope: Optional `string`. Name scope to remove. 1183 1184 Returns: 1185 A `VariableDef` protocol buffer, or `None` if the `Variable` is not 1186 in the specified name scope. 1187 """ 1188 raise NotImplementedError 1189 1190 @staticmethod 1191 def from_proto(variable_def, import_scope=None): 1192 """Returns a `Variable` object created from `variable_def`.""" 1193 return RefVariable(variable_def=variable_def, import_scope=import_scope) 1194 1195 def _set_save_slice_info(self, save_slice_info): 1196 """Sets the slice info for this `Variable`. 1197 1198 Args: 1199 save_slice_info: A `Variable.SaveSliceInfo` object. 1200 """ 1201 self._save_slice_info = save_slice_info 1202 1203 def _get_save_slice_info(self): 1204 return self._save_slice_info 1205 1206 @deprecated(None, "Use ref() instead.") 1207 def experimental_ref(self): 1208 return self.ref() 1209 1210 def ref(self): 1211 # tf.Tensor also has the same ref() API. If you update the 1212 # documentation here, please update tf.Tensor.ref() as well. 1213 """Returns a hashable reference object to this Variable. 1214 1215 The primary use case for this API is to put variables in a set/dictionary. 1216 We can't put variables in a set/dictionary as `variable.__hash__()` is no 1217 longer available starting Tensorflow 2.0. 1218 1219 The following will raise an exception starting 2.0 1220 1221 >>> x = tf.Variable(5) 1222 >>> y = tf.Variable(10) 1223 >>> z = tf.Variable(10) 1224 >>> variable_set = {x, y, z} 1225 Traceback (most recent call last): 1226 ... 1227 TypeError: Variable is unhashable. Instead, use tensor.ref() as the key. 1228 >>> variable_dict = {x: 'five', y: 'ten'} 1229 Traceback (most recent call last): 1230 ... 1231 TypeError: Variable is unhashable. Instead, use tensor.ref() as the key. 1232 1233 Instead, we can use `variable.ref()`. 1234 1235 >>> variable_set = {x.ref(), y.ref(), z.ref()} 1236 >>> x.ref() in variable_set 1237 True 1238 >>> variable_dict = {x.ref(): 'five', y.ref(): 'ten', z.ref(): 'ten'} 1239 >>> variable_dict[y.ref()] 1240 'ten' 1241 1242 Also, the reference object provides `.deref()` function that returns the 1243 original Variable. 1244 1245 >>> x = tf.Variable(5) 1246 >>> x.ref().deref() 1247 <tf.Variable 'Variable:0' shape=() dtype=int32, numpy=5> 1248 """ 1249 return object_identity.Reference(self) 1250 1251 class SaveSliceInfo: 1252 """Information on how to save this Variable as a slice. 1253 1254 Provides internal support for saving variables as slices of a larger 1255 variable. This API is not public and is subject to change. 1256 1257 Available properties: 1258 1259 * full_name 1260 * full_shape 1261 * var_offset 1262 * var_shape 1263 """ 1264 1265 def __init__(self, 1266 full_name=None, 1267 full_shape=None, 1268 var_offset=None, 1269 var_shape=None, 1270 save_slice_info_def=None, 1271 import_scope=None): 1272 """Create a `SaveSliceInfo`. 1273 1274 Args: 1275 full_name: Name of the full variable of which this `Variable` is a 1276 slice. 1277 full_shape: Shape of the full variable, as a list of int. 1278 var_offset: Offset of this `Variable` into the full variable, as a list 1279 of int. 1280 var_shape: Shape of this `Variable`, as a list of int. 1281 save_slice_info_def: `SaveSliceInfoDef` protocol buffer. If not `None`, 1282 recreates the SaveSliceInfo object its contents. `save_slice_info_def` 1283 and other arguments are mutually exclusive. 1284 import_scope: Optional `string`. Name scope to add. Only used when 1285 initializing from protocol buffer. 1286 """ 1287 if save_slice_info_def: 1288 assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef) 1289 self.full_name = ops.prepend_name_scope( 1290 save_slice_info_def.full_name, import_scope=import_scope) 1291 self.full_shape = [i for i in save_slice_info_def.full_shape] 1292 self.var_offset = [i for i in save_slice_info_def.var_offset] 1293 self.var_shape = [i for i in save_slice_info_def.var_shape] 1294 else: 1295 self.full_name = full_name 1296 self.full_shape = full_shape 1297 self.var_offset = var_offset 1298 self.var_shape = var_shape 1299 1300 @property 1301 def spec(self): 1302 """Computes the spec string used for saving.""" 1303 full_shape_str = " ".join("%d" % d for d in self.full_shape) + " " 1304 sl_spec = ":".join( 1305 "%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape)) 1306 return full_shape_str + sl_spec 1307 1308 def to_proto(self, export_scope=None): 1309 """Returns a SaveSliceInfoDef() proto. 1310 1311 Args: 1312 export_scope: Optional `string`. Name scope to remove. 1313 1314 Returns: 1315 A `SaveSliceInfoDef` protocol buffer, or None if the `Variable` is not 1316 in the specified name scope. 1317 """ 1318 if (export_scope is None or self.full_name.startswith(export_scope)): 1319 save_slice_info_def = variable_pb2.SaveSliceInfoDef() 1320 save_slice_info_def.full_name = ops.strip_name_scope( 1321 self.full_name, export_scope) 1322 for i in self.full_shape: 1323 save_slice_info_def.full_shape.append(i) 1324 for i in self.var_offset: 1325 save_slice_info_def.var_offset.append(i) 1326 for i in self.var_shape: 1327 save_slice_info_def.var_shape.append(i) 1328 return save_slice_info_def 1329 else: 1330 return None 1331 1332 1333Variable._OverloadAllOperators() # pylint: disable=protected-access 1334_pywrap_utils.RegisterType("Variable", Variable) 1335 1336 1337@tf_export(v1=["Variable"]) 1338class VariableV1(Variable): 1339 """See the [Variables Guide](https://tensorflow.org/guide/variables). 1340 1341 A variable maintains state in the graph across calls to `run()`. You add a 1342 variable to the graph by constructing an instance of the class `Variable`. 1343 1344 The `Variable()` constructor requires an initial value for the variable, 1345 which can be a `Tensor` of any type and shape. The initial value defines the 1346 type and shape of the variable. After construction, the type and shape of 1347 the variable are fixed. The value can be changed using one of the assign 1348 methods. 1349 1350 If you want to change the shape of a variable later you have to use an 1351 `assign` Op with `validate_shape=False`. 1352 1353 Just like any `Tensor`, variables created with `Variable()` can be used as 1354 inputs for other Ops in the graph. Additionally, all the operators 1355 overloaded for the `Tensor` class are carried over to variables, so you can 1356 also add nodes to the graph by just doing arithmetic on variables. 1357 1358 ```python 1359 import tensorflow as tf 1360 1361 # Create a variable. 1362 w = tf.Variable(<initial-value>, name=<optional-name>) 1363 1364 # Use the variable in the graph like any Tensor. 1365 y = tf.matmul(w, ...another variable or tensor...) 1366 1367 # The overloaded operators are available too. 1368 z = tf.sigmoid(w + y) 1369 1370 # Assign a new value to the variable with `assign()` or a related method. 1371 w.assign(w + 1.0) 1372 w.assign_add(1.0) 1373 ``` 1374 1375 When you launch the graph, variables have to be explicitly initialized before 1376 you can run Ops that use their value. You can initialize a variable by 1377 running its *initializer op*, restoring the variable from a save file, or 1378 simply running an `assign` Op that assigns a value to the variable. In fact, 1379 the variable *initializer op* is just an `assign` Op that assigns the 1380 variable's initial value to the variable itself. 1381 1382 ```python 1383 # Launch the graph in a session. 1384 with tf.compat.v1.Session() as sess: 1385 # Run the variable initializer. 1386 sess.run(w.initializer) 1387 # ...you now can run ops that use the value of 'w'... 1388 ``` 1389 1390 The most common initialization pattern is to use the convenience function 1391 `global_variables_initializer()` to add an Op to the graph that initializes 1392 all the variables. You then run that Op after launching the graph. 1393 1394 ```python 1395 # Add an Op to initialize global variables. 1396 init_op = tf.compat.v1.global_variables_initializer() 1397 1398 # Launch the graph in a session. 1399 with tf.compat.v1.Session() as sess: 1400 # Run the Op that initializes global variables. 1401 sess.run(init_op) 1402 # ...you can now run any Op that uses variable values... 1403 ``` 1404 1405 If you need to create a variable with an initial value dependent on another 1406 variable, use the other variable's `initialized_value()`. This ensures that 1407 variables are initialized in the right order. 1408 1409 All variables are automatically collected in the graph where they are 1410 created. By default, the constructor adds the new variable to the graph 1411 collection `GraphKeys.GLOBAL_VARIABLES`. The convenience function 1412 `global_variables()` returns the contents of that collection. 1413 1414 When building a machine learning model it is often convenient to distinguish 1415 between variables holding the trainable model parameters and other variables 1416 such as a `global step` variable used to count training steps. To make this 1417 easier, the variable constructor supports a `trainable=<bool>` parameter. If 1418 `True`, the new variable is also added to the graph collection 1419 `GraphKeys.TRAINABLE_VARIABLES`. The convenience function 1420 `trainable_variables()` returns the contents of this collection. The 1421 various `Optimizer` classes use this collection as the default list of 1422 variables to optimize. 1423 1424 WARNING: tf.Variable objects by default have a non-intuitive memory model. A 1425 Variable is represented internally as a mutable Tensor which can 1426 non-deterministically alias other Tensors in a graph. The set of operations 1427 which consume a Variable and can lead to aliasing is undetermined and can 1428 change across TensorFlow versions. Avoid writing code which relies on the 1429 value of a Variable either changing or not changing as other operations 1430 happen. For example, using Variable objects or simple functions thereof as 1431 predicates in a `tf.cond` is dangerous and error-prone: 1432 1433 ``` 1434 v = tf.Variable(True) 1435 tf.cond(v, lambda: v.assign(False), my_false_fn) # Note: this is broken. 1436 ``` 1437 1438 Here, adding `use_resource=True` when constructing the variable will 1439 fix any nondeterminism issues: 1440 ``` 1441 v = tf.Variable(True, use_resource=True) 1442 tf.cond(v, lambda: v.assign(False), my_false_fn) 1443 ``` 1444 1445 To use the replacement for variables which does 1446 not have these issues: 1447 1448 * Add `use_resource=True` when constructing `tf.Variable`; 1449 * Call `tf.compat.v1.get_variable_scope().set_use_resource(True)` inside a 1450 `tf.compat.v1.variable_scope` before the `tf.compat.v1.get_variable()` call. 1451 """ 1452 1453 def __init__( 1454 self, # pylint: disable=super-init-not-called 1455 initial_value=None, 1456 trainable=None, 1457 collections=None, 1458 validate_shape=True, 1459 caching_device=None, 1460 name=None, 1461 variable_def=None, 1462 dtype=None, 1463 expected_shape=None, 1464 import_scope=None, 1465 constraint=None, 1466 use_resource=None, 1467 synchronization=VariableSynchronization.AUTO, 1468 aggregation=VariableAggregation.NONE, 1469 shape=None): 1470 """Creates a new variable with value `initial_value`. 1471 1472 The new variable is added to the graph collections listed in `collections`, 1473 which defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1474 1475 If `trainable` is `True` the variable is also added to the graph collection 1476 `GraphKeys.TRAINABLE_VARIABLES`. 1477 1478 This constructor creates both a `variable` Op and an `assign` Op to set the 1479 variable to its initial value. 1480 1481 Args: 1482 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 1483 which is the initial value for the Variable. The initial value must have 1484 a shape specified unless `validate_shape` is set to False. Can also be a 1485 callable with no argument that returns the initial value when called. In 1486 that case, `dtype` must be specified. (Note that initializer functions 1487 from init_ops.py must first be bound to a shape before being used here.) 1488 trainable: If `True`, also adds the variable to the graph collection 1489 `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default 1490 list of variables to use by the `Optimizer` classes. Defaults to `True`, 1491 unless `synchronization` is set to `ON_READ`, in which case it defaults 1492 to `False`. 1493 collections: List of graph collections keys. The new variable is added to 1494 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1495 validate_shape: If `False`, allows the variable to be initialized with a 1496 value of unknown shape. If `True`, the default, the shape of 1497 `initial_value` must be known. 1498 caching_device: Optional device string describing where the Variable 1499 should be cached for reading. Defaults to the Variable's device. If not 1500 `None`, caches on another device. Typical use is to cache on the device 1501 where the Ops using the Variable reside, to deduplicate copying through 1502 `Switch` and other conditional statements. 1503 name: Optional name for the variable. Defaults to `'Variable'` and gets 1504 uniquified automatically. 1505 variable_def: `VariableDef` protocol buffer. If not `None`, recreates the 1506 Variable object with its contents, referencing the variable's nodes in 1507 the graph, which must already exist. The graph is not changed. 1508 `variable_def` and the other arguments are mutually exclusive. 1509 dtype: If set, initial_value will be converted to the given type. If 1510 `None`, either the datatype will be kept (if `initial_value` is a 1511 Tensor), or `convert_to_tensor` will decide. 1512 expected_shape: A TensorShape. If set, initial_value is expected to have 1513 this shape. 1514 import_scope: Optional `string`. Name scope to add to the `Variable.` Only 1515 used when initializing from protocol buffer. 1516 constraint: An optional projection function to be applied to the variable 1517 after being updated by an `Optimizer` (e.g. used to implement norm 1518 constraints or value constraints for layer weights). The function must 1519 take as input the unprojected Tensor representing the value of the 1520 variable and return the Tensor for the projected value (which must have 1521 the same shape). Constraints are not safe to use when doing asynchronous 1522 distributed training. 1523 use_resource: whether to use resource variables. 1524 synchronization: Indicates when a distributed a variable will be 1525 aggregated. Accepted values are constants defined in the class 1526 `tf.VariableSynchronization`. By default the synchronization is set to 1527 `AUTO` and the current `DistributionStrategy` chooses when to 1528 synchronize. 1529 aggregation: Indicates how a distributed variable will be aggregated. 1530 Accepted values are constants defined in the class 1531 `tf.VariableAggregation`. 1532 shape: (optional) The shape of this variable. If None, the shape of 1533 `initial_value` will be used. When setting this argument to 1534 `tf.TensorShape(None)` (representing an unspecified shape), the variable 1535 can be assigned with values of different shapes. 1536 1537 Raises: 1538 ValueError: If both `variable_def` and initial_value are specified. 1539 ValueError: If the initial value is not specified, or does not have a 1540 shape and `validate_shape` is `True`. 1541 RuntimeError: If eager execution is enabled. 1542 """ 1543 1544 SaveSliceInfo = Variable.SaveSliceInfo 1545 1546 1547# TODO(apassos): do not repeat all comments here 1548class RefVariable(VariableV1, core.Tensor): 1549 """Ref-based implementation of variables.""" 1550 1551 def __init__( 1552 self, # pylint: disable=super-init-not-called 1553 initial_value=None, 1554 trainable=None, 1555 collections=None, 1556 validate_shape=True, 1557 caching_device=None, 1558 name=None, 1559 variable_def=None, 1560 dtype=None, 1561 expected_shape=None, 1562 import_scope=None, 1563 constraint=None, 1564 synchronization=None, 1565 aggregation=None, 1566 shape=None): 1567 """Creates a new variable with value `initial_value`. 1568 1569 The new variable is added to the graph collections listed in `collections`, 1570 which defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1571 1572 If `trainable` is `True` the variable is also added to the graph collection 1573 `GraphKeys.TRAINABLE_VARIABLES`. 1574 1575 This constructor creates both a `variable` Op and an `assign` Op to set the 1576 variable to its initial value. 1577 1578 Args: 1579 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 1580 which is the initial value for the Variable. The initial value must have 1581 a shape specified unless `validate_shape` is set to False. Can also be a 1582 callable with no argument that returns the initial value when called. In 1583 that case, `dtype` must be specified. (Note that initializer functions 1584 from init_ops.py must first be bound to a shape before being used here.) 1585 trainable: If `True`, also adds the variable to the graph collection 1586 `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default 1587 list of variables to use by the `Optimizer` classes. Defaults to `True`, 1588 unless `synchronization` is set to `ON_READ`, in which case it defaults 1589 to `False`. 1590 collections: List of graph collections keys. The new variable is added to 1591 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1592 validate_shape: If `False`, allows the variable to be initialized with a 1593 value of unknown shape. If `True`, the default, the shape of 1594 `initial_value` must be known. 1595 caching_device: Optional device string describing where the Variable 1596 should be cached for reading. Defaults to the Variable's device. If not 1597 `None`, caches on another device. Typical use is to cache on the device 1598 where the Ops using the Variable reside, to deduplicate copying through 1599 `Switch` and other conditional statements. 1600 name: Optional name for the variable. Defaults to `'Variable'` and gets 1601 uniquified automatically. 1602 variable_def: `VariableDef` protocol buffer. If not `None`, recreates the 1603 Variable object with its contents, referencing the variable's nodes in 1604 the graph, which must already exist. The graph is not changed. 1605 `variable_def` and the other arguments are mutually exclusive. 1606 dtype: If set, initial_value will be converted to the given type. If 1607 `None`, either the datatype will be kept (if `initial_value` is a 1608 Tensor), or `convert_to_tensor` will decide. 1609 expected_shape: A TensorShape. If set, initial_value is expected to have 1610 this shape. 1611 import_scope: Optional `string`. Name scope to add to the `Variable.` Only 1612 used when initializing from protocol buffer. 1613 constraint: An optional projection function to be applied to the variable 1614 after being updated by an `Optimizer` (e.g. used to implement norm 1615 constraints or value constraints for layer weights). The function must 1616 take as input the unprojected Tensor representing the value of the 1617 variable and return the Tensor for the projected value (which must have 1618 the same shape). Constraints are not safe to use when doing asynchronous 1619 distributed training. 1620 synchronization: Indicates when a distributed a variable will be 1621 aggregated. Accepted values are constants defined in the class 1622 `tf.VariableSynchronization`. By default the synchronization is set to 1623 `AUTO` and the current `DistributionStrategy` chooses when to 1624 synchronize. 1625 aggregation: Indicates how a distributed variable will be aggregated. 1626 Accepted values are constants defined in the class 1627 `tf.VariableAggregation`. 1628 shape: (optional) The shape of this variable. If None, the shape of 1629 `initial_value` will be used. When setting this argument to 1630 `tf.TensorShape(None)` (representing an unspecified shape), the variable 1631 can be assigned with values of different shapes. 1632 1633 Raises: 1634 ValueError: If both `variable_def` and initial_value are specified. 1635 ValueError: If the initial value is not specified, or does not have a 1636 shape and `validate_shape` is `True`. 1637 RuntimeError: If eager execution is enabled. 1638 """ 1639 self._in_graph_mode = True 1640 if variable_def: 1641 # If variable_def is provided, recreates the variable from its fields. 1642 if initial_value: 1643 raise ValueError("variable_def and initial_value are mutually " 1644 "exclusive.") 1645 self._init_from_proto(variable_def, import_scope=import_scope) 1646 else: 1647 # Create from initial_value. 1648 self._init_from_args( 1649 initial_value=initial_value, 1650 trainable=trainable, 1651 collections=collections, 1652 validate_shape=validate_shape, 1653 caching_device=caching_device, 1654 name=name, 1655 dtype=dtype, 1656 expected_shape=expected_shape, 1657 constraint=constraint, 1658 synchronization=synchronization, 1659 aggregation=aggregation, 1660 shape=shape) 1661 1662 def __repr__(self): 1663 if context.executing_eagerly() and not self._in_graph_mode: 1664 return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % ( 1665 self.name, self.get_shape(), self.dtype.name, 1666 ops.numpy_text(self.read_value(), is_repr=True)) 1667 else: 1668 return "<tf.Variable '%s' shape=%s dtype=%s>" % ( 1669 self.name, self.get_shape(), self.dtype.name) 1670 1671 def _init_from_args(self, 1672 initial_value=None, 1673 trainable=None, 1674 collections=None, 1675 validate_shape=True, 1676 caching_device=None, 1677 name=None, 1678 dtype=None, 1679 expected_shape=None, 1680 constraint=None, 1681 synchronization=None, 1682 aggregation=None, 1683 shape=None): 1684 """Creates a new variable from arguments. 1685 1686 Args: 1687 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 1688 which is the initial value for the Variable. The initial value must have 1689 a shape specified unless `validate_shape` is set to False. Can also be a 1690 callable with no argument that returns the initial value when called. 1691 (Note that initializer functions from init_ops.py must first be bound to 1692 a shape before being used here.) 1693 trainable: If `True`, also adds the variable to the graph collection 1694 `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default 1695 list of variables to use by the `Optimizer` classes. Defaults to `True`, 1696 unless `synchronization` is set to `ON_READ`, in which case it defaults 1697 to `False`. 1698 collections: List of graph collections keys. The new variable is added to 1699 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1700 validate_shape: If `False`, allows the variable to be initialized with a 1701 value of unknown shape. If `True`, the default, the shape of 1702 `initial_value` must be known. 1703 caching_device: Optional device string or function describing where the 1704 Variable should be cached for reading. Defaults to the Variable's 1705 device. If not `None`, caches on another device. Typical use is to 1706 cache on the device where the Ops using the Variable reside, to 1707 deduplicate copying through `Switch` and other conditional statements. 1708 name: Optional name for the variable. Defaults to `'Variable'` and gets 1709 uniquified automatically. 1710 dtype: If set, initial_value will be converted to the given type. If None, 1711 either the datatype will be kept (if initial_value is a Tensor) or 1712 float32 will be used (if it is a Python object convertible to a Tensor). 1713 expected_shape: Deprecated. Ignored. 1714 constraint: An optional projection function to be applied to the variable 1715 after being updated by an `Optimizer` (e.g. used to implement norm 1716 constraints or value constraints for layer weights). The function must 1717 take as input the unprojected Tensor representing the value of the 1718 variable and return the Tensor for the projected value (which must have 1719 the same shape). Constraints are not safe to use when doing asynchronous 1720 distributed training. 1721 synchronization: Indicates when a distributed a variable will be 1722 aggregated. Accepted values are constants defined in the class 1723 `tf.VariableSynchronization`. By default the synchronization is set to 1724 `AUTO` and the current `DistributionStrategy` chooses when to 1725 synchronize. 1726 aggregation: Indicates how a distributed variable will be aggregated. 1727 Accepted values are constants defined in the class 1728 `tf.VariableAggregation`. 1729 shape: (optional) The shape of this variable. If None, the shape of 1730 `initial_value` will be used. When setting this argument to 1731 `tf.TensorShape(None)` (representing an unspecified shape), the variable 1732 can be assigned with values of different shapes. 1733 1734 Raises: 1735 ValueError: If the initial value is not specified, or does not have a 1736 shape and `validate_shape` is `True`. 1737 RuntimeError: If lifted into the eager context. 1738 """ 1739 _ = expected_shape 1740 if initial_value is None: 1741 raise ValueError("initial_value must be specified.") 1742 init_from_fn = callable(initial_value) 1743 1744 if collections is None: 1745 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 1746 if not isinstance(collections, (list, tuple, set)): 1747 raise ValueError( 1748 "collections argument to Variable constructor must be a list, tuple, " 1749 "or set. Got %s of type %s" % (collections, type(collections))) 1750 if constraint is not None and not callable(constraint): 1751 raise ValueError("The `constraint` argument must be a callable.") 1752 1753 # Store the graph key so optimizers know how to only retrieve variables from 1754 # this graph. 1755 self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 1756 if isinstance(initial_value, trackable.CheckpointInitialValue): 1757 self._maybe_initialize_trackable() 1758 self._update_uid = initial_value.checkpoint_position.restore_uid 1759 initial_value = initial_value.wrapped_value 1760 1761 synchronization, aggregation, trainable = ( 1762 validate_synchronization_aggregation_trainable(synchronization, 1763 aggregation, trainable, 1764 name)) 1765 self._synchronization = synchronization 1766 self._aggregation = aggregation 1767 self._trainable = trainable 1768 if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: 1769 collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] 1770 with ops.init_scope(): 1771 # Ensure that we weren't lifted into the eager context. 1772 if context.executing_eagerly(): 1773 raise RuntimeError( 1774 "Reference variables are not supported when eager execution is " 1775 "enabled. Please run `tf.compat.v1.enable_resource_variables()` to " 1776 "switch to resource variables.") 1777 with ops.name_scope(name, "Variable", 1778 [] if init_from_fn else [initial_value]) as name: 1779 1780 if init_from_fn: 1781 # Use attr_scope and device(None) to simulate the behavior of 1782 # colocate_with when the variable we want to colocate with doesn't 1783 # yet exist. 1784 true_name = ops.name_from_scope_name(name) # pylint: disable=protected-access 1785 attr = attr_value_pb2.AttrValue( 1786 list=attr_value_pb2.AttrValue.ListValue( 1787 s=[compat.as_bytes("loc:@%s" % true_name)])) 1788 # pylint: disable=protected-access 1789 with ops.get_default_graph()._attr_scope({"_class": attr}): 1790 with ops.name_scope("Initializer"), ops.device(None): 1791 initial_value = initial_value() 1792 if isinstance(initial_value, trackable.CheckpointInitialValue): 1793 self._maybe_initialize_trackable() 1794 self._update_uid = initial_value.checkpoint_position.restore_uid 1795 initial_value = initial_value.wrapped_value 1796 self._initial_value = ops.convert_to_tensor( 1797 initial_value, name="initial_value", dtype=dtype) 1798 if shape is None: 1799 shape = ( 1800 self._initial_value.get_shape() 1801 if validate_shape else tensor_shape.unknown_shape()) 1802 self._variable = state_ops.variable_op_v2( 1803 shape, self._initial_value.dtype.base_dtype, name=name) 1804 # pylint: enable=protected-access 1805 1806 # Or get the initial value from a Tensor or Python object. 1807 else: 1808 self._initial_value = ops.convert_to_tensor( 1809 initial_value, name="initial_value", dtype=dtype) 1810 # pylint: disable=protected-access 1811 if self._initial_value.op._get_control_flow_context() is not None: 1812 raise ValueError( 1813 "Initializer for variable %s is from inside a control-flow " 1814 "construct, such as a loop or conditional. When creating a " 1815 "variable inside a loop or conditional, use a lambda as the " 1816 "initializer." % name) 1817 if shape is None: 1818 # pylint: enable=protected-access 1819 shape = ( 1820 self._initial_value.get_shape() 1821 if validate_shape else tensor_shape.unknown_shape()) 1822 # In this case, the variable op can't be created until after the 1823 # initial_value has been converted to a Tensor with a known type. 1824 self._variable = state_ops.variable_op_v2( 1825 shape, self._initial_value.dtype.base_dtype, name=name) 1826 1827 # Cache the name in `self`, because some APIs call `Variable.name` in a 1828 # tight loop, and this halves the cost. 1829 self._name = self._variable.name 1830 1831 # Manually overrides the variable's shape with the initial value's. 1832 if validate_shape: 1833 initial_value_shape = self._initial_value.get_shape() 1834 if not initial_value_shape.is_fully_defined(): 1835 raise ValueError("initial_value must have a shape specified: %s" % 1836 self._initial_value) 1837 1838 # If 'initial_value' makes use of other variables, make sure we don't 1839 # have an issue if these other variables aren't initialized first by 1840 # using their initialized_value() method. 1841 self._initializer_op = state_ops.assign( 1842 self._variable, 1843 _try_guard_against_uninitialized_dependencies( 1844 name, self._initial_value), 1845 validate_shape=validate_shape).op 1846 1847 # TODO(vrv): Change this class to not take caching_device, but 1848 # to take the op to colocate the snapshot with, so we can use 1849 # colocation rather than devices. 1850 if caching_device is not None: 1851 with ops.device(caching_device): 1852 self._snapshot = array_ops.identity(self._variable, name="read") 1853 else: 1854 with ops.colocate_with(self._variable.op): 1855 self._snapshot = array_ops.identity(self._variable, name="read") 1856 ops.add_to_collections(collections, self) 1857 1858 self._caching_device = caching_device 1859 self._save_slice_info = None 1860 self._constraint = constraint 1861 1862 def _init_from_proto(self, variable_def, import_scope=None): 1863 """Recreates the Variable object from a `VariableDef` protocol buffer. 1864 1865 Args: 1866 variable_def: `VariableDef` protocol buffer, describing a variable whose 1867 nodes already exists in the graph. 1868 import_scope: Optional `string`. Name scope to add. 1869 """ 1870 assert isinstance(variable_def, variable_pb2.VariableDef) 1871 # Create from variable_def. 1872 g = ops.get_default_graph() 1873 self._variable = g.as_graph_element( 1874 ops.prepend_name_scope( 1875 variable_def.variable_name, import_scope=import_scope)) 1876 self._name = self._variable.name 1877 self._initializer_op = g.as_graph_element( 1878 ops.prepend_name_scope( 1879 variable_def.initializer_name, import_scope=import_scope)) 1880 # Tests whether initial_value_name exists first for backwards compatibility. 1881 if (hasattr(variable_def, "initial_value_name") and 1882 variable_def.initial_value_name): 1883 self._initial_value = g.as_graph_element( 1884 ops.prepend_name_scope( 1885 variable_def.initial_value_name, import_scope=import_scope)) 1886 else: 1887 self._initial_value = None 1888 synchronization, aggregation, trainable = ( 1889 validate_synchronization_aggregation_trainable( 1890 variable_def.synchronization, variable_def.aggregation, 1891 variable_def.trainable, variable_def.variable_name)) 1892 self._synchronization = synchronization 1893 self._aggregation = aggregation 1894 self._trainable = trainable 1895 self._snapshot = g.as_graph_element( 1896 ops.prepend_name_scope( 1897 variable_def.snapshot_name, import_scope=import_scope)) 1898 if variable_def.HasField("save_slice_info_def"): 1899 self._save_slice_info = Variable.SaveSliceInfo( 1900 save_slice_info_def=variable_def.save_slice_info_def, 1901 import_scope=import_scope) 1902 else: 1903 self._save_slice_info = None 1904 self._caching_device = None 1905 self._constraint = None 1906 1907 def _as_graph_element(self): 1908 """Conversion function for Graph.as_graph_element().""" 1909 return self._variable 1910 1911 def value(self): 1912 """Returns the last snapshot of this variable. 1913 1914 You usually do not need to call this method as all ops that need the value 1915 of the variable call it automatically through a `convert_to_tensor()` call. 1916 1917 Returns a `Tensor` which holds the value of the variable. You can not 1918 assign a new value to this tensor as it is not a reference to the variable. 1919 1920 To avoid copies, if the consumer of the returned value is on the same device 1921 as the variable, this actually returns the live value of the variable, not 1922 a copy. Updates to the variable are seen by the consumer. If the consumer 1923 is on a different device it will get a copy of the variable. 1924 1925 Returns: 1926 A `Tensor` containing the value of the variable. 1927 """ 1928 return self._snapshot 1929 1930 def read_value(self): 1931 """Returns the value of this variable, read in the current context. 1932 1933 Can be different from value() if it's on another device, with control 1934 dependencies, etc. 1935 1936 Returns: 1937 A `Tensor` containing the value of the variable. 1938 """ 1939 return array_ops.identity(self._variable, name="read") 1940 1941 def _ref(self): 1942 """Returns a reference to this variable. 1943 1944 You usually do not need to call this method as all ops that need a reference 1945 to the variable call it automatically. 1946 1947 Returns is a `Tensor` which holds a reference to the variable. You can 1948 assign a new value to the variable by passing the tensor to an assign op. 1949 See `tf.Variable.value` if you want to get the value of the 1950 variable. 1951 1952 Returns: 1953 A `Tensor` that is a reference to the variable. 1954 """ 1955 return self._variable 1956 1957 def set_shape(self, shape): 1958 """Overrides the shape for this variable. 1959 1960 Args: 1961 shape: the `TensorShape` representing the overridden shape. 1962 """ 1963 self._ref().set_shape(shape) 1964 self.value().set_shape(shape) 1965 1966 @property 1967 def trainable(self): 1968 return self._trainable 1969 1970 @property 1971 def synchronization(self): 1972 return self._synchronization 1973 1974 @property 1975 def aggregation(self): 1976 return self._aggregation 1977 1978 def eval(self, session=None): 1979 """In a session, computes and returns the value of this variable. 1980 1981 This is not a graph construction method, it does not add ops to the graph. 1982 1983 This convenience method requires a session where the graph 1984 containing this variable has been launched. If no session is 1985 passed, the default session is used. See `tf.compat.v1.Session` for more 1986 information on launching a graph and on sessions. 1987 1988 ```python 1989 v = tf.Variable([1, 2]) 1990 init = tf.compat.v1.global_variables_initializer() 1991 1992 with tf.compat.v1.Session() as sess: 1993 sess.run(init) 1994 # Usage passing the session explicitly. 1995 print(v.eval(sess)) 1996 # Usage with the default session. The 'with' block 1997 # above makes 'sess' the default session. 1998 print(v.eval()) 1999 ``` 2000 2001 Args: 2002 session: The session to use to evaluate this variable. If none, the 2003 default session is used. 2004 2005 Returns: 2006 A numpy `ndarray` with a copy of the value of this variable. 2007 """ 2008 return self._variable.eval(session=session) 2009 2010 @property 2011 def initial_value(self): 2012 """Returns the Tensor used as the initial value for the variable. 2013 2014 Note that this is different from `initialized_value()` which runs 2015 the op that initializes the variable before returning its value. 2016 This method returns the tensor that is used by the op that initializes 2017 the variable. 2018 2019 Returns: 2020 A `Tensor`. 2021 """ 2022 return self._initial_value 2023 2024 @property 2025 def constraint(self): 2026 """Returns the constraint function associated with this variable. 2027 2028 Returns: 2029 The constraint function that was passed to the variable constructor. 2030 Can be `None` if no constraint was passed. 2031 """ 2032 return self._constraint 2033 2034 def assign(self, value, use_locking=False, name=None, read_value=True): 2035 """Assigns a new value to the variable. 2036 2037 This is essentially a shortcut for `assign(self, value)`. 2038 2039 Args: 2040 value: A `Tensor`. The new value for this variable. 2041 use_locking: If `True`, use locking during the assignment. 2042 name: The name of the operation to be created 2043 read_value: if True, will return something which evaluates to the new 2044 value of the variable; if False will return the assign op. 2045 2046 Returns: 2047 A `Tensor` that will hold the new value of this variable after 2048 the assignment has completed. 2049 """ 2050 assign = state_ops.assign( 2051 self._variable, value, use_locking=use_locking, name=name) 2052 if read_value: 2053 return assign 2054 return assign.op 2055 2056 def assign_add(self, delta, use_locking=False, name=None, read_value=True): 2057 """Adds a value to this variable. 2058 2059 This is essentially a shortcut for `assign_add(self, delta)`. 2060 2061 Args: 2062 delta: A `Tensor`. The value to add to this variable. 2063 use_locking: If `True`, use locking during the operation. 2064 name: The name of the operation to be created 2065 read_value: if True, will return something which evaluates to the new 2066 value of the variable; if False will return the assign op. 2067 2068 Returns: 2069 A `Tensor` that will hold the new value of this variable after 2070 the addition has completed. 2071 """ 2072 assign = state_ops.assign_add( 2073 self._variable, delta, use_locking=use_locking, name=name) 2074 if read_value: 2075 return assign 2076 return assign.op 2077 2078 def assign_sub(self, delta, use_locking=False, name=None, read_value=True): 2079 """Subtracts a value from this variable. 2080 2081 This is essentially a shortcut for `assign_sub(self, delta)`. 2082 2083 Args: 2084 delta: A `Tensor`. The value to subtract from this variable. 2085 use_locking: If `True`, use locking during the operation. 2086 name: The name of the operation to be created 2087 read_value: if True, will return something which evaluates to the new 2088 value of the variable; if False will return the assign op. 2089 2090 Returns: 2091 A `Tensor` that will hold the new value of this variable after 2092 the subtraction has completed. 2093 """ 2094 assign = state_ops.assign_sub( 2095 self._variable, delta, use_locking=use_locking, name=name) 2096 if read_value: 2097 return assign 2098 return assign.op 2099 2100 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 2101 """Subtracts `tf.IndexedSlices` from this variable. 2102 2103 Args: 2104 sparse_delta: `tf.IndexedSlices` to be subtracted from this variable. 2105 use_locking: If `True`, use locking during the operation. 2106 name: the name of the operation. 2107 2108 Returns: 2109 A `Tensor` that will hold the new value of this variable after 2110 the scattered subtraction has completed. 2111 2112 Raises: 2113 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2114 """ 2115 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 2116 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2117 return gen_state_ops.scatter_sub( 2118 self._variable, 2119 sparse_delta.indices, 2120 sparse_delta.values, 2121 use_locking=use_locking, 2122 name=name) 2123 2124 def scatter_add(self, sparse_delta, use_locking=False, name=None): 2125 """Adds `tf.IndexedSlices` to this variable. 2126 2127 Args: 2128 sparse_delta: `tf.IndexedSlices` to be added to this variable. 2129 use_locking: If `True`, use locking during the operation. 2130 name: the name of the operation. 2131 2132 Returns: 2133 A `Tensor` that will hold the new value of this variable after 2134 the scattered addition has completed. 2135 2136 Raises: 2137 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2138 """ 2139 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 2140 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2141 return gen_state_ops.scatter_add( 2142 self._variable, 2143 sparse_delta.indices, 2144 sparse_delta.values, 2145 use_locking=use_locking, 2146 name=name) 2147 2148 def scatter_max(self, sparse_delta, use_locking=False, name=None): 2149 """Updates this variable with the max of `tf.IndexedSlices` and itself. 2150 2151 Args: 2152 sparse_delta: `tf.IndexedSlices` to use as an argument of max with this 2153 variable. 2154 use_locking: If `True`, use locking during the operation. 2155 name: the name of the operation. 2156 2157 Returns: 2158 A `Tensor` that will hold the new value of this variable after 2159 the scattered maximization has completed. 2160 2161 Raises: 2162 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2163 """ 2164 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 2165 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2166 return gen_state_ops.scatter_max( 2167 self._variable, 2168 sparse_delta.indices, 2169 sparse_delta.values, 2170 use_locking=use_locking, 2171 name=name) 2172 2173 def scatter_min(self, sparse_delta, use_locking=False, name=None): 2174 """Updates this variable with the min of `tf.IndexedSlices` and itself. 2175 2176 Args: 2177 sparse_delta: `tf.IndexedSlices` to use as an argument of min with this 2178 variable. 2179 use_locking: If `True`, use locking during the operation. 2180 name: the name of the operation. 2181 2182 Returns: 2183 A `Tensor` that will hold the new value of this variable after 2184 the scattered minimization has completed. 2185 2186 Raises: 2187 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2188 """ 2189 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 2190 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2191 return gen_state_ops.scatter_min( 2192 self._variable, 2193 sparse_delta.indices, 2194 sparse_delta.values, 2195 use_locking=use_locking, 2196 name=name) 2197 2198 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 2199 """Multiply this variable by `tf.IndexedSlices`. 2200 2201 Args: 2202 sparse_delta: `tf.IndexedSlices` to multiply this variable by. 2203 use_locking: If `True`, use locking during the operation. 2204 name: the name of the operation. 2205 2206 Returns: 2207 A `Tensor` that will hold the new value of this variable after 2208 the scattered multiplication has completed. 2209 2210 Raises: 2211 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2212 """ 2213 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 2214 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2215 return gen_state_ops.scatter_mul( 2216 self._variable, 2217 sparse_delta.indices, 2218 sparse_delta.values, 2219 use_locking=use_locking, 2220 name=name) 2221 2222 def scatter_div(self, sparse_delta, use_locking=False, name=None): 2223 """Divide this variable by `tf.IndexedSlices`. 2224 2225 Args: 2226 sparse_delta: `tf.IndexedSlices` to divide this variable by. 2227 use_locking: If `True`, use locking during the operation. 2228 name: the name of the operation. 2229 2230 Returns: 2231 A `Tensor` that will hold the new value of this variable after 2232 the scattered division has completed. 2233 2234 Raises: 2235 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2236 """ 2237 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 2238 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2239 return gen_state_ops.scatter_div( 2240 self._variable, 2241 sparse_delta.indices, 2242 sparse_delta.values, 2243 use_locking=use_locking, 2244 name=name) 2245 2246 def scatter_update(self, sparse_delta, use_locking=False, name=None): 2247 """Assigns `tf.IndexedSlices` to this variable. 2248 2249 Args: 2250 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 2251 use_locking: If `True`, use locking during the operation. 2252 name: the name of the operation. 2253 2254 Returns: 2255 A `Tensor` that will hold the new value of this variable after 2256 the scattered assignment has completed. 2257 2258 Raises: 2259 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2260 """ 2261 if not isinstance(sparse_delta, indexed_slices.IndexedSlices): 2262 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2263 return gen_state_ops.scatter_update( 2264 self._variable, 2265 sparse_delta.indices, 2266 sparse_delta.values, 2267 use_locking=use_locking, 2268 name=name) 2269 2270 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 2271 """Assigns `tf.IndexedSlices` to this variable batch-wise. 2272 2273 Analogous to `batch_gather`. This assumes that this variable and the 2274 sparse_delta IndexedSlices have a series of leading dimensions that are the 2275 same for all of them, and the updates are performed on the last dimension of 2276 indices. In other words, the dimensions should be the following: 2277 2278 `num_prefix_dims = sparse_delta.indices.ndims - 1` 2279 `batch_dim = num_prefix_dims + 1` 2280 `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[ 2281 batch_dim:]` 2282 2283 where 2284 2285 `sparse_delta.updates.shape[:num_prefix_dims]` 2286 `== sparse_delta.indices.shape[:num_prefix_dims]` 2287 `== var.shape[:num_prefix_dims]` 2288 2289 And the operation performed can be expressed as: 2290 2291 `var[i_1, ..., i_n, 2292 sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[ 2293 i_1, ..., i_n, j]` 2294 2295 When sparse_delta.indices is a 1D tensor, this operation is equivalent to 2296 `scatter_update`. 2297 2298 To avoid this operation one can looping over the first `ndims` of the 2299 variable and using `scatter_update` on the subtensors that result of slicing 2300 the first dimension. This is a valid option for `ndims = 1`, but less 2301 efficient than this implementation. 2302 2303 Args: 2304 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 2305 use_locking: If `True`, use locking during the operation. 2306 name: the name of the operation. 2307 2308 Returns: 2309 A `Tensor` that will hold the new value of this variable after 2310 the scattered assignment has completed. 2311 2312 Raises: 2313 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2314 """ 2315 return state_ops.batch_scatter_update( 2316 self, 2317 sparse_delta.indices, 2318 sparse_delta.values, 2319 use_locking=use_locking, 2320 name=name) 2321 2322 def scatter_nd_sub(self, indices, updates, name=None): 2323 """Applies sparse subtraction to individual values or slices in a Variable. 2324 2325 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 2326 2327 `indices` must be integer tensor, containing indices into `ref`. 2328 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 2329 2330 The innermost dimension of `indices` (with length `K`) corresponds to 2331 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 2332 dimension of `ref`. 2333 2334 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 2335 2336 ``` 2337 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 2338 ``` 2339 2340 For example, say we want to add 4 scattered elements to a rank-1 tensor to 2341 8 elements. In Python, that update would look like this: 2342 2343 ```python 2344 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 2345 indices = tf.constant([[4], [3], [1] ,[7]]) 2346 updates = tf.constant([9, 10, 11, 12]) 2347 op = ref.scatter_nd_sub(indices, updates) 2348 with tf.compat.v1.Session() as sess: 2349 print sess.run(op) 2350 ``` 2351 2352 The resulting update to ref would look like this: 2353 2354 [1, -9, 3, -6, -6, 6, 7, -4] 2355 2356 See `tf.scatter_nd` for more details about how to make updates to 2357 slices. 2358 2359 Args: 2360 indices: The indices to be used in the operation. 2361 updates: The values to be used in the operation. 2362 name: the name of the operation. 2363 2364 Returns: 2365 A `Tensor` that will hold the new value of this variable after 2366 the scattered subtraction has completed. 2367 """ 2368 return gen_state_ops.scatter_nd_sub( 2369 self._variable, indices, updates, use_locking=True, name=name) 2370 2371 def scatter_nd_add(self, indices, updates, name=None): 2372 """Applies sparse addition to individual values or slices in a Variable. 2373 2374 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 2375 2376 `indices` must be integer tensor, containing indices into `ref`. 2377 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 2378 2379 The innermost dimension of `indices` (with length `K`) corresponds to 2380 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 2381 dimension of `ref`. 2382 2383 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 2384 2385 ``` 2386 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 2387 ``` 2388 2389 For example, say we want to add 4 scattered elements to a rank-1 tensor to 2390 8 elements. In Python, that update would look like this: 2391 2392 ```python 2393 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 2394 indices = tf.constant([[4], [3], [1] ,[7]]) 2395 updates = tf.constant([9, 10, 11, 12]) 2396 add = ref.scatter_nd_add(indices, updates) 2397 with tf.compat.v1.Session() as sess: 2398 print sess.run(add) 2399 ``` 2400 2401 The resulting update to ref would look like this: 2402 2403 [1, 13, 3, 14, 14, 6, 7, 20] 2404 2405 See `tf.scatter_nd` for more details about how to make updates to 2406 slices. 2407 2408 Args: 2409 indices: The indices to be used in the operation. 2410 updates: The values to be used in the operation. 2411 name: the name of the operation. 2412 2413 Returns: 2414 A `Tensor` that will hold the new value of this variable after 2415 the scattered addition has completed. 2416 """ 2417 return gen_state_ops.scatter_nd_add( 2418 self._variable, indices, updates, use_locking=True, name=name) 2419 2420 def scatter_nd_update(self, indices, updates, name=None): 2421 """Applies sparse assignment to individual values or slices in a Variable. 2422 2423 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 2424 2425 `indices` must be integer tensor, containing indices into `ref`. 2426 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 2427 2428 The innermost dimension of `indices` (with length `K`) corresponds to 2429 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 2430 dimension of `ref`. 2431 2432 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 2433 2434 ``` 2435 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 2436 ``` 2437 2438 For example, say we want to add 4 scattered elements to a rank-1 tensor to 2439 8 elements. In Python, that update would look like this: 2440 2441 ```python 2442 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 2443 indices = tf.constant([[4], [3], [1] ,[7]]) 2444 updates = tf.constant([9, 10, 11, 12]) 2445 op = ref.scatter_nd_update(indices, updates) 2446 with tf.compat.v1.Session() as sess: 2447 print sess.run(op) 2448 ``` 2449 2450 The resulting update to ref would look like this: 2451 2452 [1, 11, 3, 10, 9, 6, 7, 12] 2453 2454 See `tf.scatter_nd` for more details about how to make updates to 2455 slices. 2456 2457 Args: 2458 indices: The indices to be used in the operation. 2459 updates: The values to be used in the operation. 2460 name: the name of the operation. 2461 2462 Returns: 2463 A `Tensor` that will hold the new value of this variable after 2464 the scattered assignment has completed. 2465 """ 2466 return gen_state_ops.scatter_nd_update( 2467 self._variable, indices, updates, use_locking=True, name=name) 2468 2469 def scatter_nd_max(self, indices, updates, name=None): 2470 """Updates this variable with the max of `tf.IndexedSlices` and itself. 2471 2472 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 2473 2474 `indices` must be integer tensor, containing indices into `ref`. 2475 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 2476 2477 The innermost dimension of `indices` (with length `K`) corresponds to 2478 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 2479 dimension of `ref`. 2480 2481 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 2482 2483 ``` 2484 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 2485 ``` 2486 2487 See `tf.scatter_nd` for more details about how to make updates to 2488 slices. 2489 2490 Args: 2491 indices: The indices to be used in the operation. 2492 updates: The values to be used in the operation. 2493 name: the name of the operation. 2494 2495 Returns: 2496 A `Tensor` that will hold the new value of this variable after 2497 the scattered addition has completed. 2498 """ 2499 return gen_state_ops.scatter_nd_max( 2500 self._variable, indices, updates, use_locking=True, name=name) 2501 2502 def scatter_nd_min(self, indices, updates, name=None): 2503 """Updates this variable with the min of `tf.IndexedSlices` and itself. 2504 2505 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 2506 2507 `indices` must be integer tensor, containing indices into `ref`. 2508 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 2509 2510 The innermost dimension of `indices` (with length `K`) corresponds to 2511 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 2512 dimension of `ref`. 2513 2514 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 2515 2516 ``` 2517 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 2518 ``` 2519 2520 See `tf.scatter_nd` for more details about how to make updates to 2521 slices. 2522 2523 Args: 2524 indices: The indices to be used in the operation. 2525 updates: The values to be used in the operation. 2526 name: the name of the operation. 2527 2528 Returns: 2529 A `Tensor` that will hold the new value of this variable after 2530 the scattered addition has completed. 2531 """ 2532 return gen_state_ops.scatter_nd_min( 2533 self._variable, indices, updates, use_locking=True, name=name) 2534 2535 def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask, 2536 end_mask, ellipsis_mask, new_axis_mask, 2537 shrink_axis_mask): 2538 return gen_array_ops.strided_slice_assign( 2539 ref=self._ref(), 2540 begin=begin, 2541 end=end, 2542 strides=strides, 2543 value=value, 2544 name=name, 2545 begin_mask=begin_mask, 2546 end_mask=end_mask, 2547 ellipsis_mask=ellipsis_mask, 2548 new_axis_mask=new_axis_mask, 2549 shrink_axis_mask=shrink_axis_mask) 2550 2551 @deprecated(None, "Prefer Dataset.range instead.") 2552 def count_up_to(self, limit): 2553 """Increments this variable until it reaches `limit`. 2554 2555 When that Op is run it tries to increment the variable by `1`. If 2556 incrementing the variable would bring it above `limit` then the Op raises 2557 the exception `OutOfRangeError`. 2558 2559 If no error is raised, the Op outputs the value of the variable before 2560 the increment. 2561 2562 This is essentially a shortcut for `count_up_to(self, limit)`. 2563 2564 Args: 2565 limit: value at which incrementing the variable raises an error. 2566 2567 Returns: 2568 A `Tensor` that will hold the variable value before the increment. If no 2569 other Op modifies this variable, the values produced will all be 2570 distinct. 2571 """ 2572 return state_ops.count_up_to(self._variable, limit=limit) 2573 2574 # Conversion to tensor. 2575 @staticmethod 2576 def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name 2577 """Utility function for converting a Variable to a Tensor.""" 2578 _ = name 2579 if dtype and not dtype.is_compatible_with(v.dtype): 2580 raise ValueError( 2581 "Incompatible type conversion requested to type '%s' for variable " 2582 "of type '%s'" % (dtype.name, v.dtype.name)) 2583 if as_ref: 2584 return v._ref() # pylint: disable=protected-access 2585 else: 2586 return v.value() 2587 2588 # NOTE(mrry): This enables the Variable's overloaded "right" binary 2589 # operators to run when the left operand is an ndarray, because it 2590 # accords the Variable class higher priority than an ndarray, or a 2591 # numpy matrix. 2592 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ 2593 # mechanism, which allows more control over how Variables interact 2594 # with ndarrays. 2595 __array_priority__ = 100 2596 2597 @property 2598 def name(self): 2599 """The name of this variable.""" 2600 return self._name 2601 2602 @property 2603 def initializer(self): 2604 """The initializer operation for this variable.""" 2605 return self._initializer_op 2606 2607 @property 2608 def device(self): 2609 """The device of this variable.""" 2610 return self._variable.device 2611 2612 @property 2613 def dtype(self): 2614 """The `DType` of this variable.""" 2615 return self._variable.dtype 2616 2617 @property 2618 def op(self): 2619 """The `Operation` of this variable.""" 2620 return self._variable.op 2621 2622 @property 2623 def graph(self): 2624 """The `Graph` of this variable.""" 2625 return self._variable.graph 2626 2627 @property 2628 def _distribute_strategy(self): 2629 """The `tf.distribute.Strategy` that this variable was created under.""" 2630 return None # Ref variables are never created inside a strategy. 2631 2632 @property 2633 def shape(self): 2634 """The `TensorShape` of this variable. 2635 2636 Returns: 2637 A `TensorShape`. 2638 """ 2639 return self._variable.get_shape() 2640 2641 def to_proto(self, export_scope=None): 2642 """Converts a `Variable` to a `VariableDef` protocol buffer. 2643 2644 Args: 2645 export_scope: Optional `string`. Name scope to remove. 2646 2647 Returns: 2648 A `VariableDef` protocol buffer, or `None` if the `Variable` is not 2649 in the specified name scope. 2650 """ 2651 if (export_scope is None or self._variable.name.startswith(export_scope)): 2652 var_def = variable_pb2.VariableDef() 2653 var_def.variable_name = ops.strip_name_scope(self._variable.name, 2654 export_scope) 2655 if self._initial_value is not None: 2656 # For backwards compatibility. 2657 var_def.initial_value_name = ops.strip_name_scope( 2658 self._initial_value.name, export_scope) 2659 var_def.trainable = self.trainable 2660 var_def.synchronization = self.synchronization.value 2661 var_def.aggregation = self.aggregation.value 2662 var_def.initializer_name = ops.strip_name_scope(self.initializer.name, 2663 export_scope) 2664 var_def.snapshot_name = ops.strip_name_scope(self._snapshot.name, 2665 export_scope) 2666 if self._save_slice_info: 2667 var_def.save_slice_info_def.MergeFrom( 2668 self._save_slice_info.to_proto(export_scope=export_scope)) 2669 return var_def 2670 else: 2671 return None 2672 2673 def __iadd__(self, other): 2674 logging.log_first_n( 2675 logging.WARN, "Variable += will be deprecated. Use variable.assign_add" 2676 " if you want assignment to the variable value or 'x = x + y'" 2677 " if you want a new python Tensor object.", 1) 2678 return self + other 2679 2680 def __isub__(self, other): 2681 logging.log_first_n( 2682 logging.WARN, "Variable -= will be deprecated. Use variable.assign_sub" 2683 " if you want assignment to the variable value or 'x = x - y'" 2684 " if you want a new python Tensor object.", 1) 2685 return self - other 2686 2687 def __imul__(self, other): 2688 logging.log_first_n( 2689 logging.WARN, 2690 "Variable *= will be deprecated. Use `var.assign(var * other)`" 2691 " if you want assignment to the variable value or `x = x * y`" 2692 " if you want a new python Tensor object.", 1) 2693 return self * other 2694 2695 def __idiv__(self, other): 2696 logging.log_first_n( 2697 logging.WARN, 2698 "Variable /= will be deprecated. Use `var.assign(var / other)`" 2699 " if you want assignment to the variable value or `x = x / y`" 2700 " if you want a new python Tensor object.", 1) 2701 return self / other 2702 2703 def __itruediv__(self, other): 2704 logging.log_first_n( 2705 logging.WARN, 2706 "Variable /= will be deprecated. Use `var.assign(var / other)`" 2707 " if you want assignment to the variable value or `x = x / y`" 2708 " if you want a new python Tensor object.", 1) 2709 return self / other 2710 2711 def __irealdiv__(self, other): 2712 logging.log_first_n( 2713 logging.WARN, 2714 "Variable /= will be deprecated. Use `var.assign(var / other)`" 2715 " if you want assignment to the variable value or `x = x / y`" 2716 " if you want a new python Tensor object.", 1) 2717 return self / other 2718 2719 def __ipow__(self, other): 2720 logging.log_first_n( 2721 logging.WARN, 2722 "Variable **= will be deprecated. Use `var.assign(var ** other)`" 2723 " if you want assignment to the variable value or `x = x ** y`" 2724 " if you want a new python Tensor object.", 1) 2725 return self**other 2726 2727 2728def _try_guard_against_uninitialized_dependencies(name, initial_value): 2729 """Attempt to guard against dependencies on uninitialized variables. 2730 2731 Replace references to variables in `initial_value` with references to the 2732 variable's initialized values. The initialized values are essentially 2733 conditional TensorFlow graphs that return a variable's value if it is 2734 initialized or its `initial_value` if it hasn't been initialized. This 2735 replacement is done on a best effort basis: 2736 2737 - If the `initial_value` graph contains cycles, we don't do any 2738 replacements for that graph. 2739 - If the variables that `initial_value` depends on are not present in the 2740 `GLOBAL_VARIABLES` or `LOCAL_VARIABLES` we don't replace them. 2741 2742 In these cases, it is up to the caller to ensure that the `initial_value` 2743 graph uses initialized variables or that they guard access to variables 2744 using their `initialized_value` method. 2745 2746 Args: 2747 name: Variable name. 2748 initial_value: `Tensor`. The initial value. 2749 2750 Returns: 2751 A `Tensor` suitable to initialize a variable. 2752 Raises: 2753 TypeError: If `initial_value` is not a `Tensor`. 2754 """ 2755 if not isinstance(initial_value, ops.Tensor): 2756 raise TypeError("initial_value needs to be a Tensor: %s" % initial_value) 2757 2758 # Don't modify initial_value if it contains any cyclic dependencies. 2759 if _has_cycle(initial_value.op, state={}): 2760 return initial_value 2761 return _safe_initial_value_from_tensor(name, initial_value, op_cache={}) 2762 2763 2764_UNKNOWN, _STARTED, _FINISHED = range(3) 2765 2766 2767def _has_cycle(op, state): 2768 """Detect cycles in the dependencies of `initial_value`.""" 2769 op_state = state.get(op.name, _UNKNOWN) 2770 if op_state == _STARTED: 2771 return True 2772 elif op_state == _FINISHED: 2773 return False 2774 2775 state[op.name] = _STARTED 2776 for i in itertools.chain((i.op for i in op.inputs), op.control_inputs): 2777 if _has_cycle(i, state): 2778 return True 2779 state[op.name] = _FINISHED 2780 return False 2781 2782 2783def _safe_initial_value_from_tensor(name, tensor, op_cache): 2784 """Replace dependencies on variables with their initialized values. 2785 2786 Args: 2787 name: Variable name. 2788 tensor: A `Tensor`. The tensor to replace. 2789 op_cache: A dict mapping operation names to `Operation`s. Used to memoize 2790 the results so as to avoid creating redundant operations. 2791 2792 Returns: 2793 A `Tensor` compatible with `tensor`. Any inputs that lead to variable 2794 values will be replaced with a corresponding graph that uses the 2795 variable's initialized values. This is done on a best-effort basis. If no 2796 modifications need to be made then `tensor` will be returned unchanged. 2797 """ 2798 op = tensor.op 2799 new_op = op_cache.get(op.name) 2800 if new_op is None: 2801 new_op = _safe_initial_value_from_op(name, op, op_cache) 2802 op_cache[op.name] = new_op 2803 return new_op.outputs[tensor.value_index] 2804 2805 2806def _safe_initial_value_from_op(name, op, op_cache): 2807 """Replace dependencies on variables with their initialized values. 2808 2809 Args: 2810 name: Variable name. 2811 op: An `Operation`. The operation to replace. 2812 op_cache: A dict mapping operation names to `Operation`s. Used to memoize 2813 the results so as to avoid creating redundant operations. 2814 2815 Returns: 2816 An `Operation` compatible with `op`. Any inputs that lead to variable 2817 values will be replaced with a corresponding graph that uses the 2818 variable's initialized values. This is done on a best-effort basis. If no 2819 modifications need to be made then `op` will be returned unchanged. 2820 """ 2821 op_type = op.node_def.op 2822 if op_type in ("IsVariableInitialized", "VarIsInitializedOp", 2823 "ReadVariableOp", "If"): 2824 return op 2825 2826 # Attempt to find the initialized_value of any variable reference / handles. 2827 # TODO(b/70206927): Fix handling of ResourceVariables. 2828 if op_type in ("Variable", "VariableV2", "VarHandleOp"): 2829 initialized_value = _find_initialized_value_for_variable(op) 2830 return op if initialized_value is None else initialized_value.op 2831 2832 # Recursively build initializer expressions for inputs. 2833 modified = False 2834 new_op_inputs = [] 2835 for op_input in op.inputs: 2836 new_op_input = _safe_initial_value_from_tensor(name, op_input, op_cache) 2837 new_op_inputs.append(new_op_input) 2838 modified = modified or (new_op_input != op_input) 2839 2840 # If at least one input was modified, replace the op. 2841 if modified: 2842 new_op_type = op_type 2843 if new_op_type == "RefSwitch": 2844 new_op_type = "Switch" 2845 new_op_name = op.node_def.name + "_" + name 2846 new_op_name = new_op_name.replace(":", "_") 2847 return op.graph.create_op( 2848 new_op_type, 2849 new_op_inputs, 2850 op._output_types, # pylint: disable=protected-access 2851 name=new_op_name, 2852 attrs=op.node_def.attr) 2853 2854 return op 2855 2856 2857def _find_initialized_value_for_variable(variable_op): 2858 """Find the initialized value for a variable op. 2859 2860 To do so, lookup the variable op in the variables collection. 2861 2862 Args: 2863 variable_op: A variable `Operation`. 2864 2865 Returns: 2866 A `Tensor` representing the initialized value for the variable or `None` 2867 if the initialized value could not be found. 2868 """ 2869 try: 2870 var_names = [variable_op.node_def.name, variable_op.node_def.name + ":0"] 2871 for collection_name in (ops.GraphKeys.GLOBAL_VARIABLES, 2872 ops.GraphKeys.LOCAL_VARIABLES): 2873 for var in variable_op.graph.get_collection(collection_name): 2874 if var.name in var_names: 2875 return var.initialized_value() 2876 except AttributeError: 2877 # Return None when an incomplete user-defined variable type was put in 2878 # the collection. 2879 return None 2880 return None 2881 2882 2883class PartitionedVariable: 2884 """A container for partitioned `Variable` objects. 2885 2886 @compatibility(eager) `tf.PartitionedVariable` is not compatible with 2887 eager execution. Use `tf.Variable` instead which is compatible 2888 with both eager execution and graph construction. See [the 2889 TensorFlow Eager Execution 2890 guide](https://www.tensorflow.org/guide/eager#variables_and_optimizers) 2891 for details on how variables work in eager execution. 2892 @end_compatibility 2893 """ 2894 2895 def __init__(self, name, shape, dtype, variable_list, partitions): 2896 """Creates a new partitioned variable wrapper. 2897 2898 Variables passed via the variable_list must contain a save_slice_info 2899 field. Concatenation and iteration is in lexicographic order according 2900 to the var_offset property of the save_slice_info. 2901 2902 Args: 2903 name: String. Overall name of the variables. 2904 shape: List of integers. Overall shape of the variables. 2905 dtype: Type of the variables. 2906 variable_list: List of `Variable` that comprise this partitioned variable. 2907 partitions: List of integers. Number of partitions for each dimension. 2908 2909 Raises: 2910 TypeError: If `variable_list` is not a list of `Variable` objects, or 2911 `partitions` is not a list. 2912 ValueError: If `variable_list` is empty, or the `Variable` shape 2913 information does not match `shape`, or `partitions` has invalid values. 2914 """ 2915 if not isinstance(variable_list, (list, tuple)): 2916 raise TypeError("variable_list is not a list or tuple: %s" % 2917 variable_list) 2918 if not isinstance(partitions, (list, tuple)): 2919 raise TypeError("partitions is not a list or tuple: %s" % partitions) 2920 if not all(p >= 1 for p in partitions): 2921 raise ValueError("partition values must be positive: %s" % partitions) 2922 if not variable_list: 2923 raise ValueError("variable_list may not be empty") 2924 # pylint: disable=protected-access 2925 for v in variable_list: 2926 # Sort the variable_list lexicographically according to var offset value. 2927 if not all(v._get_save_slice_info() is not None for v in variable_list): 2928 raise ValueError( 2929 "All variables must have a save_slice_info available: %s" % 2930 [v.name for v in variable_list]) 2931 if len(shape) != len(partitions): 2932 raise ValueError("len(shape) != len(partitions): %s vs. %s" % 2933 (shape, partitions)) 2934 if v._get_save_slice_info().full_shape != shape: 2935 raise ValueError("All variables' full shapes must match shape: %s; " 2936 "but full shapes were: %s" % 2937 (shape, str([v._get_save_slice_info().full_shape]))) 2938 self._variable_list = sorted( 2939 variable_list, key=lambda v: v._get_save_slice_info().var_offset) 2940 # pylint: enable=protected-access 2941 2942 self._name = name 2943 self._shape = shape 2944 self._dtype = dtype 2945 self._partitions = partitions 2946 self._as_tensor = None 2947 2948 def __iter__(self): 2949 """Return an iterable for accessing the underlying partition Variables.""" 2950 return iter(self._variable_list) 2951 2952 def __len__(self): 2953 num_partition_axes = len(self._partition_axes()) 2954 if num_partition_axes > 1: 2955 raise ValueError("Cannot get a length for %d > 1 partition axes" % 2956 num_partition_axes) 2957 return len(self._variable_list) 2958 2959 def _partition_axes(self): 2960 if all(p == 1 for p in self._partitions): 2961 return [0] 2962 else: 2963 return [i for i, p in enumerate(self._partitions) if p > 1] 2964 2965 def _concat(self): 2966 """Returns the overall concatenated value as a `Tensor`. 2967 2968 This is different from using the partitioned variable directly as a tensor 2969 (through tensor conversion and `as_tensor`) in that it creates a new set of 2970 operations that keeps the control dependencies from its scope. 2971 2972 Returns: 2973 `Tensor` containing the concatenated value. 2974 """ 2975 if len(self._variable_list) == 1: 2976 with ops.name_scope(None): 2977 return array_ops.identity(self._variable_list[0], name=self._name) 2978 2979 partition_axes = self._partition_axes() 2980 2981 if len(partition_axes) > 1: 2982 raise NotImplementedError( 2983 "Cannot concatenate along more than one dimension: %s. " 2984 "Multi-axis partition concat is not supported" % str(partition_axes)) 2985 partition_ix = partition_axes[0] 2986 2987 with ops.name_scope(self._name + "/ConcatPartitions/"): 2988 concatenated = array_ops.concat(self._variable_list, partition_ix) 2989 2990 with ops.name_scope(None): 2991 return array_ops.identity(concatenated, name=self._name) 2992 2993 def as_tensor(self): 2994 """Returns the overall concatenated value as a `Tensor`. 2995 2996 The returned tensor will not inherit the control dependencies from the scope 2997 where the value is used, which is similar to getting the value of 2998 `Variable`. 2999 3000 Returns: 3001 `Tensor` containing the concatenated value. 3002 """ 3003 with ops.control_dependencies(None): 3004 return self._concat() 3005 3006 @staticmethod 3007 def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): 3008 # pylint: disable=invalid-name 3009 _ = name 3010 if dtype is not None and not dtype.is_compatible_with(v.dtype): 3011 raise ValueError( 3012 "Incompatible type conversion requested to type '%s' for variable " 3013 "of type '%s'" % (dtype.name, v.dtype.name)) 3014 if as_ref: 3015 raise NotImplementedError( 3016 "PartitionedVariable doesn't support being used as a reference.") 3017 else: 3018 return v.as_tensor() 3019 3020 @property 3021 def name(self): 3022 return self._name 3023 3024 @property 3025 def dtype(self): 3026 return self._dtype 3027 3028 @property 3029 def shape(self): 3030 return self.get_shape() 3031 3032 @property 3033 def _distribute_strategy(self): 3034 """The `tf.distribute.Strategy` that this variable was created under.""" 3035 # NOTE(yuefengz): Today, no partitioned variables in a distribute strategy. 3036 return None 3037 3038 def get_shape(self): 3039 return self._shape 3040 3041 def _get_variable_list(self): 3042 return self._variable_list 3043 3044 def _get_partitions(self): 3045 return self._partitions 3046 3047 def _apply_assign_fn(self, assign_fn, value): 3048 partition_axes = self._partition_axes() 3049 if len(partition_axes) > 1: 3050 raise NotImplementedError( 3051 "Cannot do assign action along more than one dimension: %s. " 3052 "Multi-axis partition assign action is not supported " % 3053 str(partition_axes)) 3054 if isinstance(value, list): 3055 assert len(value) == len(self._variable_list) 3056 value_list = value 3057 elif isinstance(value, PartitionedVariable): 3058 value_list = [var_part for var_part in value] 3059 else: 3060 partition_ix = partition_axes[0] 3061 size_splits_list = [ 3062 tensor_shape.dimension_value(var.shape[partition_ix]) 3063 for var in self._variable_list 3064 ] 3065 value_list = array_ops.split(value, size_splits_list, axis=partition_ix) 3066 3067 op_list = [ 3068 assign_fn(var, value_list[idx]) 3069 for idx, var in enumerate(self._variable_list) 3070 ] 3071 return op_list 3072 3073 def assign(self, value, use_locking=False, name=None, read_value=True): 3074 assign_fn = lambda var, r_value: var.assign( 3075 r_value, use_locking=use_locking, name=name, read_value=read_value) 3076 assign_list = self._apply_assign_fn(assign_fn, value) 3077 if read_value: 3078 return assign_list 3079 return [assign.op for assign in assign_list] 3080 3081 def assign_add(self, value, use_locking=False, name=None, read_value=True): 3082 assign_fn = lambda var, r_value: var.assign_add( 3083 r_value, use_locking=use_locking, name=name, read_value=read_value) 3084 assign_list = self._apply_assign_fn(assign_fn, value) 3085 if read_value: 3086 return assign_list 3087 return [assign.op for assign in assign_list] 3088 3089 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 3090 assign_fn = lambda var, r_value: var.assign_sub( 3091 r_value, use_locking=use_locking, name=name, read_value=read_value) 3092 assign_list = self._apply_assign_fn(assign_fn, value) 3093 if read_value: 3094 return assign_list 3095 return [assign.op for assign in assign_list] 3096 3097 3098# Register a conversion function which reads the value of the variable, 3099# allowing instances of the class to be used as tensors. 3100ops.register_tensor_conversion_function(RefVariable, 3101 RefVariable._TensorConversionFunction) # pylint: disable=protected-access 3102 3103 3104@tf_export(v1=["global_variables"]) 3105def global_variables(scope=None): 3106 """Returns global variables. 3107 3108 Global variables are variables that are shared across machines in a 3109 distributed environment. The `Variable()` constructor or `get_variable()` 3110 automatically adds new variables to the graph collection 3111 `GraphKeys.GLOBAL_VARIABLES`. 3112 This convenience function returns the contents of that collection. 3113 3114 An alternative to global variables are local variables. See 3115 `tf.compat.v1.local_variables` 3116 3117 @compatibility(TF2) 3118 Not compatible with eager execution and `tf.function`. In particular, Graph 3119 collections are deprecated in TF2. Instead please create a 3120 [tf.Module](https://www.tensorflow.org/guide/intro_to_modules) 3121 container for all your model state, including variables. 3122 You can then list all the variables in your `tf.Module` through the 3123 `variables` attribute. 3124 @end_compatibility 3125 3126 Args: 3127 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3128 include only items whose `name` attribute matches `scope` using 3129 `re.match`. Items without a `name` attribute are never returned if a scope 3130 is supplied. The choice of `re.match` means that a `scope` without special 3131 tokens filters by prefix. 3132 3133 Returns: 3134 A list of `Variable` objects. 3135 """ 3136 return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) 3137 3138 3139@tf_export(v1=["all_variables"]) 3140@deprecated("2017-03-02", "Please use tf.global_variables instead.") 3141def all_variables(): 3142 """Use `tf.compat.v1.global_variables` instead.""" 3143 return global_variables() 3144 3145 3146def _all_saveable_objects(scope=None): 3147 """Returns all variables and `SaveableObject`s that must be checkpointed. 3148 3149 Args: 3150 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3151 include only items whose `name` attribute matches `scope` using 3152 `re.match`. Items without a `name` attribute are never returned if a scope 3153 is supplied. The choice of `re.match` means that a `scope` without special 3154 tokens filters by prefix. 3155 3156 Returns: 3157 A list of `Variable` and `SaveableObject` to be checkpointed 3158 """ 3159 # TODO(andreasst): make this function public once things are settled. 3160 return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) + 3161 ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope)) 3162 3163 3164@tf_export(v1=["local_variables"]) 3165def local_variables(scope=None): 3166 """Returns local variables. 3167 3168 Local variables - per process variables, usually not saved/restored to 3169 checkpoint and used for temporary or intermediate values. 3170 For example, they can be used as counters for metrics computation or 3171 number of epochs this machine has read data. 3172 The `tf.contrib.framework.local_variable()` function automatically adds the 3173 new variable to `GraphKeys.LOCAL_VARIABLES`. 3174 This convenience function returns the contents of that collection. 3175 3176 An alternative to local variables are global variables. See 3177 `tf.compat.v1.global_variables` 3178 3179 Args: 3180 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3181 include only items whose `name` attribute matches `scope` using 3182 `re.match`. Items without a `name` attribute are never returned if a scope 3183 is supplied. The choice of `re.match` means that a `scope` without special 3184 tokens filters by prefix. 3185 3186 Returns: 3187 A list of local `Variable` objects. 3188 """ 3189 return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES, scope) 3190 3191 3192@tf_export(v1=["model_variables"]) 3193def model_variables(scope=None): 3194 """Returns all variables in the MODEL_VARIABLES collection. 3195 3196 Args: 3197 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3198 include only items whose `name` attribute matches `scope` using 3199 `re.match`. Items without a `name` attribute are never returned if a scope 3200 is supplied. The choice of `re.match` means that a `scope` without special 3201 tokens filters by prefix. 3202 3203 Returns: 3204 A list of local Variable objects. 3205 """ 3206 return ops.get_collection(ops.GraphKeys.MODEL_VARIABLES, scope) 3207 3208 3209@tf_export(v1=["trainable_variables"]) 3210def trainable_variables(scope=None): 3211 """Returns all variables created with `trainable=True`. 3212 3213 When passed `trainable=True`, the `Variable()` constructor automatically 3214 adds new variables to the graph collection 3215 `GraphKeys.TRAINABLE_VARIABLES`. This convenience function returns the 3216 contents of that collection. 3217 3218 @compatibility(TF2) 3219 Not compatible with eager execution and `tf.function`. In particular, Graph 3220 collections are deprecated in TF2. Instead please create a `tf.Module` 3221 container for all your model state, including variables. 3222 You can then list all the trainable variables in your `tf.Module` through the 3223 `trainable_variables` attribute. 3224 @end_compatibility 3225 3226 Args: 3227 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3228 include only items whose `name` attribute matches `scope` using 3229 `re.match`. Items without a `name` attribute are never returned if a scope 3230 is supplied. The choice of `re.match` means that a `scope` without special 3231 tokens filters by prefix. 3232 3233 Returns: 3234 A list of Variable objects. 3235 """ 3236 return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, scope) 3237 3238 3239@tf_export(v1=["moving_average_variables"]) 3240def moving_average_variables(scope=None): 3241 """Returns all variables that maintain their moving averages. 3242 3243 If an `ExponentialMovingAverage` object is created and the `apply()` 3244 method is called on a list of variables, these variables will 3245 be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection. 3246 This convenience function returns the contents of that collection. 3247 3248 Args: 3249 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3250 include only items whose `name` attribute matches `scope` using 3251 `re.match`. Items without a `name` attribute are never returned if a scope 3252 is supplied. The choice of `re.match` means that a `scope` without special 3253 tokens filters by prefix. 3254 3255 Returns: 3256 A list of Variable objects. 3257 """ 3258 return ops.get_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, scope) 3259 3260 3261@tf_export(v1=["initializers.variables", "variables_initializer"]) 3262def variables_initializer(var_list, name="init"): 3263 """Returns an Op that initializes a list of variables. 3264 3265 After you launch the graph in a session, you can run the returned Op to 3266 initialize all the variables in `var_list`. This Op runs all the 3267 initializers of the variables in `var_list` in parallel. 3268 3269 Calling `initialize_variables()` is equivalent to passing the list of 3270 initializers to `Group()`. 3271 3272 If `var_list` is empty, however, the function still returns an Op that can 3273 be run. That Op just has no effect. 3274 3275 @compatibility(TF2) 3276 In TF2, variables are initialized immediately when they are created. There is 3277 no longer a need to run variable initializers before using them. 3278 @end_compatibility 3279 3280 Args: 3281 var_list: List of `Variable` objects to initialize. 3282 name: Optional name for the returned operation. 3283 3284 Returns: 3285 An Op that run the initializers of all the specified variables. 3286 """ 3287 if var_list and not context.executing_eagerly(): 3288 return control_flow_ops.group(*[v.initializer for v in var_list], name=name) 3289 return control_flow_ops.no_op(name=name) 3290 3291 3292@tf_export(v1=["initialize_variables"]) 3293@tf_should_use.should_use_result 3294@deprecated("2017-03-02", "Use `tf.variables_initializer` instead.") 3295def initialize_variables(var_list, name="init"): 3296 """See `tf.compat.v1.variables_initializer`.""" 3297 return variables_initializer(var_list, name=name) 3298 3299 3300@tf_export(v1=["initializers.global_variables", "global_variables_initializer"]) 3301def global_variables_initializer(): 3302 """Returns an Op that initializes global variables. 3303 3304 This is just a shortcut for `variables_initializer(global_variables())` 3305 3306 @compatibility(TF2) 3307 In TF2, variables are initialized immediately when they are created. There is 3308 no longer a need to run variable initializers before using them. 3309 @end_compatibility 3310 3311 Returns: 3312 An Op that initializes global variables in the graph. 3313 """ 3314 if context.executing_eagerly(): 3315 return control_flow_ops.no_op(name="global_variables_initializer") 3316 return variables_initializer(global_variables()) 3317 3318 3319@tf_export(v1=["initialize_all_variables"]) 3320@tf_should_use.should_use_result 3321@deprecated("2017-03-02", "Use `tf.global_variables_initializer` instead.") 3322def initialize_all_variables(): 3323 """See `tf.compat.v1.global_variables_initializer`.""" 3324 return global_variables_initializer() 3325 3326 3327@tf_export(v1=["initializers.local_variables", "local_variables_initializer"]) 3328def local_variables_initializer(): 3329 """Returns an Op that initializes all local variables. 3330 3331 This is just a shortcut for `variables_initializer(local_variables())` 3332 3333 @compatibility(TF2) 3334 In TF2, variables are initialized immediately when they are created. There is 3335 no longer a need to run variable initializers before using them. 3336 @end_compatibility 3337 3338 Returns: 3339 An Op that initializes all local variables in the graph. 3340 """ 3341 if context.executing_eagerly(): 3342 return control_flow_ops.no_op(name="local_variables_initializer") 3343 return variables_initializer(local_variables()) 3344 3345 3346@tf_export(v1=["initialize_local_variables"]) 3347@tf_should_use.should_use_result 3348@deprecated("2017-03-02", "Use `tf.local_variables_initializer` instead.") 3349def initialize_local_variables(): 3350 """See `tf.compat.v1.local_variables_initializer`.""" 3351 return local_variables_initializer() 3352 3353 3354@tf_export(v1=["is_variable_initialized"]) 3355@tf_should_use.should_use_result 3356def is_variable_initialized(variable): 3357 """Tests if a variable has been initialized. 3358 3359 Args: 3360 variable: A `Variable`. 3361 3362 Returns: 3363 Returns a scalar boolean Tensor, `True` if the variable has been 3364 initialized, `False` otherwise. 3365 """ 3366 return state_ops.is_variable_initialized(variable) 3367 3368 3369@tf_export(v1=["assert_variables_initialized"]) 3370@tf_should_use.should_use_result 3371def assert_variables_initialized(var_list=None): 3372 """Returns an Op to check if variables are initialized. 3373 3374 NOTE: This function is obsolete and will be removed in 6 months. Please 3375 change your implementation to use `report_uninitialized_variables()`. 3376 3377 When run, the returned Op will raise the exception `FailedPreconditionError` 3378 if any of the variables has not yet been initialized. 3379 3380 Note: This function is implemented by trying to fetch the values of the 3381 variables. If one of the variables is not initialized a message may be 3382 logged by the C++ runtime. This is expected. 3383 3384 Args: 3385 var_list: List of `Variable` objects to check. Defaults to the value of 3386 `global_variables().` 3387 3388 Returns: 3389 An Op, or None if there are no variables. 3390 """ 3391 if var_list is None: 3392 var_list = global_variables() + local_variables() 3393 # Backwards compatibility for old-style variables. TODO(touts): remove. 3394 if not var_list: 3395 var_list = [] 3396 for op in ops.get_default_graph().get_operations(): 3397 if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]: 3398 var_list.append(op.outputs[0]) 3399 if not var_list: 3400 return None 3401 else: 3402 ranks = [] 3403 for var in var_list: 3404 with ops.colocate_with(var.op): 3405 ranks.append(array_ops.rank_internal(var, optimize=False)) 3406 if len(ranks) == 1: 3407 return ranks[0] 3408 else: 3409 return array_ops.stack(ranks) 3410 3411 3412@tf_export(v1=["report_uninitialized_variables"]) 3413@tf_should_use.should_use_result 3414def report_uninitialized_variables(var_list=None, 3415 name="report_uninitialized_variables"): 3416 """Adds ops to list the names of uninitialized variables. 3417 3418 When run, it returns a 1-D tensor containing the names of uninitialized 3419 variables if there are any, or an empty array if there are none. 3420 3421 Args: 3422 var_list: List of `Variable` objects to check. Defaults to the value of 3423 `global_variables() + local_variables()` 3424 name: Optional name of the `Operation`. 3425 3426 Returns: 3427 A 1-D tensor containing names of the uninitialized variables, or an empty 3428 1-D tensor if there are no variables or no uninitialized variables. 3429 """ 3430 if var_list is None: 3431 var_list = global_variables() + local_variables() 3432 # Backwards compatibility for old-style variables. TODO(touts): remove. 3433 if not var_list: 3434 var_list = [] 3435 for op in ops.get_default_graph().get_operations(): 3436 if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]: 3437 var_list.append(op.outputs[0]) 3438 with ops.name_scope(name): 3439 # Run all operations on CPU 3440 if var_list: 3441 init_vars = [state_ops.is_variable_initialized(v) for v in var_list] 3442 local_device = os.environ.get( 3443 "TF_DEVICE_FOR_UNINITIALIZED_VARIABLE_REPORTING", "/cpu:0") 3444 with ops.device(local_device): 3445 if not var_list: 3446 # Return an empty tensor so we only need to check for returned tensor 3447 # size being 0 as an indication of model ready. 3448 return array_ops.constant([], dtype=dtypes.string) 3449 else: 3450 # Get a 1-D boolean tensor listing whether each variable is initialized. 3451 variables_mask = math_ops.logical_not(array_ops.stack(init_vars)) 3452 # Get a 1-D string tensor containing all the variable names. 3453 variable_names_tensor = array_ops.constant( 3454 [s.op.name for s in var_list]) 3455 # Return a 1-D tensor containing all the names of 3456 # uninitialized variables. 3457 return array_ops.boolean_mask(variable_names_tensor, variables_mask) 3458 3459 3460ops.register_tensor_conversion_function( 3461 PartitionedVariable, PartitionedVariable._TensorConversionFunction) # pylint: disable=protected-access 3462