1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Variable class.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import enum # pylint: disable=g-bad-import-order 21import itertools 22import functools 23import os 24 25import six 26 27from tensorflow.core.framework import attr_value_pb2 28from tensorflow.core.framework import variable_pb2 29from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import 30from tensorflow.python.eager import context 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import tensor_shape 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import control_flow_ops 36from tensorflow.python.ops import gen_array_ops 37from tensorflow.python.ops import gen_state_ops 38from tensorflow.python.ops import gen_math_ops 39from tensorflow.python.ops import math_ops 40from tensorflow.python.ops import state_ops 41from tensorflow.python.platform import tf_logging as logging 42from tensorflow.python.training.tracking import base as trackable 43from tensorflow.python.util import _pywrap_utils 44from tensorflow.python.util import compat 45from tensorflow.python.util import object_identity 46from tensorflow.python.util import tf_should_use 47from tensorflow.python.util.deprecation import deprecated 48from tensorflow.python.util.deprecation import deprecated_args 49from tensorflow.python.util import traceback_utils 50from tensorflow.python.util.tf_export import tf_export 51from tensorflow.python.types import core 52 53 54def default_variable_creator(_, **kwds): 55 del kwds 56 raise NotImplementedError("variable_scope needs to be imported") 57 58 59def default_variable_creator_v2(_, **kwds): 60 del kwds 61 raise NotImplementedError("variable_scope needs to be imported") 62 63 64def _make_getter(captured_getter, captured_previous): 65 """To avoid capturing loop variables.""" 66 67 def getter(**kwargs): 68 return captured_getter(captured_previous, **kwargs) 69 70 return getter 71 72 73@tf_export("VariableSynchronization") 74class VariableSynchronization(enum.Enum): 75 """Indicates when a distributed variable will be synced. 76 77 * `AUTO`: Indicates that the synchronization will be determined by the current 78 `DistributionStrategy` (eg. With `MirroredStrategy` this would be 79 `ON_WRITE`). 80 * `NONE`: Indicates that there will only be one copy of the variable, so 81 there is no need to sync. 82 * `ON_WRITE`: Indicates that the variable will be updated across devices 83 every time it is written. 84 * `ON_READ`: Indicates that the variable will be aggregated across devices 85 when it is read (eg. when checkpointing or when evaluating an op that uses 86 the variable). 87 88 Example: 89 >>> temp_grad=[tf.Variable([0.], trainable=False, 90 ... synchronization=tf.VariableSynchronization.ON_READ, 91 ... aggregation=tf.VariableAggregation.MEAN 92 ... )] 93 """ 94 AUTO = 0 95 NONE = 1 96 ON_WRITE = 2 97 ON_READ = 3 98 99 100# LINT.IfChange 101@tf_export("VariableAggregation", v1=[]) 102class VariableAggregationV2(enum.Enum): 103 """Indicates how a distributed variable will be aggregated. 104 105 `tf.distribute.Strategy` distributes a model by making multiple copies 106 (called "replicas") acting data-parallel on different elements of the input 107 batch. When performing some variable-update operation, say 108 `var.assign_add(x)`, in a model, we need to resolve how to combine the 109 different values for `x` computed in the different replicas. 110 111 * `NONE`: This is the default, giving an error if you use a 112 variable-update operation with multiple replicas. 113 * `SUM`: Add the updates across replicas. 114 * `MEAN`: Take the arithmetic mean ("average") of the updates across replicas. 115 * `ONLY_FIRST_REPLICA`: This is for when every replica is performing the same 116 update, but we only want to perform the update once. Used, e.g., for the 117 global step counter. 118 """ 119 NONE = 0 120 SUM = 1 121 MEAN = 2 122 ONLY_FIRST_REPLICA = 3 123 124 def __hash__(self): 125 return hash(self.value) 126 127 def __eq__(self, other): 128 if self is other: 129 return True 130 elif isinstance(other, VariableAggregation): 131 return int(self.value) == int(other.value) 132 else: 133 return False 134 135 136@tf_export(v1=["VariableAggregation"]) 137class VariableAggregation(enum.Enum): 138 NONE = 0 139 SUM = 1 140 MEAN = 2 141 ONLY_FIRST_REPLICA = 3 142 ONLY_FIRST_TOWER = 3 # DEPRECATED 143 144 def __hash__(self): 145 return hash(self.value) 146 147 148# LINT.ThenChange(//tensorflow/core/framework/variable.proto) 149# 150# Note that we are currently relying on the integer values of the Python enums 151# matching the integer values of the proto enums. 152 153VariableAggregation.__doc__ = ( 154 VariableAggregationV2.__doc__ + 155 "* `ONLY_FIRST_TOWER`: Deprecated alias for `ONLY_FIRST_REPLICA`.\n ") 156 157 158def validate_synchronization_aggregation_trainable(synchronization, aggregation, 159 trainable, name): 160 """Given user-provided variable properties, sets defaults and validates.""" 161 if aggregation is None: 162 aggregation = VariableAggregation.NONE 163 else: 164 if not isinstance(aggregation, 165 (VariableAggregation, VariableAggregationV2)): 166 try: 167 aggregation = VariableAggregationV2(aggregation) 168 except ValueError: 169 raise ValueError( 170 "Invalid variable aggregation mode: {} for variable: {}".format( 171 aggregation, name)) 172 if synchronization is None: 173 synchronization = VariableSynchronization.AUTO 174 else: 175 try: 176 synchronization = VariableSynchronization(synchronization) 177 except ValueError: 178 raise ValueError( 179 "Invalid variable synchronization mode: {} for variable: {}".format( 180 synchronization, name)) 181 if trainable is None: 182 trainable = synchronization != VariableSynchronization.ON_READ 183 return synchronization, aggregation, trainable 184 185 186class VariableMetaclass(type): 187 """Metaclass to allow construction of tf.Variable to be overridden.""" 188 189 def _variable_v1_call(cls, 190 initial_value=None, 191 trainable=None, 192 collections=None, 193 validate_shape=True, 194 caching_device=None, 195 name=None, 196 variable_def=None, 197 dtype=None, 198 expected_shape=None, 199 import_scope=None, 200 constraint=None, 201 use_resource=None, 202 synchronization=VariableSynchronization.AUTO, 203 aggregation=VariableAggregation.NONE, 204 shape=None): 205 """Call on Variable class. Useful to force the signature.""" 206 previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs) 207 for _, getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access 208 previous_getter = _make_getter(getter, previous_getter) 209 210 # Reset `aggregation` that is explicitly set as `None` to the enum NONE. 211 if aggregation is None: 212 aggregation = VariableAggregation.NONE 213 return previous_getter( 214 initial_value=initial_value, 215 trainable=trainable, 216 collections=collections, 217 validate_shape=validate_shape, 218 caching_device=caching_device, 219 name=name, 220 variable_def=variable_def, 221 dtype=dtype, 222 expected_shape=expected_shape, 223 import_scope=import_scope, 224 constraint=constraint, 225 use_resource=use_resource, 226 synchronization=synchronization, 227 aggregation=aggregation, 228 shape=shape) 229 230 def _variable_v2_call(cls, 231 initial_value=None, 232 trainable=None, 233 validate_shape=True, 234 caching_device=None, 235 name=None, 236 variable_def=None, 237 dtype=None, 238 import_scope=None, 239 constraint=None, 240 synchronization=VariableSynchronization.AUTO, 241 aggregation=VariableAggregation.NONE, 242 shape=None): 243 """Call on Variable class. Useful to force the signature.""" 244 previous_getter = lambda **kws: default_variable_creator_v2(None, **kws) 245 for _, getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access 246 previous_getter = _make_getter(getter, previous_getter) 247 248 # Reset `aggregation` that is explicitly set as `None` to the enum NONE. 249 if aggregation is None: 250 aggregation = VariableAggregation.NONE 251 return previous_getter( 252 initial_value=initial_value, 253 trainable=trainable, 254 validate_shape=validate_shape, 255 caching_device=caching_device, 256 name=name, 257 variable_def=variable_def, 258 dtype=dtype, 259 import_scope=import_scope, 260 constraint=constraint, 261 synchronization=synchronization, 262 aggregation=aggregation, 263 shape=shape) 264 265 @traceback_utils.filter_traceback 266 def __call__(cls, *args, **kwargs): 267 if cls is VariableV1: 268 return cls._variable_v1_call(*args, **kwargs) 269 elif cls is Variable: 270 return cls._variable_v2_call(*args, **kwargs) 271 else: 272 return super(VariableMetaclass, cls).__call__(*args, **kwargs) 273 274 275@tf_export("Variable", v1=[]) 276# TODO(mdan): This should subclass core.Tensor, and not all its subclasses? 277class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)): 278 """See the [variable guide](https://tensorflow.org/guide/variable). 279 280 A variable maintains shared, persistent state manipulated by a program. 281 282 The `Variable()` constructor requires an initial value for the variable, which 283 can be a `Tensor` of any type and shape. This initial value defines the type 284 and shape of the variable. After construction, the type and shape of the 285 variable are fixed. The value can be changed using one of the assign methods. 286 287 >>> v = tf.Variable(1.) 288 >>> v.assign(2.) 289 <tf.Variable ... shape=() dtype=float32, numpy=2.0> 290 >>> v.assign_add(0.5) 291 <tf.Variable ... shape=() dtype=float32, numpy=2.5> 292 293 The `shape` argument to `Variable`'s constructor allows you to construct a 294 variable with a less defined shape than its `initial_value`: 295 296 >>> v = tf.Variable(1., shape=tf.TensorShape(None)) 297 >>> v.assign([[1.]]) 298 <tf.Variable ... shape=<unknown> dtype=float32, numpy=array([[1.]], ...)> 299 300 Just like any `Tensor`, variables created with `Variable()` can be used as 301 inputs to operations. Additionally, all the operators overloaded for the 302 `Tensor` class are carried over to variables. 303 304 >>> w = tf.Variable([[1.], [2.]]) 305 >>> x = tf.constant([[3., 4.]]) 306 >>> tf.matmul(w, x) 307 <tf.Tensor:... shape=(2, 2), ... numpy= 308 array([[3., 4.], 309 [6., 8.]], dtype=float32)> 310 >>> tf.sigmoid(w + x) 311 <tf.Tensor:... shape=(2, 2), ...> 312 313 When building a machine learning model it is often convenient to distinguish 314 between variables holding trainable model parameters and other variables such 315 as a `step` variable used to count training steps. To make this easier, the 316 variable constructor supports a `trainable=<bool>` 317 parameter. `tf.GradientTape` watches trainable variables by default: 318 319 >>> with tf.GradientTape(persistent=True) as tape: 320 ... trainable = tf.Variable(1.) 321 ... non_trainable = tf.Variable(2., trainable=False) 322 ... x1 = trainable * 2. 323 ... x2 = non_trainable * 3. 324 >>> tape.gradient(x1, trainable) 325 <tf.Tensor:... shape=(), dtype=float32, numpy=2.0> 326 >>> assert tape.gradient(x2, non_trainable) is None # Unwatched 327 328 Variables are automatically tracked when assigned to attributes of types 329 inheriting from `tf.Module`. 330 331 >>> m = tf.Module() 332 >>> m.v = tf.Variable([1.]) 333 >>> m.trainable_variables 334 (<tf.Variable ... shape=(1,) ... numpy=array([1.], dtype=float32)>,) 335 336 This tracking then allows saving variable values to 337 [training checkpoints](https://www.tensorflow.org/guide/checkpoint), or to 338 [SavedModels](https://www.tensorflow.org/guide/saved_model) which include 339 serialized TensorFlow graphs. 340 341 Variables are often captured and manipulated by `tf.function`s. This works the 342 same way the un-decorated function would have: 343 344 >>> v = tf.Variable(0.) 345 >>> read_and_decrement = tf.function(lambda: v.assign_sub(0.1)) 346 >>> read_and_decrement() 347 <tf.Tensor: shape=(), dtype=float32, numpy=-0.1> 348 >>> read_and_decrement() 349 <tf.Tensor: shape=(), dtype=float32, numpy=-0.2> 350 351 Variables created inside a `tf.function` must be owned outside the function 352 and be created only once: 353 354 >>> class M(tf.Module): 355 ... @tf.function 356 ... def __call__(self, x): 357 ... if not hasattr(self, "v"): # Or set self.v to None in __init__ 358 ... self.v = tf.Variable(x) 359 ... return self.v * x 360 >>> m = M() 361 >>> m(2.) 362 <tf.Tensor: shape=(), dtype=float32, numpy=4.0> 363 >>> m(3.) 364 <tf.Tensor: shape=(), dtype=float32, numpy=6.0> 365 >>> m.v 366 <tf.Variable ... shape=() dtype=float32, numpy=2.0> 367 368 See the `tf.function` documentation for details. 369 """ 370 371 @deprecated_args( 372 None, "A variable's value can be manually cached by calling " 373 "tf.Variable.read_value() under a tf.device scope. The caching_device " 374 "argument does not work properly.", "caching_device") 375 def __init__(self, 376 initial_value=None, 377 trainable=None, 378 validate_shape=True, 379 caching_device=None, 380 name=None, 381 variable_def=None, 382 dtype=None, 383 import_scope=None, 384 constraint=None, 385 synchronization=VariableSynchronization.AUTO, 386 aggregation=VariableAggregation.NONE, 387 shape=None): 388 """Creates a new variable with value `initial_value`. 389 390 Args: 391 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 392 which is the initial value for the Variable. The initial value must have 393 a shape specified unless `validate_shape` is set to False. Can also be a 394 callable with no argument that returns the initial value when called. In 395 that case, `dtype` must be specified. (Note that initializer functions 396 from init_ops.py must first be bound to a shape before being used here.) 397 trainable: If `True`, GradientTapes automatically watch uses of this 398 variable. Defaults to `True`, unless `synchronization` is set to 399 `ON_READ`, in which case it defaults to `False`. 400 validate_shape: If `False`, allows the variable to be initialized with a 401 value of unknown shape. If `True`, the default, the shape of 402 `initial_value` must be known. 403 caching_device: Note: This argument is only valid when using a v1-style 404 `Session`. Optional device string describing where the Variable should 405 be cached for reading. Defaults to the Variable's device. If not `None`, 406 caches on another device. Typical use is to cache on the device where 407 the Ops using the Variable reside, to deduplicate copying through 408 `Switch` and other conditional statements. 409 name: Optional name for the variable. Defaults to `'Variable'` and gets 410 uniquified automatically. 411 variable_def: `VariableDef` protocol buffer. If not `None`, recreates the 412 Variable object with its contents, referencing the variable's nodes in 413 the graph, which must already exist. The graph is not changed. 414 `variable_def` and the other arguments are mutually exclusive. 415 dtype: If set, initial_value will be converted to the given type. If 416 `None`, either the datatype will be kept (if `initial_value` is a 417 Tensor), or `convert_to_tensor` will decide. 418 import_scope: Optional `string`. Name scope to add to the `Variable.` Only 419 used when initializing from protocol buffer. 420 constraint: An optional projection function to be applied to the variable 421 after being updated by an `Optimizer` (e.g. used to implement norm 422 constraints or value constraints for layer weights). The function must 423 take as input the unprojected Tensor representing the value of the 424 variable and return the Tensor for the projected value (which must have 425 the same shape). Constraints are not safe to use when doing asynchronous 426 distributed training. 427 synchronization: Indicates when a distributed a variable will be 428 aggregated. Accepted values are constants defined in the class 429 `tf.VariableSynchronization`. By default the synchronization is set to 430 `AUTO` and the current `DistributionStrategy` chooses when to 431 synchronize. 432 aggregation: Indicates how a distributed variable will be aggregated. 433 Accepted values are constants defined in the class 434 `tf.VariableAggregation`. 435 shape: (optional) The shape of this variable. If None, the shape of 436 `initial_value` will be used. When setting this argument to 437 `tf.TensorShape(None)` (representing an unspecified shape), the variable 438 can be assigned with values of different shapes. 439 440 Raises: 441 ValueError: If both `variable_def` and initial_value are specified. 442 ValueError: If the initial value is not specified, or does not have a 443 shape and `validate_shape` is `True`. 444 """ 445 raise NotImplementedError 446 447 def __repr__(self): 448 raise NotImplementedError 449 450 def value(self): 451 """Returns the last snapshot of this variable. 452 453 You usually do not need to call this method as all ops that need the value 454 of the variable call it automatically through a `convert_to_tensor()` call. 455 456 Returns a `Tensor` which holds the value of the variable. You can not 457 assign a new value to this tensor as it is not a reference to the variable. 458 459 To avoid copies, if the consumer of the returned value is on the same device 460 as the variable, this actually returns the live value of the variable, not 461 a copy. Updates to the variable are seen by the consumer. If the consumer 462 is on a different device it will get a copy of the variable. 463 464 Returns: 465 A `Tensor` containing the value of the variable. 466 """ 467 raise NotImplementedError 468 469 def read_value(self): 470 """Returns the value of this variable, read in the current context. 471 472 Can be different from value() if it's on another device, with control 473 dependencies, etc. 474 475 Returns: 476 A `Tensor` containing the value of the variable. 477 """ 478 raise NotImplementedError 479 480 def set_shape(self, shape): 481 """Overrides the shape for this variable. 482 483 Args: 484 shape: the `TensorShape` representing the overridden shape. 485 """ 486 raise NotImplementedError 487 488 @property 489 def trainable(self): 490 raise NotImplementedError 491 492 @property 493 def synchronization(self): 494 raise NotImplementedError 495 496 @property 497 def aggregation(self): 498 raise NotImplementedError 499 500 def eval(self, session=None): 501 """In a session, computes and returns the value of this variable. 502 503 This is not a graph construction method, it does not add ops to the graph. 504 505 This convenience method requires a session where the graph 506 containing this variable has been launched. If no session is 507 passed, the default session is used. See `tf.compat.v1.Session` for more 508 information on launching a graph and on sessions. 509 510 ```python 511 v = tf.Variable([1, 2]) 512 init = tf.compat.v1.global_variables_initializer() 513 514 with tf.compat.v1.Session() as sess: 515 sess.run(init) 516 # Usage passing the session explicitly. 517 print(v.eval(sess)) 518 # Usage with the default session. The 'with' block 519 # above makes 'sess' the default session. 520 print(v.eval()) 521 ``` 522 523 Args: 524 session: The session to use to evaluate this variable. If none, the 525 default session is used. 526 527 Returns: 528 A numpy `ndarray` with a copy of the value of this variable. 529 """ 530 raise NotImplementedError 531 532 @deprecated( 533 None, "Use Variable.read_value. Variables in 2.X are initialized " 534 "automatically both in eager and graph (inside tf.defun) contexts.") 535 def initialized_value(self): 536 """Returns the value of the initialized variable. 537 538 You should use this instead of the variable itself to initialize another 539 variable with a value that depends on the value of this variable. 540 541 ```python 542 # Initialize 'v' with a random tensor. 543 v = tf.Variable(tf.random.truncated_normal([10, 40])) 544 # Use `initialized_value` to guarantee that `v` has been 545 # initialized before its value is used to initialize `w`. 546 # The random values are picked only once. 547 w = tf.Variable(v.initialized_value() * 2.0) 548 ``` 549 550 Returns: 551 A `Tensor` holding the value of this variable after its initializer 552 has run. 553 """ 554 with ops.init_scope(): 555 return control_flow_ops.cond( 556 is_variable_initialized(self), self.read_value, 557 lambda: self.initial_value) 558 559 @property 560 def initial_value(self): 561 """Returns the Tensor used as the initial value for the variable. 562 563 Note that this is different from `initialized_value()` which runs 564 the op that initializes the variable before returning its value. 565 This method returns the tensor that is used by the op that initializes 566 the variable. 567 568 Returns: 569 A `Tensor`. 570 """ 571 raise NotImplementedError 572 573 @property 574 def constraint(self): 575 """Returns the constraint function associated with this variable. 576 577 Returns: 578 The constraint function that was passed to the variable constructor. 579 Can be `None` if no constraint was passed. 580 """ 581 raise NotImplementedError 582 583 def assign(self, value, use_locking=False, name=None, read_value=True): 584 """Assigns a new value to the variable. 585 586 This is essentially a shortcut for `assign(self, value)`. 587 588 Args: 589 value: A `Tensor`. The new value for this variable. 590 use_locking: If `True`, use locking during the assignment. 591 name: The name of the operation to be created 592 read_value: if True, will return something which evaluates to the new 593 value of the variable; if False will return the assign op. 594 595 Returns: 596 The updated variable. If `read_value` is false, instead returns None in 597 Eager mode and the assign op in graph mode. 598 """ 599 raise NotImplementedError 600 601 def assign_add(self, delta, use_locking=False, name=None, read_value=True): 602 """Adds a value to this variable. 603 604 This is essentially a shortcut for `assign_add(self, delta)`. 605 606 Args: 607 delta: A `Tensor`. The value to add to this variable. 608 use_locking: If `True`, use locking during the operation. 609 name: The name of the operation to be created 610 read_value: if True, will return something which evaluates to the new 611 value of the variable; if False will return the assign op. 612 613 Returns: 614 The updated variable. If `read_value` is false, instead returns None in 615 Eager mode and the assign op in graph mode. 616 """ 617 raise NotImplementedError 618 619 def assign_sub(self, delta, use_locking=False, name=None, read_value=True): 620 """Subtracts a value from this variable. 621 622 This is essentially a shortcut for `assign_sub(self, delta)`. 623 624 Args: 625 delta: A `Tensor`. The value to subtract from this variable. 626 use_locking: If `True`, use locking during the operation. 627 name: The name of the operation to be created 628 read_value: if True, will return something which evaluates to the new 629 value of the variable; if False will return the assign op. 630 631 Returns: 632 The updated variable. If `read_value` is false, instead returns None in 633 Eager mode and the assign op in graph mode. 634 """ 635 raise NotImplementedError 636 637 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 638 """Subtracts `tf.IndexedSlices` from this variable. 639 640 Args: 641 sparse_delta: `tf.IndexedSlices` to be subtracted from this variable. 642 use_locking: If `True`, use locking during the operation. 643 name: the name of the operation. 644 645 Returns: 646 The updated variable. 647 648 Raises: 649 TypeError: if `sparse_delta` is not an `IndexedSlices`. 650 """ 651 raise NotImplementedError 652 653 def scatter_add(self, sparse_delta, use_locking=False, name=None): 654 """Adds `tf.IndexedSlices` to this variable. 655 656 Args: 657 sparse_delta: `tf.IndexedSlices` to be added to this variable. 658 use_locking: If `True`, use locking during the operation. 659 name: the name of the operation. 660 661 Returns: 662 The updated variable. 663 664 Raises: 665 TypeError: if `sparse_delta` is not an `IndexedSlices`. 666 """ 667 raise NotImplementedError 668 669 def scatter_max(self, sparse_delta, use_locking=False, name=None): 670 """Updates this variable with the max of `tf.IndexedSlices` and itself. 671 672 Args: 673 sparse_delta: `tf.IndexedSlices` to use as an argument of max with this 674 variable. 675 use_locking: If `True`, use locking during the operation. 676 name: the name of the operation. 677 678 Returns: 679 The updated variable. 680 681 Raises: 682 TypeError: if `sparse_delta` is not an `IndexedSlices`. 683 """ 684 raise NotImplementedError 685 686 def scatter_min(self, sparse_delta, use_locking=False, name=None): 687 """Updates this variable with the min of `tf.IndexedSlices` and itself. 688 689 Args: 690 sparse_delta: `tf.IndexedSlices` to use as an argument of min with this 691 variable. 692 use_locking: If `True`, use locking during the operation. 693 name: the name of the operation. 694 695 Returns: 696 The updated variable. 697 698 Raises: 699 TypeError: if `sparse_delta` is not an `IndexedSlices`. 700 """ 701 raise NotImplementedError 702 703 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 704 """Multiply this variable by `tf.IndexedSlices`. 705 706 Args: 707 sparse_delta: `tf.IndexedSlices` to multiply this variable by. 708 use_locking: If `True`, use locking during the operation. 709 name: the name of the operation. 710 711 Returns: 712 The updated variable. 713 714 Raises: 715 TypeError: if `sparse_delta` is not an `IndexedSlices`. 716 """ 717 raise NotImplementedError 718 719 def scatter_div(self, sparse_delta, use_locking=False, name=None): 720 """Divide this variable by `tf.IndexedSlices`. 721 722 Args: 723 sparse_delta: `tf.IndexedSlices` to divide this variable by. 724 use_locking: If `True`, use locking during the operation. 725 name: the name of the operation. 726 727 Returns: 728 The updated variable. 729 730 Raises: 731 TypeError: if `sparse_delta` is not an `IndexedSlices`. 732 """ 733 raise NotImplementedError 734 735 def scatter_update(self, sparse_delta, use_locking=False, name=None): 736 """Assigns `tf.IndexedSlices` to this variable. 737 738 Args: 739 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 740 use_locking: If `True`, use locking during the operation. 741 name: the name of the operation. 742 743 Returns: 744 The updated variable. 745 746 Raises: 747 TypeError: if `sparse_delta` is not an `IndexedSlices`. 748 """ 749 raise NotImplementedError 750 751 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 752 """Assigns `tf.IndexedSlices` to this variable batch-wise. 753 754 Analogous to `batch_gather`. This assumes that this variable and the 755 sparse_delta IndexedSlices have a series of leading dimensions that are the 756 same for all of them, and the updates are performed on the last dimension of 757 indices. In other words, the dimensions should be the following: 758 759 `num_prefix_dims = sparse_delta.indices.ndims - 1` 760 `batch_dim = num_prefix_dims + 1` 761 `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[ 762 batch_dim:]` 763 764 where 765 766 `sparse_delta.updates.shape[:num_prefix_dims]` 767 `== sparse_delta.indices.shape[:num_prefix_dims]` 768 `== var.shape[:num_prefix_dims]` 769 770 And the operation performed can be expressed as: 771 772 `var[i_1, ..., i_n, 773 sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[ 774 i_1, ..., i_n, j]` 775 776 When sparse_delta.indices is a 1D tensor, this operation is equivalent to 777 `scatter_update`. 778 779 To avoid this operation one can looping over the first `ndims` of the 780 variable and using `scatter_update` on the subtensors that result of slicing 781 the first dimension. This is a valid option for `ndims = 1`, but less 782 efficient than this implementation. 783 784 Args: 785 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 786 use_locking: If `True`, use locking during the operation. 787 name: the name of the operation. 788 789 Returns: 790 The updated variable. 791 792 Raises: 793 TypeError: if `sparse_delta` is not an `IndexedSlices`. 794 """ 795 raise NotImplementedError 796 797 def scatter_nd_sub(self, indices, updates, name=None): 798 """Applies sparse subtraction to individual values or slices in a Variable. 799 800 Assuming the variable has rank `P` and `indices` is a `Tensor` of rank `Q`. 801 802 `indices` must be integer tensor, containing indices into self. 803 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 804 805 The innermost dimension of `indices` (with length `K`) corresponds to 806 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 807 dimension of self. 808 809 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 810 811 ``` 812 [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]]. 813 ``` 814 815 For example, say we want to add 4 scattered elements to a rank-1 tensor to 816 8 elements. In Python, that update would look like this: 817 818 ```python 819 v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 820 indices = tf.constant([[4], [3], [1] ,[7]]) 821 updates = tf.constant([9, 10, 11, 12]) 822 v.scatter_nd_sub(indices, updates) 823 print(v) 824 ``` 825 826 After the update `v` would look like this: 827 828 [1, -9, 3, -6, -4, 6, 7, -4] 829 830 See `tf.scatter_nd` for more details about how to make updates to 831 slices. 832 833 Args: 834 indices: The indices to be used in the operation. 835 updates: The values to be used in the operation. 836 name: the name of the operation. 837 838 Returns: 839 The updated variable. 840 """ 841 raise NotImplementedError 842 843 def scatter_nd_add(self, indices, updates, name=None): 844 """Applies sparse addition to individual values or slices in a Variable. 845 846 The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`. 847 848 `indices` must be integer tensor, containing indices into self. 849 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 850 851 The innermost dimension of `indices` (with length `K`) corresponds to 852 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 853 dimension of self. 854 855 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 856 857 ``` 858 [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]]. 859 ``` 860 861 For example, say we want to add 4 scattered elements to a rank-1 tensor to 862 8 elements. In Python, that update would look like this: 863 864 ```python 865 v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 866 indices = tf.constant([[4], [3], [1] ,[7]]) 867 updates = tf.constant([9, 10, 11, 12]) 868 v.scatter_nd_add(indices, updates) 869 print(v) 870 ``` 871 872 The resulting update to v would look like this: 873 874 [1, 13, 3, 14, 14, 6, 7, 20] 875 876 See `tf.scatter_nd` for more details about how to make updates to 877 slices. 878 879 Args: 880 indices: The indices to be used in the operation. 881 updates: The values to be used in the operation. 882 name: the name of the operation. 883 884 Returns: 885 The updated variable. 886 """ 887 raise NotImplementedError 888 889 def scatter_nd_update(self, indices, updates, name=None): 890 """Applies sparse assignment to individual values or slices in a Variable. 891 892 The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`. 893 894 `indices` must be integer tensor, containing indices into self. 895 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 896 897 The innermost dimension of `indices` (with length `K`) corresponds to 898 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 899 dimension of self. 900 901 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 902 903 ``` 904 [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]]. 905 ``` 906 907 For example, say we want to add 4 scattered elements to a rank-1 tensor to 908 8 elements. In Python, that update would look like this: 909 910 ```python 911 v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 912 indices = tf.constant([[4], [3], [1] ,[7]]) 913 updates = tf.constant([9, 10, 11, 12]) 914 v.scatter_nd_update(indices, updates) 915 print(v) 916 ``` 917 918 The resulting update to v would look like this: 919 920 [1, 11, 3, 10, 9, 6, 7, 12] 921 922 See `tf.scatter_nd` for more details about how to make updates to 923 slices. 924 925 Args: 926 indices: The indices to be used in the operation. 927 updates: The values to be used in the operation. 928 name: the name of the operation. 929 930 Returns: 931 The updated variable. 932 """ 933 raise NotImplementedError 934 935 def sparse_read(self, indices, name=None): 936 r"""Gather slices from params axis axis according to indices. 937 938 This function supports a subset of tf.gather, see tf.gather for details on 939 usage. 940 941 Args: 942 indices: The index `Tensor`. Must be one of the following types: `int32`, 943 `int64`. Must be in range `[0, params.shape[axis])`. 944 name: A name for the operation (optional). 945 946 Returns: 947 A `Tensor`. Has the same type as `params`. 948 """ 949 raise AttributeError 950 951 def gather_nd(self, indices, name=None): 952 r"""Gather slices from `params` into a Tensor with shape specified by `indices`. 953 954 See tf.gather_nd for details. 955 956 Args: 957 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 958 Index tensor. 959 name: A name for the operation (optional). 960 961 Returns: 962 A `Tensor`. Has the same type as `params`. 963 """ 964 raise AttributeError 965 966 @deprecated(None, "Prefer Dataset.range instead.") 967 def count_up_to(self, limit): 968 """Increments this variable until it reaches `limit`. 969 970 When that Op is run it tries to increment the variable by `1`. If 971 incrementing the variable would bring it above `limit` then the Op raises 972 the exception `OutOfRangeError`. 973 974 If no error is raised, the Op outputs the value of the variable before 975 the increment. 976 977 This is essentially a shortcut for `count_up_to(self, limit)`. 978 979 Args: 980 limit: value at which incrementing the variable raises an error. 981 982 Returns: 983 A `Tensor` that will hold the variable value before the increment. If no 984 other Op modifies this variable, the values produced will all be 985 distinct. 986 """ 987 raise NotImplementedError 988 989 @deprecated(None, 990 "Prefer Variable.assign which has equivalent behavior in 2.X.") 991 def load(self, value, session=None): 992 """Load new value into this variable. 993 994 Writes new value to variable's memory. Doesn't add ops to the graph. 995 996 This convenience method requires a session where the graph 997 containing this variable has been launched. If no session is 998 passed, the default session is used. See `tf.compat.v1.Session` for more 999 information on launching a graph and on sessions. 1000 1001 ```python 1002 v = tf.Variable([1, 2]) 1003 init = tf.compat.v1.global_variables_initializer() 1004 1005 with tf.compat.v1.Session() as sess: 1006 sess.run(init) 1007 # Usage passing the session explicitly. 1008 v.load([2, 3], sess) 1009 print(v.eval(sess)) # prints [2 3] 1010 # Usage with the default session. The 'with' block 1011 # above makes 'sess' the default session. 1012 v.load([3, 4], sess) 1013 print(v.eval()) # prints [3 4] 1014 ``` 1015 1016 Args: 1017 value: New variable value 1018 session: The session to use to evaluate this variable. If none, the 1019 default session is used. 1020 1021 Raises: 1022 ValueError: Session is not passed and no default session 1023 """ 1024 if context.executing_eagerly(): 1025 self.assign(value) 1026 else: 1027 session = session or ops.get_default_session() 1028 if session is None: 1029 raise ValueError( 1030 "Either session argument should be provided or default session " 1031 "should be established") 1032 session.run(self.initializer, {self.initializer.inputs[1]: value}) 1033 1034 # Conversion to tensor. 1035 @staticmethod 1036 def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name 1037 """Utility function for converting a Variable to a Tensor.""" 1038 _ = name 1039 if dtype and not dtype.is_compatible_with(v.dtype): 1040 raise ValueError( 1041 "Incompatible type conversion requested to type '%s' for variable " 1042 "of type '%s'" % (dtype.name, v.dtype.name)) 1043 if as_ref: 1044 return v._ref() # pylint: disable=protected-access 1045 else: 1046 return v.value() 1047 1048 @classmethod 1049 def _OverloadAllOperators(cls): # pylint: disable=invalid-name 1050 """Register overloads for all operators.""" 1051 for operator in ops.Tensor.OVERLOADABLE_OPERATORS: 1052 cls._OverloadOperator(operator) 1053 # For slicing, bind getitem differently than a tensor (use SliceHelperVar 1054 # instead) 1055 # pylint: disable=protected-access 1056 setattr(cls, "__getitem__", array_ops._SliceHelperVar) 1057 1058 @classmethod 1059 def _OverloadOperator(cls, operator): # pylint: disable=invalid-name 1060 """Defer an operator overload to `ops.Tensor`. 1061 1062 We pull the operator out of ops.Tensor dynamically to avoid ordering issues. 1063 1064 Args: 1065 operator: string. The operator name. 1066 """ 1067 # We can't use the overload mechanism on __eq__ & __ne__ since __eq__ is 1068 # called when adding a variable to sets. As a result we call a.value() which 1069 # causes infinite recursion when operating within a GradientTape 1070 # TODO(gjn): Consider removing this 1071 if operator == "__eq__" or operator == "__ne__": 1072 return 1073 1074 tensor_oper = getattr(ops.Tensor, operator) 1075 1076 def _run_op(a, *args, **kwargs): 1077 # pylint: disable=protected-access 1078 return tensor_oper(a.value(), *args, **kwargs) 1079 1080 functools.update_wrapper(_run_op, tensor_oper) 1081 setattr(cls, operator, _run_op) 1082 1083 def __hash__(self): 1084 if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access 1085 raise TypeError("Variable is unhashable. " 1086 "Instead, use tensor.ref() as the key.") 1087 else: 1088 return id(self) 1089 1090 # TODO(gjn): duplicate of math_ops.tensor_equals, consider removing 1091 def __eq__(self, other): 1092 """Compares two variables element-wise for equality.""" 1093 if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access 1094 return gen_math_ops.equal(self, other, incompatible_shape_error=False) 1095 else: 1096 # In legacy graph mode, tensor equality is object equality 1097 return self is other 1098 1099 # TODO(gjn): duplicate of math_ops.tensor_not_equals, consider removing 1100 def __ne__(self, other): 1101 """Compares two variables element-wise for equality.""" 1102 if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access 1103 return gen_math_ops.not_equal(self, other, incompatible_shape_error=False) 1104 else: 1105 # In legacy graph mode, tensor equality is object equality 1106 return self is not other 1107 1108 def __iter__(self): 1109 """Dummy method to prevent iteration. 1110 1111 Do not call. 1112 1113 NOTE(mrry): If we register __getitem__ as an overloaded operator, 1114 Python will valiantly attempt to iterate over the variable's Tensor from 0 1115 to infinity. Declaring this method prevents this unintended behavior. 1116 1117 Raises: 1118 TypeError: when invoked. 1119 """ 1120 raise TypeError("'Variable' object is not iterable.") 1121 1122 # NOTE(mrry): This enables the Variable's overloaded "right" binary 1123 # operators to run when the left operand is an ndarray, because it 1124 # accords the Variable class higher priority than an ndarray, or a 1125 # numpy matrix. 1126 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ 1127 # mechanism, which allows more control over how Variables interact 1128 # with ndarrays. 1129 __array_priority__ = 100 1130 1131 @property 1132 def name(self): 1133 """The name of this variable.""" 1134 raise NotImplementedError 1135 1136 @property 1137 def _shared_name(self): 1138 """The shared name of the variable. 1139 1140 Unlike name(), shared_name doesn't have ":0" suffix. It is user-specified 1141 name with name scope prefix. 1142 1143 Returns: 1144 variable name. 1145 """ 1146 return self.name[:self.name.index(":")] 1147 1148 @property 1149 def initializer(self): 1150 """The initializer operation for this variable.""" 1151 raise NotImplementedError 1152 1153 @property 1154 def device(self): 1155 """The device of this variable.""" 1156 raise NotImplementedError 1157 1158 @property 1159 def dtype(self): 1160 """The `DType` of this variable.""" 1161 raise NotImplementedError 1162 1163 @property 1164 def op(self): 1165 """The `Operation` of this variable.""" 1166 raise NotImplementedError 1167 1168 @property 1169 def graph(self): 1170 """The `Graph` of this variable.""" 1171 raise NotImplementedError 1172 1173 @property 1174 def shape(self): 1175 """The `TensorShape` of this variable. 1176 1177 Returns: 1178 A `TensorShape`. 1179 """ 1180 raise NotImplementedError 1181 1182 def get_shape(self): 1183 """Alias of `Variable.shape`.""" 1184 return self.shape 1185 1186 def _gather_saveables_for_checkpoint(self): 1187 """For implementing `Trackable`. This object is saveable on its own.""" 1188 return {trackable.VARIABLE_VALUE_KEY: self} 1189 1190 def to_proto(self, export_scope=None): 1191 """Converts a `Variable` to a `VariableDef` protocol buffer. 1192 1193 Args: 1194 export_scope: Optional `string`. Name scope to remove. 1195 1196 Returns: 1197 A `VariableDef` protocol buffer, or `None` if the `Variable` is not 1198 in the specified name scope. 1199 """ 1200 raise NotImplementedError 1201 1202 @staticmethod 1203 def from_proto(variable_def, import_scope=None): 1204 """Returns a `Variable` object created from `variable_def`.""" 1205 return RefVariable(variable_def=variable_def, import_scope=import_scope) 1206 1207 def _set_save_slice_info(self, save_slice_info): 1208 """Sets the slice info for this `Variable`. 1209 1210 Args: 1211 save_slice_info: A `Variable.SaveSliceInfo` object. 1212 """ 1213 self._save_slice_info = save_slice_info 1214 1215 def _get_save_slice_info(self): 1216 return self._save_slice_info 1217 1218 @deprecated(None, "Use ref() instead.") 1219 def experimental_ref(self): 1220 return self.ref() 1221 1222 def ref(self): 1223 # tf.Tensor also has the same ref() API. If you update the 1224 # documentation here, please update tf.Tensor.ref() as well. 1225 """Returns a hashable reference object to this Variable. 1226 1227 The primary use case for this API is to put variables in a set/dictionary. 1228 We can't put variables in a set/dictionary as `variable.__hash__()` is no 1229 longer available starting Tensorflow 2.0. 1230 1231 The following will raise an exception starting 2.0 1232 1233 >>> x = tf.Variable(5) 1234 >>> y = tf.Variable(10) 1235 >>> z = tf.Variable(10) 1236 >>> variable_set = {x, y, z} 1237 Traceback (most recent call last): 1238 ... 1239 TypeError: Variable is unhashable. Instead, use tensor.ref() as the key. 1240 >>> variable_dict = {x: 'five', y: 'ten'} 1241 Traceback (most recent call last): 1242 ... 1243 TypeError: Variable is unhashable. Instead, use tensor.ref() as the key. 1244 1245 Instead, we can use `variable.ref()`. 1246 1247 >>> variable_set = {x.ref(), y.ref(), z.ref()} 1248 >>> x.ref() in variable_set 1249 True 1250 >>> variable_dict = {x.ref(): 'five', y.ref(): 'ten', z.ref(): 'ten'} 1251 >>> variable_dict[y.ref()] 1252 'ten' 1253 1254 Also, the reference object provides `.deref()` function that returns the 1255 original Variable. 1256 1257 >>> x = tf.Variable(5) 1258 >>> x.ref().deref() 1259 <tf.Variable 'Variable:0' shape=() dtype=int32, numpy=5> 1260 """ 1261 return object_identity.Reference(self) 1262 1263 class SaveSliceInfo(object): 1264 """Information on how to save this Variable as a slice. 1265 1266 Provides internal support for saving variables as slices of a larger 1267 variable. This API is not public and is subject to change. 1268 1269 Available properties: 1270 1271 * full_name 1272 * full_shape 1273 * var_offset 1274 * var_shape 1275 """ 1276 1277 def __init__(self, 1278 full_name=None, 1279 full_shape=None, 1280 var_offset=None, 1281 var_shape=None, 1282 save_slice_info_def=None, 1283 import_scope=None): 1284 """Create a `SaveSliceInfo`. 1285 1286 Args: 1287 full_name: Name of the full variable of which this `Variable` is a 1288 slice. 1289 full_shape: Shape of the full variable, as a list of int. 1290 var_offset: Offset of this `Variable` into the full variable, as a list 1291 of int. 1292 var_shape: Shape of this `Variable`, as a list of int. 1293 save_slice_info_def: `SaveSliceInfoDef` protocol buffer. If not `None`, 1294 recreates the SaveSliceInfo object its contents. `save_slice_info_def` 1295 and other arguments are mutually exclusive. 1296 import_scope: Optional `string`. Name scope to add. Only used when 1297 initializing from protocol buffer. 1298 """ 1299 if save_slice_info_def: 1300 assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef) 1301 self.full_name = ops.prepend_name_scope( 1302 save_slice_info_def.full_name, import_scope=import_scope) 1303 self.full_shape = [i for i in save_slice_info_def.full_shape] 1304 self.var_offset = [i for i in save_slice_info_def.var_offset] 1305 self.var_shape = [i for i in save_slice_info_def.var_shape] 1306 else: 1307 self.full_name = full_name 1308 self.full_shape = full_shape 1309 self.var_offset = var_offset 1310 self.var_shape = var_shape 1311 1312 @property 1313 def spec(self): 1314 """Computes the spec string used for saving.""" 1315 full_shape_str = " ".join("%d" % d for d in self.full_shape) + " " 1316 sl_spec = ":".join( 1317 "%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape)) 1318 return full_shape_str + sl_spec 1319 1320 def to_proto(self, export_scope=None): 1321 """Returns a SaveSliceInfoDef() proto. 1322 1323 Args: 1324 export_scope: Optional `string`. Name scope to remove. 1325 1326 Returns: 1327 A `SaveSliceInfoDef` protocol buffer, or None if the `Variable` is not 1328 in the specified name scope. 1329 """ 1330 if (export_scope is None or self.full_name.startswith(export_scope)): 1331 save_slice_info_def = variable_pb2.SaveSliceInfoDef() 1332 save_slice_info_def.full_name = ops.strip_name_scope( 1333 self.full_name, export_scope) 1334 for i in self.full_shape: 1335 save_slice_info_def.full_shape.append(i) 1336 for i in self.var_offset: 1337 save_slice_info_def.var_offset.append(i) 1338 for i in self.var_shape: 1339 save_slice_info_def.var_shape.append(i) 1340 return save_slice_info_def 1341 else: 1342 return None 1343 1344 1345Variable._OverloadAllOperators() # pylint: disable=protected-access 1346_pywrap_utils.RegisterType("Variable", Variable) 1347 1348 1349@tf_export(v1=["Variable"]) 1350class VariableV1(Variable): 1351 """See the [Variables Guide](https://tensorflow.org/guide/variables). 1352 1353 A variable maintains state in the graph across calls to `run()`. You add a 1354 variable to the graph by constructing an instance of the class `Variable`. 1355 1356 The `Variable()` constructor requires an initial value for the variable, 1357 which can be a `Tensor` of any type and shape. The initial value defines the 1358 type and shape of the variable. After construction, the type and shape of 1359 the variable are fixed. The value can be changed using one of the assign 1360 methods. 1361 1362 If you want to change the shape of a variable later you have to use an 1363 `assign` Op with `validate_shape=False`. 1364 1365 Just like any `Tensor`, variables created with `Variable()` can be used as 1366 inputs for other Ops in the graph. Additionally, all the operators 1367 overloaded for the `Tensor` class are carried over to variables, so you can 1368 also add nodes to the graph by just doing arithmetic on variables. 1369 1370 ```python 1371 import tensorflow as tf 1372 1373 # Create a variable. 1374 w = tf.Variable(<initial-value>, name=<optional-name>) 1375 1376 # Use the variable in the graph like any Tensor. 1377 y = tf.matmul(w, ...another variable or tensor...) 1378 1379 # The overloaded operators are available too. 1380 z = tf.sigmoid(w + y) 1381 1382 # Assign a new value to the variable with `assign()` or a related method. 1383 w.assign(w + 1.0) 1384 w.assign_add(1.0) 1385 ``` 1386 1387 When you launch the graph, variables have to be explicitly initialized before 1388 you can run Ops that use their value. You can initialize a variable by 1389 running its *initializer op*, restoring the variable from a save file, or 1390 simply running an `assign` Op that assigns a value to the variable. In fact, 1391 the variable *initializer op* is just an `assign` Op that assigns the 1392 variable's initial value to the variable itself. 1393 1394 ```python 1395 # Launch the graph in a session. 1396 with tf.compat.v1.Session() as sess: 1397 # Run the variable initializer. 1398 sess.run(w.initializer) 1399 # ...you now can run ops that use the value of 'w'... 1400 ``` 1401 1402 The most common initialization pattern is to use the convenience function 1403 `global_variables_initializer()` to add an Op to the graph that initializes 1404 all the variables. You then run that Op after launching the graph. 1405 1406 ```python 1407 # Add an Op to initialize global variables. 1408 init_op = tf.compat.v1.global_variables_initializer() 1409 1410 # Launch the graph in a session. 1411 with tf.compat.v1.Session() as sess: 1412 # Run the Op that initializes global variables. 1413 sess.run(init_op) 1414 # ...you can now run any Op that uses variable values... 1415 ``` 1416 1417 If you need to create a variable with an initial value dependent on another 1418 variable, use the other variable's `initialized_value()`. This ensures that 1419 variables are initialized in the right order. 1420 1421 All variables are automatically collected in the graph where they are 1422 created. By default, the constructor adds the new variable to the graph 1423 collection `GraphKeys.GLOBAL_VARIABLES`. The convenience function 1424 `global_variables()` returns the contents of that collection. 1425 1426 When building a machine learning model it is often convenient to distinguish 1427 between variables holding the trainable model parameters and other variables 1428 such as a `global step` variable used to count training steps. To make this 1429 easier, the variable constructor supports a `trainable=<bool>` parameter. If 1430 `True`, the new variable is also added to the graph collection 1431 `GraphKeys.TRAINABLE_VARIABLES`. The convenience function 1432 `trainable_variables()` returns the contents of this collection. The 1433 various `Optimizer` classes use this collection as the default list of 1434 variables to optimize. 1435 1436 WARNING: tf.Variable objects by default have a non-intuitive memory model. A 1437 Variable is represented internally as a mutable Tensor which can 1438 non-deterministically alias other Tensors in a graph. The set of operations 1439 which consume a Variable and can lead to aliasing is undetermined and can 1440 change across TensorFlow versions. Avoid writing code which relies on the 1441 value of a Variable either changing or not changing as other operations 1442 happen. For example, using Variable objects or simple functions thereof as 1443 predicates in a `tf.cond` is dangerous and error-prone: 1444 1445 ``` 1446 v = tf.Variable(True) 1447 tf.cond(v, lambda: v.assign(False), my_false_fn) # Note: this is broken. 1448 ``` 1449 1450 Here, adding `use_resource=True` when constructing the variable will 1451 fix any nondeterminism issues: 1452 ``` 1453 v = tf.Variable(True, use_resource=True) 1454 tf.cond(v, lambda: v.assign(False), my_false_fn) 1455 ``` 1456 1457 To use the replacement for variables which does 1458 not have these issues: 1459 1460 * Add `use_resource=True` when constructing `tf.Variable`; 1461 * Call `tf.compat.v1.get_variable_scope().set_use_resource(True)` inside a 1462 `tf.compat.v1.variable_scope` before the `tf.compat.v1.get_variable()` call. 1463 """ 1464 1465 def __init__( 1466 self, # pylint: disable=super-init-not-called 1467 initial_value=None, 1468 trainable=None, 1469 collections=None, 1470 validate_shape=True, 1471 caching_device=None, 1472 name=None, 1473 variable_def=None, 1474 dtype=None, 1475 expected_shape=None, 1476 import_scope=None, 1477 constraint=None, 1478 use_resource=None, 1479 synchronization=VariableSynchronization.AUTO, 1480 aggregation=VariableAggregation.NONE, 1481 shape=None): 1482 """Creates a new variable with value `initial_value`. 1483 1484 The new variable is added to the graph collections listed in `collections`, 1485 which defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1486 1487 If `trainable` is `True` the variable is also added to the graph collection 1488 `GraphKeys.TRAINABLE_VARIABLES`. 1489 1490 This constructor creates both a `variable` Op and an `assign` Op to set the 1491 variable to its initial value. 1492 1493 Args: 1494 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 1495 which is the initial value for the Variable. The initial value must have 1496 a shape specified unless `validate_shape` is set to False. Can also be a 1497 callable with no argument that returns the initial value when called. In 1498 that case, `dtype` must be specified. (Note that initializer functions 1499 from init_ops.py must first be bound to a shape before being used here.) 1500 trainable: If `True`, also adds the variable to the graph collection 1501 `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default 1502 list of variables to use by the `Optimizer` classes. Defaults to `True`, 1503 unless `synchronization` is set to `ON_READ`, in which case it defaults 1504 to `False`. 1505 collections: List of graph collections keys. The new variable is added to 1506 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1507 validate_shape: If `False`, allows the variable to be initialized with a 1508 value of unknown shape. If `True`, the default, the shape of 1509 `initial_value` must be known. 1510 caching_device: Optional device string describing where the Variable 1511 should be cached for reading. Defaults to the Variable's device. If not 1512 `None`, caches on another device. Typical use is to cache on the device 1513 where the Ops using the Variable reside, to deduplicate copying through 1514 `Switch` and other conditional statements. 1515 name: Optional name for the variable. Defaults to `'Variable'` and gets 1516 uniquified automatically. 1517 variable_def: `VariableDef` protocol buffer. If not `None`, recreates the 1518 Variable object with its contents, referencing the variable's nodes in 1519 the graph, which must already exist. The graph is not changed. 1520 `variable_def` and the other arguments are mutually exclusive. 1521 dtype: If set, initial_value will be converted to the given type. If 1522 `None`, either the datatype will be kept (if `initial_value` is a 1523 Tensor), or `convert_to_tensor` will decide. 1524 expected_shape: A TensorShape. If set, initial_value is expected to have 1525 this shape. 1526 import_scope: Optional `string`. Name scope to add to the `Variable.` Only 1527 used when initializing from protocol buffer. 1528 constraint: An optional projection function to be applied to the variable 1529 after being updated by an `Optimizer` (e.g. used to implement norm 1530 constraints or value constraints for layer weights). The function must 1531 take as input the unprojected Tensor representing the value of the 1532 variable and return the Tensor for the projected value (which must have 1533 the same shape). Constraints are not safe to use when doing asynchronous 1534 distributed training. 1535 use_resource: whether to use resource variables. 1536 synchronization: Indicates when a distributed a variable will be 1537 aggregated. Accepted values are constants defined in the class 1538 `tf.VariableSynchronization`. By default the synchronization is set to 1539 `AUTO` and the current `DistributionStrategy` chooses when to 1540 synchronize. 1541 aggregation: Indicates how a distributed variable will be aggregated. 1542 Accepted values are constants defined in the class 1543 `tf.VariableAggregation`. 1544 shape: (optional) The shape of this variable. If None, the shape of 1545 `initial_value` will be used. When setting this argument to 1546 `tf.TensorShape(None)` (representing an unspecified shape), the variable 1547 can be assigned with values of different shapes. 1548 1549 Raises: 1550 ValueError: If both `variable_def` and initial_value are specified. 1551 ValueError: If the initial value is not specified, or does not have a 1552 shape and `validate_shape` is `True`. 1553 RuntimeError: If eager execution is enabled. 1554 """ 1555 1556 SaveSliceInfo = Variable.SaveSliceInfo 1557 1558 1559# TODO(apassos): do not repeat all comments here 1560class RefVariable(VariableV1, core.Tensor): 1561 """Ref-based implementation of variables.""" 1562 1563 def __init__( 1564 self, # pylint: disable=super-init-not-called 1565 initial_value=None, 1566 trainable=None, 1567 collections=None, 1568 validate_shape=True, 1569 caching_device=None, 1570 name=None, 1571 variable_def=None, 1572 dtype=None, 1573 expected_shape=None, 1574 import_scope=None, 1575 constraint=None, 1576 synchronization=None, 1577 aggregation=None, 1578 shape=None): 1579 """Creates a new variable with value `initial_value`. 1580 1581 The new variable is added to the graph collections listed in `collections`, 1582 which defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1583 1584 If `trainable` is `True` the variable is also added to the graph collection 1585 `GraphKeys.TRAINABLE_VARIABLES`. 1586 1587 This constructor creates both a `variable` Op and an `assign` Op to set the 1588 variable to its initial value. 1589 1590 Args: 1591 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 1592 which is the initial value for the Variable. The initial value must have 1593 a shape specified unless `validate_shape` is set to False. Can also be a 1594 callable with no argument that returns the initial value when called. In 1595 that case, `dtype` must be specified. (Note that initializer functions 1596 from init_ops.py must first be bound to a shape before being used here.) 1597 trainable: If `True`, also adds the variable to the graph collection 1598 `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default 1599 list of variables to use by the `Optimizer` classes. Defaults to `True`, 1600 unless `synchronization` is set to `ON_READ`, in which case it defaults 1601 to `False`. 1602 collections: List of graph collections keys. The new variable is added to 1603 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1604 validate_shape: If `False`, allows the variable to be initialized with a 1605 value of unknown shape. If `True`, the default, the shape of 1606 `initial_value` must be known. 1607 caching_device: Optional device string describing where the Variable 1608 should be cached for reading. Defaults to the Variable's device. If not 1609 `None`, caches on another device. Typical use is to cache on the device 1610 where the Ops using the Variable reside, to deduplicate copying through 1611 `Switch` and other conditional statements. 1612 name: Optional name for the variable. Defaults to `'Variable'` and gets 1613 uniquified automatically. 1614 variable_def: `VariableDef` protocol buffer. If not `None`, recreates the 1615 Variable object with its contents, referencing the variable's nodes in 1616 the graph, which must already exist. The graph is not changed. 1617 `variable_def` and the other arguments are mutually exclusive. 1618 dtype: If set, initial_value will be converted to the given type. If 1619 `None`, either the datatype will be kept (if `initial_value` is a 1620 Tensor), or `convert_to_tensor` will decide. 1621 expected_shape: A TensorShape. If set, initial_value is expected to have 1622 this shape. 1623 import_scope: Optional `string`. Name scope to add to the `Variable.` Only 1624 used when initializing from protocol buffer. 1625 constraint: An optional projection function to be applied to the variable 1626 after being updated by an `Optimizer` (e.g. used to implement norm 1627 constraints or value constraints for layer weights). The function must 1628 take as input the unprojected Tensor representing the value of the 1629 variable and return the Tensor for the projected value (which must have 1630 the same shape). Constraints are not safe to use when doing asynchronous 1631 distributed training. 1632 synchronization: Indicates when a distributed a variable will be 1633 aggregated. Accepted values are constants defined in the class 1634 `tf.VariableSynchronization`. By default the synchronization is set to 1635 `AUTO` and the current `DistributionStrategy` chooses when to 1636 synchronize. 1637 aggregation: Indicates how a distributed variable will be aggregated. 1638 Accepted values are constants defined in the class 1639 `tf.VariableAggregation`. 1640 shape: (optional) The shape of this variable. If None, the shape of 1641 `initial_value` will be used. When setting this argument to 1642 `tf.TensorShape(None)` (representing an unspecified shape), the variable 1643 can be assigned with values of different shapes. 1644 1645 Raises: 1646 ValueError: If both `variable_def` and initial_value are specified. 1647 ValueError: If the initial value is not specified, or does not have a 1648 shape and `validate_shape` is `True`. 1649 RuntimeError: If eager execution is enabled. 1650 """ 1651 self._in_graph_mode = True 1652 if variable_def: 1653 # If variable_def is provided, recreates the variable from its fields. 1654 if initial_value: 1655 raise ValueError("variable_def and initial_value are mutually " 1656 "exclusive.") 1657 self._init_from_proto(variable_def, import_scope=import_scope) 1658 else: 1659 # Create from initial_value. 1660 self._init_from_args( 1661 initial_value=initial_value, 1662 trainable=trainable, 1663 collections=collections, 1664 validate_shape=validate_shape, 1665 caching_device=caching_device, 1666 name=name, 1667 dtype=dtype, 1668 expected_shape=expected_shape, 1669 constraint=constraint, 1670 synchronization=synchronization, 1671 aggregation=aggregation, 1672 shape=shape) 1673 1674 def __repr__(self): 1675 if context.executing_eagerly() and not self._in_graph_mode: 1676 return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % ( 1677 self.name, self.get_shape(), self.dtype.name, 1678 ops.numpy_text(self.read_value(), is_repr=True)) 1679 else: 1680 return "<tf.Variable '%s' shape=%s dtype=%s>" % ( 1681 self.name, self.get_shape(), self.dtype.name) 1682 1683 def _init_from_args(self, 1684 initial_value=None, 1685 trainable=None, 1686 collections=None, 1687 validate_shape=True, 1688 caching_device=None, 1689 name=None, 1690 dtype=None, 1691 expected_shape=None, 1692 constraint=None, 1693 synchronization=None, 1694 aggregation=None, 1695 shape=None): 1696 """Creates a new variable from arguments. 1697 1698 Args: 1699 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 1700 which is the initial value for the Variable. The initial value must have 1701 a shape specified unless `validate_shape` is set to False. Can also be a 1702 callable with no argument that returns the initial value when called. 1703 (Note that initializer functions from init_ops.py must first be bound to 1704 a shape before being used here.) 1705 trainable: If `True`, also adds the variable to the graph collection 1706 `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default 1707 list of variables to use by the `Optimizer` classes. Defaults to `True`, 1708 unless `synchronization` is set to `ON_READ`, in which case it defaults 1709 to `False`. 1710 collections: List of graph collections keys. The new variable is added to 1711 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1712 validate_shape: If `False`, allows the variable to be initialized with a 1713 value of unknown shape. If `True`, the default, the shape of 1714 `initial_value` must be known. 1715 caching_device: Optional device string or function describing where the 1716 Variable should be cached for reading. Defaults to the Variable's 1717 device. If not `None`, caches on another device. Typical use is to 1718 cache on the device where the Ops using the Variable reside, to 1719 deduplicate copying through `Switch` and other conditional statements. 1720 name: Optional name for the variable. Defaults to `'Variable'` and gets 1721 uniquified automatically. 1722 dtype: If set, initial_value will be converted to the given type. If None, 1723 either the datatype will be kept (if initial_value is a Tensor) or 1724 float32 will be used (if it is a Python object convertible to a Tensor). 1725 expected_shape: Deprecated. Ignored. 1726 constraint: An optional projection function to be applied to the variable 1727 after being updated by an `Optimizer` (e.g. used to implement norm 1728 constraints or value constraints for layer weights). The function must 1729 take as input the unprojected Tensor representing the value of the 1730 variable and return the Tensor for the projected value (which must have 1731 the same shape). Constraints are not safe to use when doing asynchronous 1732 distributed training. 1733 synchronization: Indicates when a distributed a variable will be 1734 aggregated. Accepted values are constants defined in the class 1735 `tf.VariableSynchronization`. By default the synchronization is set to 1736 `AUTO` and the current `DistributionStrategy` chooses when to 1737 synchronize. 1738 aggregation: Indicates how a distributed variable will be aggregated. 1739 Accepted values are constants defined in the class 1740 `tf.VariableAggregation`. 1741 shape: (optional) The shape of this variable. If None, the shape of 1742 `initial_value` will be used. When setting this argument to 1743 `tf.TensorShape(None)` (representing an unspecified shape), the variable 1744 can be assigned with values of different shapes. 1745 1746 Raises: 1747 ValueError: If the initial value is not specified, or does not have a 1748 shape and `validate_shape` is `True`. 1749 RuntimeError: If lifted into the eager context. 1750 """ 1751 _ = expected_shape 1752 if initial_value is None: 1753 raise ValueError("initial_value must be specified.") 1754 init_from_fn = callable(initial_value) 1755 1756 if collections is None: 1757 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 1758 if not isinstance(collections, (list, tuple, set)): 1759 raise ValueError( 1760 "collections argument to Variable constructor must be a list, tuple, " 1761 "or set. Got %s of type %s" % (collections, type(collections))) 1762 if constraint is not None and not callable(constraint): 1763 raise ValueError("The `constraint` argument must be a callable.") 1764 1765 # Store the graph key so optimizers know how to only retrieve variables from 1766 # this graph. 1767 self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 1768 if isinstance(initial_value, trackable.CheckpointInitialValue): 1769 self._maybe_initialize_trackable() 1770 self._update_uid = initial_value.checkpoint_position.restore_uid 1771 initial_value = initial_value.wrapped_value 1772 1773 synchronization, aggregation, trainable = ( 1774 validate_synchronization_aggregation_trainable(synchronization, 1775 aggregation, trainable, 1776 name)) 1777 self._synchronization = synchronization 1778 self._aggregation = aggregation 1779 self._trainable = trainable 1780 if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: 1781 collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] 1782 with ops.init_scope(): 1783 # Ensure that we weren't lifted into the eager context. 1784 if context.executing_eagerly(): 1785 raise RuntimeError( 1786 "RefVariable not supported when eager execution is enabled. ") 1787 with ops.name_scope(name, "Variable", 1788 [] if init_from_fn else [initial_value]) as name: 1789 1790 if init_from_fn: 1791 # Use attr_scope and device(None) to simulate the behavior of 1792 # colocate_with when the variable we want to colocate with doesn't 1793 # yet exist. 1794 true_name = ops.name_from_scope_name(name) # pylint: disable=protected-access 1795 attr = attr_value_pb2.AttrValue( 1796 list=attr_value_pb2.AttrValue.ListValue( 1797 s=[compat.as_bytes("loc:@%s" % true_name)])) 1798 # pylint: disable=protected-access 1799 with ops.get_default_graph()._attr_scope({"_class": attr}): 1800 with ops.name_scope("Initializer"), ops.device(None): 1801 initial_value = initial_value() 1802 if isinstance(initial_value, trackable.CheckpointInitialValue): 1803 self._maybe_initialize_trackable() 1804 self._update_uid = initial_value.checkpoint_position.restore_uid 1805 initial_value = initial_value.wrapped_value 1806 self._initial_value = ops.convert_to_tensor( 1807 initial_value, name="initial_value", dtype=dtype) 1808 if shape is None: 1809 shape = ( 1810 self._initial_value.get_shape() 1811 if validate_shape else tensor_shape.unknown_shape()) 1812 self._variable = state_ops.variable_op_v2( 1813 shape, self._initial_value.dtype.base_dtype, name=name) 1814 # pylint: enable=protected-access 1815 1816 # Or get the initial value from a Tensor or Python object. 1817 else: 1818 self._initial_value = ops.convert_to_tensor( 1819 initial_value, name="initial_value", dtype=dtype) 1820 # pylint: disable=protected-access 1821 if self._initial_value.op._get_control_flow_context() is not None: 1822 raise ValueError( 1823 "Initializer for variable %s is from inside a control-flow " 1824 "construct, such as a loop or conditional. When creating a " 1825 "variable inside a loop or conditional, use a lambda as the " 1826 "initializer." % name) 1827 if shape is None: 1828 # pylint: enable=protected-access 1829 shape = ( 1830 self._initial_value.get_shape() 1831 if validate_shape else tensor_shape.unknown_shape()) 1832 # In this case, the variable op can't be created until after the 1833 # initial_value has been converted to a Tensor with a known type. 1834 self._variable = state_ops.variable_op_v2( 1835 shape, self._initial_value.dtype.base_dtype, name=name) 1836 1837 # Cache the name in `self`, because some APIs call `Variable.name` in a 1838 # tight loop, and this halves the cost. 1839 self._name = self._variable.name 1840 1841 # Manually overrides the variable's shape with the initial value's. 1842 if validate_shape: 1843 initial_value_shape = self._initial_value.get_shape() 1844 if not initial_value_shape.is_fully_defined(): 1845 raise ValueError("initial_value must have a shape specified: %s" % 1846 self._initial_value) 1847 1848 # If 'initial_value' makes use of other variables, make sure we don't 1849 # have an issue if these other variables aren't initialized first by 1850 # using their initialized_value() method. 1851 self._initializer_op = state_ops.assign( 1852 self._variable, 1853 _try_guard_against_uninitialized_dependencies( 1854 name, self._initial_value), 1855 validate_shape=validate_shape).op 1856 1857 # TODO(vrv): Change this class to not take caching_device, but 1858 # to take the op to colocate the snapshot with, so we can use 1859 # colocation rather than devices. 1860 if caching_device is not None: 1861 with ops.device(caching_device): 1862 self._snapshot = array_ops.identity(self._variable, name="read") 1863 else: 1864 with ops.colocate_with(self._variable.op): 1865 self._snapshot = array_ops.identity(self._variable, name="read") 1866 ops.add_to_collections(collections, self) 1867 1868 self._caching_device = caching_device 1869 self._save_slice_info = None 1870 self._constraint = constraint 1871 1872 def _init_from_proto(self, variable_def, import_scope=None): 1873 """Recreates the Variable object from a `VariableDef` protocol buffer. 1874 1875 Args: 1876 variable_def: `VariableDef` protocol buffer, describing a variable whose 1877 nodes already exists in the graph. 1878 import_scope: Optional `string`. Name scope to add. 1879 """ 1880 assert isinstance(variable_def, variable_pb2.VariableDef) 1881 # Create from variable_def. 1882 g = ops.get_default_graph() 1883 self._variable = g.as_graph_element( 1884 ops.prepend_name_scope( 1885 variable_def.variable_name, import_scope=import_scope)) 1886 self._name = self._variable.name 1887 self._initializer_op = g.as_graph_element( 1888 ops.prepend_name_scope( 1889 variable_def.initializer_name, import_scope=import_scope)) 1890 # Tests whether initial_value_name exists first for backwards compatibility. 1891 if (hasattr(variable_def, "initial_value_name") and 1892 variable_def.initial_value_name): 1893 self._initial_value = g.as_graph_element( 1894 ops.prepend_name_scope( 1895 variable_def.initial_value_name, import_scope=import_scope)) 1896 else: 1897 self._initial_value = None 1898 synchronization, aggregation, trainable = ( 1899 validate_synchronization_aggregation_trainable( 1900 variable_def.synchronization, variable_def.aggregation, 1901 variable_def.trainable, variable_def.variable_name)) 1902 self._synchronization = synchronization 1903 self._aggregation = aggregation 1904 self._trainable = trainable 1905 self._snapshot = g.as_graph_element( 1906 ops.prepend_name_scope( 1907 variable_def.snapshot_name, import_scope=import_scope)) 1908 if variable_def.HasField("save_slice_info_def"): 1909 self._save_slice_info = Variable.SaveSliceInfo( 1910 save_slice_info_def=variable_def.save_slice_info_def, 1911 import_scope=import_scope) 1912 else: 1913 self._save_slice_info = None 1914 self._caching_device = None 1915 self._constraint = None 1916 1917 def _as_graph_element(self): 1918 """Conversion function for Graph.as_graph_element().""" 1919 return self._variable 1920 1921 def value(self): 1922 """Returns the last snapshot of this variable. 1923 1924 You usually do not need to call this method as all ops that need the value 1925 of the variable call it automatically through a `convert_to_tensor()` call. 1926 1927 Returns a `Tensor` which holds the value of the variable. You can not 1928 assign a new value to this tensor as it is not a reference to the variable. 1929 1930 To avoid copies, if the consumer of the returned value is on the same device 1931 as the variable, this actually returns the live value of the variable, not 1932 a copy. Updates to the variable are seen by the consumer. If the consumer 1933 is on a different device it will get a copy of the variable. 1934 1935 Returns: 1936 A `Tensor` containing the value of the variable. 1937 """ 1938 return self._snapshot 1939 1940 def read_value(self): 1941 """Returns the value of this variable, read in the current context. 1942 1943 Can be different from value() if it's on another device, with control 1944 dependencies, etc. 1945 1946 Returns: 1947 A `Tensor` containing the value of the variable. 1948 """ 1949 return array_ops.identity(self._variable, name="read") 1950 1951 def _ref(self): 1952 """Returns a reference to this variable. 1953 1954 You usually do not need to call this method as all ops that need a reference 1955 to the variable call it automatically. 1956 1957 Returns is a `Tensor` which holds a reference to the variable. You can 1958 assign a new value to the variable by passing the tensor to an assign op. 1959 See `tf.Variable.value` if you want to get the value of the 1960 variable. 1961 1962 Returns: 1963 A `Tensor` that is a reference to the variable. 1964 """ 1965 return self._variable 1966 1967 def set_shape(self, shape): 1968 """Overrides the shape for this variable. 1969 1970 Args: 1971 shape: the `TensorShape` representing the overridden shape. 1972 """ 1973 self._ref().set_shape(shape) 1974 self.value().set_shape(shape) 1975 1976 @property 1977 def trainable(self): 1978 return self._trainable 1979 1980 @property 1981 def synchronization(self): 1982 return self._synchronization 1983 1984 @property 1985 def aggregation(self): 1986 return self._aggregation 1987 1988 def eval(self, session=None): 1989 """In a session, computes and returns the value of this variable. 1990 1991 This is not a graph construction method, it does not add ops to the graph. 1992 1993 This convenience method requires a session where the graph 1994 containing this variable has been launched. If no session is 1995 passed, the default session is used. See `tf.compat.v1.Session` for more 1996 information on launching a graph and on sessions. 1997 1998 ```python 1999 v = tf.Variable([1, 2]) 2000 init = tf.compat.v1.global_variables_initializer() 2001 2002 with tf.compat.v1.Session() as sess: 2003 sess.run(init) 2004 # Usage passing the session explicitly. 2005 print(v.eval(sess)) 2006 # Usage with the default session. The 'with' block 2007 # above makes 'sess' the default session. 2008 print(v.eval()) 2009 ``` 2010 2011 Args: 2012 session: The session to use to evaluate this variable. If none, the 2013 default session is used. 2014 2015 Returns: 2016 A numpy `ndarray` with a copy of the value of this variable. 2017 """ 2018 return self._variable.eval(session=session) 2019 2020 @property 2021 def initial_value(self): 2022 """Returns the Tensor used as the initial value for the variable. 2023 2024 Note that this is different from `initialized_value()` which runs 2025 the op that initializes the variable before returning its value. 2026 This method returns the tensor that is used by the op that initializes 2027 the variable. 2028 2029 Returns: 2030 A `Tensor`. 2031 """ 2032 return self._initial_value 2033 2034 @property 2035 def constraint(self): 2036 """Returns the constraint function associated with this variable. 2037 2038 Returns: 2039 The constraint function that was passed to the variable constructor. 2040 Can be `None` if no constraint was passed. 2041 """ 2042 return self._constraint 2043 2044 def assign(self, value, use_locking=False, name=None, read_value=True): 2045 """Assigns a new value to the variable. 2046 2047 This is essentially a shortcut for `assign(self, value)`. 2048 2049 Args: 2050 value: A `Tensor`. The new value for this variable. 2051 use_locking: If `True`, use locking during the assignment. 2052 name: The name of the operation to be created 2053 read_value: if True, will return something which evaluates to the new 2054 value of the variable; if False will return the assign op. 2055 2056 Returns: 2057 A `Tensor` that will hold the new value of this variable after 2058 the assignment has completed. 2059 """ 2060 assign = state_ops.assign( 2061 self._variable, value, use_locking=use_locking, name=name) 2062 if read_value: 2063 return assign 2064 return assign.op 2065 2066 def assign_add(self, delta, use_locking=False, name=None, read_value=True): 2067 """Adds a value to this variable. 2068 2069 This is essentially a shortcut for `assign_add(self, delta)`. 2070 2071 Args: 2072 delta: A `Tensor`. The value to add to this variable. 2073 use_locking: If `True`, use locking during the operation. 2074 name: The name of the operation to be created 2075 read_value: if True, will return something which evaluates to the new 2076 value of the variable; if False will return the assign op. 2077 2078 Returns: 2079 A `Tensor` that will hold the new value of this variable after 2080 the addition has completed. 2081 """ 2082 assign = state_ops.assign_add( 2083 self._variable, delta, use_locking=use_locking, name=name) 2084 if read_value: 2085 return assign 2086 return assign.op 2087 2088 def assign_sub(self, delta, use_locking=False, name=None, read_value=True): 2089 """Subtracts a value from this variable. 2090 2091 This is essentially a shortcut for `assign_sub(self, delta)`. 2092 2093 Args: 2094 delta: A `Tensor`. The value to subtract from this variable. 2095 use_locking: If `True`, use locking during the operation. 2096 name: The name of the operation to be created 2097 read_value: if True, will return something which evaluates to the new 2098 value of the variable; if False will return the assign op. 2099 2100 Returns: 2101 A `Tensor` that will hold the new value of this variable after 2102 the subtraction has completed. 2103 """ 2104 assign = state_ops.assign_sub( 2105 self._variable, delta, use_locking=use_locking, name=name) 2106 if read_value: 2107 return assign 2108 return assign.op 2109 2110 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 2111 """Subtracts `tf.IndexedSlices` from this variable. 2112 2113 Args: 2114 sparse_delta: `tf.IndexedSlices` to be subtracted from this variable. 2115 use_locking: If `True`, use locking during the operation. 2116 name: the name of the operation. 2117 2118 Returns: 2119 A `Tensor` that will hold the new value of this variable after 2120 the scattered subtraction has completed. 2121 2122 Raises: 2123 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2124 """ 2125 if not isinstance(sparse_delta, ops.IndexedSlices): 2126 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2127 return gen_state_ops.scatter_sub( 2128 self._variable, 2129 sparse_delta.indices, 2130 sparse_delta.values, 2131 use_locking=use_locking, 2132 name=name) 2133 2134 def scatter_add(self, sparse_delta, use_locking=False, name=None): 2135 """Adds `tf.IndexedSlices` to this variable. 2136 2137 Args: 2138 sparse_delta: `tf.IndexedSlices` to be added to this variable. 2139 use_locking: If `True`, use locking during the operation. 2140 name: the name of the operation. 2141 2142 Returns: 2143 A `Tensor` that will hold the new value of this variable after 2144 the scattered addition has completed. 2145 2146 Raises: 2147 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2148 """ 2149 if not isinstance(sparse_delta, ops.IndexedSlices): 2150 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2151 return gen_state_ops.scatter_add( 2152 self._variable, 2153 sparse_delta.indices, 2154 sparse_delta.values, 2155 use_locking=use_locking, 2156 name=name) 2157 2158 def scatter_max(self, sparse_delta, use_locking=False, name=None): 2159 """Updates this variable with the max of `tf.IndexedSlices` and itself. 2160 2161 Args: 2162 sparse_delta: `tf.IndexedSlices` to use as an argument of max with this 2163 variable. 2164 use_locking: If `True`, use locking during the operation. 2165 name: the name of the operation. 2166 2167 Returns: 2168 A `Tensor` that will hold the new value of this variable after 2169 the scattered maximization has completed. 2170 2171 Raises: 2172 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2173 """ 2174 if not isinstance(sparse_delta, ops.IndexedSlices): 2175 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2176 return gen_state_ops.scatter_max( 2177 self._variable, 2178 sparse_delta.indices, 2179 sparse_delta.values, 2180 use_locking=use_locking, 2181 name=name) 2182 2183 def scatter_min(self, sparse_delta, use_locking=False, name=None): 2184 """Updates this variable with the min of `tf.IndexedSlices` and itself. 2185 2186 Args: 2187 sparse_delta: `tf.IndexedSlices` to use as an argument of min with this 2188 variable. 2189 use_locking: If `True`, use locking during the operation. 2190 name: the name of the operation. 2191 2192 Returns: 2193 A `Tensor` that will hold the new value of this variable after 2194 the scattered minimization has completed. 2195 2196 Raises: 2197 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2198 """ 2199 if not isinstance(sparse_delta, ops.IndexedSlices): 2200 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2201 return gen_state_ops.scatter_min( 2202 self._variable, 2203 sparse_delta.indices, 2204 sparse_delta.values, 2205 use_locking=use_locking, 2206 name=name) 2207 2208 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 2209 """Multiply this variable by `tf.IndexedSlices`. 2210 2211 Args: 2212 sparse_delta: `tf.IndexedSlices` to multiply this variable by. 2213 use_locking: If `True`, use locking during the operation. 2214 name: the name of the operation. 2215 2216 Returns: 2217 A `Tensor` that will hold the new value of this variable after 2218 the scattered multiplication has completed. 2219 2220 Raises: 2221 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2222 """ 2223 if not isinstance(sparse_delta, ops.IndexedSlices): 2224 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2225 return gen_state_ops.scatter_mul( 2226 self._variable, 2227 sparse_delta.indices, 2228 sparse_delta.values, 2229 use_locking=use_locking, 2230 name=name) 2231 2232 def scatter_div(self, sparse_delta, use_locking=False, name=None): 2233 """Divide this variable by `tf.IndexedSlices`. 2234 2235 Args: 2236 sparse_delta: `tf.IndexedSlices` to divide this variable by. 2237 use_locking: If `True`, use locking during the operation. 2238 name: the name of the operation. 2239 2240 Returns: 2241 A `Tensor` that will hold the new value of this variable after 2242 the scattered division has completed. 2243 2244 Raises: 2245 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2246 """ 2247 if not isinstance(sparse_delta, ops.IndexedSlices): 2248 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2249 return gen_state_ops.scatter_div( 2250 self._variable, 2251 sparse_delta.indices, 2252 sparse_delta.values, 2253 use_locking=use_locking, 2254 name=name) 2255 2256 def scatter_update(self, sparse_delta, use_locking=False, name=None): 2257 """Assigns `tf.IndexedSlices` to this variable. 2258 2259 Args: 2260 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 2261 use_locking: If `True`, use locking during the operation. 2262 name: the name of the operation. 2263 2264 Returns: 2265 A `Tensor` that will hold the new value of this variable after 2266 the scattered assignment has completed. 2267 2268 Raises: 2269 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2270 """ 2271 if not isinstance(sparse_delta, ops.IndexedSlices): 2272 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2273 return gen_state_ops.scatter_update( 2274 self._variable, 2275 sparse_delta.indices, 2276 sparse_delta.values, 2277 use_locking=use_locking, 2278 name=name) 2279 2280 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 2281 """Assigns `tf.IndexedSlices` to this variable batch-wise. 2282 2283 Analogous to `batch_gather`. This assumes that this variable and the 2284 sparse_delta IndexedSlices have a series of leading dimensions that are the 2285 same for all of them, and the updates are performed on the last dimension of 2286 indices. In other words, the dimensions should be the following: 2287 2288 `num_prefix_dims = sparse_delta.indices.ndims - 1` 2289 `batch_dim = num_prefix_dims + 1` 2290 `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[ 2291 batch_dim:]` 2292 2293 where 2294 2295 `sparse_delta.updates.shape[:num_prefix_dims]` 2296 `== sparse_delta.indices.shape[:num_prefix_dims]` 2297 `== var.shape[:num_prefix_dims]` 2298 2299 And the operation performed can be expressed as: 2300 2301 `var[i_1, ..., i_n, 2302 sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[ 2303 i_1, ..., i_n, j]` 2304 2305 When sparse_delta.indices is a 1D tensor, this operation is equivalent to 2306 `scatter_update`. 2307 2308 To avoid this operation one can looping over the first `ndims` of the 2309 variable and using `scatter_update` on the subtensors that result of slicing 2310 the first dimension. This is a valid option for `ndims = 1`, but less 2311 efficient than this implementation. 2312 2313 Args: 2314 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 2315 use_locking: If `True`, use locking during the operation. 2316 name: the name of the operation. 2317 2318 Returns: 2319 A `Tensor` that will hold the new value of this variable after 2320 the scattered assignment has completed. 2321 2322 Raises: 2323 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2324 """ 2325 return state_ops.batch_scatter_update( 2326 self, 2327 sparse_delta.indices, 2328 sparse_delta.values, 2329 use_locking=use_locking, 2330 name=name) 2331 2332 def scatter_nd_sub(self, indices, updates, name=None): 2333 """Applies sparse subtraction to individual values or slices in a Variable. 2334 2335 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 2336 2337 `indices` must be integer tensor, containing indices into `ref`. 2338 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 2339 2340 The innermost dimension of `indices` (with length `K`) corresponds to 2341 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 2342 dimension of `ref`. 2343 2344 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 2345 2346 ``` 2347 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 2348 ``` 2349 2350 For example, say we want to add 4 scattered elements to a rank-1 tensor to 2351 8 elements. In Python, that update would look like this: 2352 2353 ```python 2354 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 2355 indices = tf.constant([[4], [3], [1] ,[7]]) 2356 updates = tf.constant([9, 10, 11, 12]) 2357 op = ref.scatter_nd_sub(indices, updates) 2358 with tf.compat.v1.Session() as sess: 2359 print sess.run(op) 2360 ``` 2361 2362 The resulting update to ref would look like this: 2363 2364 [1, -9, 3, -6, -6, 6, 7, -4] 2365 2366 See `tf.scatter_nd` for more details about how to make updates to 2367 slices. 2368 2369 Args: 2370 indices: The indices to be used in the operation. 2371 updates: The values to be used in the operation. 2372 name: the name of the operation. 2373 2374 Returns: 2375 A `Tensor` that will hold the new value of this variable after 2376 the scattered subtraction has completed. 2377 """ 2378 return gen_state_ops.scatter_nd_sub( 2379 self._variable, indices, updates, use_locking=True, name=name) 2380 2381 def scatter_nd_add(self, indices, updates, name=None): 2382 """Applies sparse addition to individual values or slices in a Variable. 2383 2384 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 2385 2386 `indices` must be integer tensor, containing indices into `ref`. 2387 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 2388 2389 The innermost dimension of `indices` (with length `K`) corresponds to 2390 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 2391 dimension of `ref`. 2392 2393 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 2394 2395 ``` 2396 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 2397 ``` 2398 2399 For example, say we want to add 4 scattered elements to a rank-1 tensor to 2400 8 elements. In Python, that update would look like this: 2401 2402 ```python 2403 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 2404 indices = tf.constant([[4], [3], [1] ,[7]]) 2405 updates = tf.constant([9, 10, 11, 12]) 2406 add = ref.scatter_nd_add(indices, updates) 2407 with tf.compat.v1.Session() as sess: 2408 print sess.run(add) 2409 ``` 2410 2411 The resulting update to ref would look like this: 2412 2413 [1, 13, 3, 14, 14, 6, 7, 20] 2414 2415 See `tf.scatter_nd` for more details about how to make updates to 2416 slices. 2417 2418 Args: 2419 indices: The indices to be used in the operation. 2420 updates: The values to be used in the operation. 2421 name: the name of the operation. 2422 2423 Returns: 2424 A `Tensor` that will hold the new value of this variable after 2425 the scattered addition has completed. 2426 """ 2427 return gen_state_ops.scatter_nd_add( 2428 self._variable, indices, updates, use_locking=True, name=name) 2429 2430 def scatter_nd_update(self, indices, updates, name=None): 2431 """Applies sparse assignment to individual values or slices in a Variable. 2432 2433 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 2434 2435 `indices` must be integer tensor, containing indices into `ref`. 2436 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 2437 2438 The innermost dimension of `indices` (with length `K`) corresponds to 2439 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 2440 dimension of `ref`. 2441 2442 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 2443 2444 ``` 2445 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 2446 ``` 2447 2448 For example, say we want to add 4 scattered elements to a rank-1 tensor to 2449 8 elements. In Python, that update would look like this: 2450 2451 ```python 2452 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 2453 indices = tf.constant([[4], [3], [1] ,[7]]) 2454 updates = tf.constant([9, 10, 11, 12]) 2455 op = ref.scatter_nd_update(indices, updates) 2456 with tf.compat.v1.Session() as sess: 2457 print sess.run(op) 2458 ``` 2459 2460 The resulting update to ref would look like this: 2461 2462 [1, 11, 3, 10, 9, 6, 7, 12] 2463 2464 See `tf.scatter_nd` for more details about how to make updates to 2465 slices. 2466 2467 Args: 2468 indices: The indices to be used in the operation. 2469 updates: The values to be used in the operation. 2470 name: the name of the operation. 2471 2472 Returns: 2473 A `Tensor` that will hold the new value of this variable after 2474 the scattered assignment has completed. 2475 """ 2476 return gen_state_ops.scatter_nd_update( 2477 self._variable, indices, updates, use_locking=True, name=name) 2478 2479 def scatter_nd_max(self, indices, updates, name=None): 2480 """Updates this variable with the max of `tf.IndexedSlices` and itself. 2481 2482 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 2483 2484 `indices` must be integer tensor, containing indices into `ref`. 2485 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 2486 2487 The innermost dimension of `indices` (with length `K`) corresponds to 2488 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 2489 dimension of `ref`. 2490 2491 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 2492 2493 ``` 2494 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 2495 ``` 2496 2497 See `tf.scatter_nd` for more details about how to make updates to 2498 slices. 2499 2500 Args: 2501 indices: The indices to be used in the operation. 2502 updates: The values to be used in the operation. 2503 name: the name of the operation. 2504 2505 Returns: 2506 A `Tensor` that will hold the new value of this variable after 2507 the scattered addition has completed. 2508 """ 2509 return gen_state_ops.scatter_nd_max( 2510 self._variable, indices, updates, use_locking=True, name=name) 2511 2512 def scatter_nd_min(self, indices, updates, name=None): 2513 """Updates this variable with the min of `tf.IndexedSlices` and itself. 2514 2515 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 2516 2517 `indices` must be integer tensor, containing indices into `ref`. 2518 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 2519 2520 The innermost dimension of `indices` (with length `K`) corresponds to 2521 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 2522 dimension of `ref`. 2523 2524 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 2525 2526 ``` 2527 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 2528 ``` 2529 2530 See `tf.scatter_nd` for more details about how to make updates to 2531 slices. 2532 2533 Args: 2534 indices: The indices to be used in the operation. 2535 updates: The values to be used in the operation. 2536 name: the name of the operation. 2537 2538 Returns: 2539 A `Tensor` that will hold the new value of this variable after 2540 the scattered addition has completed. 2541 """ 2542 return gen_state_ops.scatter_nd_min( 2543 self._variable, indices, updates, use_locking=True, name=name) 2544 2545 def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask, 2546 end_mask, ellipsis_mask, new_axis_mask, 2547 shrink_axis_mask): 2548 return gen_array_ops.strided_slice_assign( 2549 ref=self._ref(), 2550 begin=begin, 2551 end=end, 2552 strides=strides, 2553 value=value, 2554 name=name, 2555 begin_mask=begin_mask, 2556 end_mask=end_mask, 2557 ellipsis_mask=ellipsis_mask, 2558 new_axis_mask=new_axis_mask, 2559 shrink_axis_mask=shrink_axis_mask) 2560 2561 @deprecated(None, "Prefer Dataset.range instead.") 2562 def count_up_to(self, limit): 2563 """Increments this variable until it reaches `limit`. 2564 2565 When that Op is run it tries to increment the variable by `1`. If 2566 incrementing the variable would bring it above `limit` then the Op raises 2567 the exception `OutOfRangeError`. 2568 2569 If no error is raised, the Op outputs the value of the variable before 2570 the increment. 2571 2572 This is essentially a shortcut for `count_up_to(self, limit)`. 2573 2574 Args: 2575 limit: value at which incrementing the variable raises an error. 2576 2577 Returns: 2578 A `Tensor` that will hold the variable value before the increment. If no 2579 other Op modifies this variable, the values produced will all be 2580 distinct. 2581 """ 2582 return state_ops.count_up_to(self._variable, limit=limit) 2583 2584 # Conversion to tensor. 2585 @staticmethod 2586 def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name 2587 """Utility function for converting a Variable to a Tensor.""" 2588 _ = name 2589 if dtype and not dtype.is_compatible_with(v.dtype): 2590 raise ValueError( 2591 "Incompatible type conversion requested to type '%s' for variable " 2592 "of type '%s'" % (dtype.name, v.dtype.name)) 2593 if as_ref: 2594 return v._ref() # pylint: disable=protected-access 2595 else: 2596 return v.value() 2597 2598 # NOTE(mrry): This enables the Variable's overloaded "right" binary 2599 # operators to run when the left operand is an ndarray, because it 2600 # accords the Variable class higher priority than an ndarray, or a 2601 # numpy matrix. 2602 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ 2603 # mechanism, which allows more control over how Variables interact 2604 # with ndarrays. 2605 __array_priority__ = 100 2606 2607 @property 2608 def name(self): 2609 """The name of this variable.""" 2610 return self._name 2611 2612 @property 2613 def initializer(self): 2614 """The initializer operation for this variable.""" 2615 return self._initializer_op 2616 2617 @property 2618 def device(self): 2619 """The device of this variable.""" 2620 return self._variable.device 2621 2622 @property 2623 def dtype(self): 2624 """The `DType` of this variable.""" 2625 return self._variable.dtype 2626 2627 @property 2628 def op(self): 2629 """The `Operation` of this variable.""" 2630 return self._variable.op 2631 2632 @property 2633 def graph(self): 2634 """The `Graph` of this variable.""" 2635 return self._variable.graph 2636 2637 @property 2638 def _distribute_strategy(self): 2639 """The `tf.distribute.Strategy` that this variable was created under.""" 2640 return None # Ref variables are never created inside a strategy. 2641 2642 @property 2643 def shape(self): 2644 """The `TensorShape` of this variable. 2645 2646 Returns: 2647 A `TensorShape`. 2648 """ 2649 return self._variable.get_shape() 2650 2651 def to_proto(self, export_scope=None): 2652 """Converts a `Variable` to a `VariableDef` protocol buffer. 2653 2654 Args: 2655 export_scope: Optional `string`. Name scope to remove. 2656 2657 Returns: 2658 A `VariableDef` protocol buffer, or `None` if the `Variable` is not 2659 in the specified name scope. 2660 """ 2661 if (export_scope is None or self._variable.name.startswith(export_scope)): 2662 var_def = variable_pb2.VariableDef() 2663 var_def.variable_name = ops.strip_name_scope(self._variable.name, 2664 export_scope) 2665 if self._initial_value is not None: 2666 # For backwards compatibility. 2667 var_def.initial_value_name = ops.strip_name_scope( 2668 self._initial_value.name, export_scope) 2669 var_def.trainable = self.trainable 2670 var_def.synchronization = self.synchronization.value 2671 var_def.aggregation = self.aggregation.value 2672 var_def.initializer_name = ops.strip_name_scope(self.initializer.name, 2673 export_scope) 2674 var_def.snapshot_name = ops.strip_name_scope(self._snapshot.name, 2675 export_scope) 2676 if self._save_slice_info: 2677 var_def.save_slice_info_def.MergeFrom( 2678 self._save_slice_info.to_proto(export_scope=export_scope)) 2679 return var_def 2680 else: 2681 return None 2682 2683 def __iadd__(self, other): 2684 logging.log_first_n( 2685 logging.WARN, "Variable += will be deprecated. Use variable.assign_add" 2686 " if you want assignment to the variable value or 'x = x + y'" 2687 " if you want a new python Tensor object.", 1) 2688 return self + other 2689 2690 def __isub__(self, other): 2691 logging.log_first_n( 2692 logging.WARN, "Variable -= will be deprecated. Use variable.assign_sub" 2693 " if you want assignment to the variable value or 'x = x - y'" 2694 " if you want a new python Tensor object.", 1) 2695 return self - other 2696 2697 def __imul__(self, other): 2698 logging.log_first_n( 2699 logging.WARN, 2700 "Variable *= will be deprecated. Use `var.assign(var * other)`" 2701 " if you want assignment to the variable value or `x = x * y`" 2702 " if you want a new python Tensor object.", 1) 2703 return self * other 2704 2705 def __idiv__(self, other): 2706 logging.log_first_n( 2707 logging.WARN, 2708 "Variable /= will be deprecated. Use `var.assign(var / other)`" 2709 " if you want assignment to the variable value or `x = x / y`" 2710 " if you want a new python Tensor object.", 1) 2711 return self / other 2712 2713 def __itruediv__(self, other): 2714 logging.log_first_n( 2715 logging.WARN, 2716 "Variable /= will be deprecated. Use `var.assign(var / other)`" 2717 " if you want assignment to the variable value or `x = x / y`" 2718 " if you want a new python Tensor object.", 1) 2719 return self / other 2720 2721 def __irealdiv__(self, other): 2722 logging.log_first_n( 2723 logging.WARN, 2724 "Variable /= will be deprecated. Use `var.assign(var / other)`" 2725 " if you want assignment to the variable value or `x = x / y`" 2726 " if you want a new python Tensor object.", 1) 2727 return self / other 2728 2729 def __ipow__(self, other): 2730 logging.log_first_n( 2731 logging.WARN, 2732 "Variable **= will be deprecated. Use `var.assign(var ** other)`" 2733 " if you want assignment to the variable value or `x = x ** y`" 2734 " if you want a new python Tensor object.", 1) 2735 return self**other 2736 2737 2738def _try_guard_against_uninitialized_dependencies(name, initial_value): 2739 """Attempt to guard against dependencies on uninitialized variables. 2740 2741 Replace references to variables in `initial_value` with references to the 2742 variable's initialized values. The initialized values are essentially 2743 conditional TensorFlow graphs that return a variable's value if it is 2744 initialized or its `initial_value` if it hasn't been initialized. This 2745 replacement is done on a best effort basis: 2746 2747 - If the `initial_value` graph contains cycles, we don't do any 2748 replacements for that graph. 2749 - If the variables that `initial_value` depends on are not present in the 2750 `GLOBAL_VARIABLES` or `LOCAL_VARIABLES` we don't replace them. 2751 2752 In these cases, it is up to the caller to ensure that the `initial_value` 2753 graph uses initialized variables or that they guard access to variables 2754 using their `initialized_value` method. 2755 2756 Args: 2757 name: Variable name. 2758 initial_value: `Tensor`. The initial value. 2759 2760 Returns: 2761 A `Tensor` suitable to initialize a variable. 2762 Raises: 2763 TypeError: If `initial_value` is not a `Tensor`. 2764 """ 2765 if not isinstance(initial_value, ops.Tensor): 2766 raise TypeError("initial_value needs to be a Tensor: %s" % initial_value) 2767 2768 # Don't modify initial_value if it contains any cyclic dependencies. 2769 if _has_cycle(initial_value.op, state={}): 2770 return initial_value 2771 return _safe_initial_value_from_tensor(name, initial_value, op_cache={}) 2772 2773 2774_UNKNOWN, _STARTED, _FINISHED = range(3) 2775 2776 2777def _has_cycle(op, state): 2778 """Detect cycles in the dependencies of `initial_value`.""" 2779 op_state = state.get(op.name, _UNKNOWN) 2780 if op_state == _STARTED: 2781 return True 2782 elif op_state == _FINISHED: 2783 return False 2784 2785 state[op.name] = _STARTED 2786 for i in itertools.chain((i.op for i in op.inputs), op.control_inputs): 2787 if _has_cycle(i, state): 2788 return True 2789 state[op.name] = _FINISHED 2790 return False 2791 2792 2793def _safe_initial_value_from_tensor(name, tensor, op_cache): 2794 """Replace dependencies on variables with their initialized values. 2795 2796 Args: 2797 name: Variable name. 2798 tensor: A `Tensor`. The tensor to replace. 2799 op_cache: A dict mapping operation names to `Operation`s. Used to memoize 2800 the results so as to avoid creating redundant operations. 2801 2802 Returns: 2803 A `Tensor` compatible with `tensor`. Any inputs that lead to variable 2804 values will be replaced with a corresponding graph that uses the 2805 variable's initialized values. This is done on a best-effort basis. If no 2806 modifications need to be made then `tensor` will be returned unchanged. 2807 """ 2808 op = tensor.op 2809 new_op = op_cache.get(op.name) 2810 if new_op is None: 2811 new_op = _safe_initial_value_from_op(name, op, op_cache) 2812 op_cache[op.name] = new_op 2813 return new_op.outputs[tensor.value_index] 2814 2815 2816def _safe_initial_value_from_op(name, op, op_cache): 2817 """Replace dependencies on variables with their initialized values. 2818 2819 Args: 2820 name: Variable name. 2821 op: An `Operation`. The operation to replace. 2822 op_cache: A dict mapping operation names to `Operation`s. Used to memoize 2823 the results so as to avoid creating redundant operations. 2824 2825 Returns: 2826 An `Operation` compatible with `op`. Any inputs that lead to variable 2827 values will be replaced with a corresponding graph that uses the 2828 variable's initialized values. This is done on a best-effort basis. If no 2829 modifications need to be made then `op` will be returned unchanged. 2830 """ 2831 op_type = op.node_def.op 2832 if op_type in ("IsVariableInitialized", "VarIsInitializedOp", 2833 "ReadVariableOp", "If"): 2834 return op 2835 2836 # Attempt to find the initialized_value of any variable reference / handles. 2837 # TODO(b/70206927): Fix handling of ResourceVariables. 2838 if op_type in ("Variable", "VariableV2", "VarHandleOp"): 2839 initialized_value = _find_initialized_value_for_variable(op) 2840 return op if initialized_value is None else initialized_value.op 2841 2842 # Recursively build initializer expressions for inputs. 2843 modified = False 2844 new_op_inputs = [] 2845 for op_input in op.inputs: 2846 new_op_input = _safe_initial_value_from_tensor(name, op_input, op_cache) 2847 new_op_inputs.append(new_op_input) 2848 modified = modified or (new_op_input != op_input) 2849 2850 # If at least one input was modified, replace the op. 2851 if modified: 2852 new_op_type = op_type 2853 if new_op_type == "RefSwitch": 2854 new_op_type = "Switch" 2855 new_op_name = op.node_def.name + "_" + name 2856 new_op_name = new_op_name.replace(":", "_") 2857 return op.graph.create_op( 2858 new_op_type, 2859 new_op_inputs, 2860 op._output_types, # pylint: disable=protected-access 2861 name=new_op_name, 2862 attrs=op.node_def.attr) 2863 2864 return op 2865 2866 2867def _find_initialized_value_for_variable(variable_op): 2868 """Find the initialized value for a variable op. 2869 2870 To do so, lookup the variable op in the variables collection. 2871 2872 Args: 2873 variable_op: A variable `Operation`. 2874 2875 Returns: 2876 A `Tensor` representing the initialized value for the variable or `None` 2877 if the initialized value could not be found. 2878 """ 2879 try: 2880 var_names = [variable_op.node_def.name, variable_op.node_def.name + ":0"] 2881 for collection_name in (ops.GraphKeys.GLOBAL_VARIABLES, 2882 ops.GraphKeys.LOCAL_VARIABLES): 2883 for var in variable_op.graph.get_collection(collection_name): 2884 if var.name in var_names: 2885 return var.initialized_value() 2886 except AttributeError: 2887 # Return None when an incomplete user-defined variable type was put in 2888 # the collection. 2889 return None 2890 return None 2891 2892 2893class PartitionedVariable(object): 2894 """A container for partitioned `Variable` objects. 2895 2896 @compatibility(eager) `tf.PartitionedVariable` is not compatible with 2897 eager execution. Use `tf.Variable` instead which is compatible 2898 with both eager execution and graph construction. See [the 2899 TensorFlow Eager Execution 2900 guide](https://www.tensorflow.org/guide/eager#variables_and_optimizers) 2901 for details on how variables work in eager execution. 2902 @end_compatibility 2903 """ 2904 2905 def __init__(self, name, shape, dtype, variable_list, partitions): 2906 """Creates a new partitioned variable wrapper. 2907 2908 Variables passed via the variable_list must contain a save_slice_info 2909 field. Concatenation and iteration is in lexicographic order according 2910 to the var_offset property of the save_slice_info. 2911 2912 Args: 2913 name: String. Overall name of the variables. 2914 shape: List of integers. Overall shape of the variables. 2915 dtype: Type of the variables. 2916 variable_list: List of `Variable` that comprise this partitioned variable. 2917 partitions: List of integers. Number of partitions for each dimension. 2918 2919 Raises: 2920 TypeError: If `variable_list` is not a list of `Variable` objects, or 2921 `partitions` is not a list. 2922 ValueError: If `variable_list` is empty, or the `Variable` shape 2923 information does not match `shape`, or `partitions` has invalid values. 2924 """ 2925 if not isinstance(variable_list, (list, tuple)): 2926 raise TypeError("variable_list is not a list or tuple: %s" % 2927 variable_list) 2928 if not isinstance(partitions, (list, tuple)): 2929 raise TypeError("partitions is not a list or tuple: %s" % partitions) 2930 if not all(p >= 1 for p in partitions): 2931 raise ValueError("partition values must be positive: %s" % partitions) 2932 if not variable_list: 2933 raise ValueError("variable_list may not be empty") 2934 # pylint: disable=protected-access 2935 for v in variable_list: 2936 # Sort the variable_list lexicographically according to var offset value. 2937 if not all(v._get_save_slice_info() is not None for v in variable_list): 2938 raise ValueError( 2939 "All variables must have a save_slice_info available: %s" % 2940 [v.name for v in variable_list]) 2941 if len(shape) != len(partitions): 2942 raise ValueError("len(shape) != len(partitions): %s vs. %s" % 2943 (shape, partitions)) 2944 if v._get_save_slice_info().full_shape != shape: 2945 raise ValueError("All variables' full shapes must match shape: %s; " 2946 "but full shapes were: %s" % 2947 (shape, str([v._get_save_slice_info().full_shape]))) 2948 self._variable_list = sorted( 2949 variable_list, key=lambda v: v._get_save_slice_info().var_offset) 2950 # pylint: enable=protected-access 2951 2952 self._name = name 2953 self._shape = shape 2954 self._dtype = dtype 2955 self._partitions = partitions 2956 self._as_tensor = None 2957 2958 def __iter__(self): 2959 """Return an iterable for accessing the underlying partition Variables.""" 2960 return iter(self._variable_list) 2961 2962 def __len__(self): 2963 num_partition_axes = len(self._partition_axes()) 2964 if num_partition_axes > 1: 2965 raise ValueError("Cannot get a length for %d > 1 partition axes" % 2966 num_partition_axes) 2967 return len(self._variable_list) 2968 2969 def _partition_axes(self): 2970 if all(p == 1 for p in self._partitions): 2971 return [0] 2972 else: 2973 return [i for i, p in enumerate(self._partitions) if p > 1] 2974 2975 def _concat(self): 2976 """Returns the overall concatenated value as a `Tensor`. 2977 2978 This is different from using the partitioned variable directly as a tensor 2979 (through tensor conversion and `as_tensor`) in that it creates a new set of 2980 operations that keeps the control dependencies from its scope. 2981 2982 Returns: 2983 `Tensor` containing the concatenated value. 2984 """ 2985 if len(self._variable_list) == 1: 2986 with ops.name_scope(None): 2987 return array_ops.identity(self._variable_list[0], name=self._name) 2988 2989 partition_axes = self._partition_axes() 2990 2991 if len(partition_axes) > 1: 2992 raise NotImplementedError( 2993 "Cannot concatenate along more than one dimension: %s. " 2994 "Multi-axis partition concat is not supported" % str(partition_axes)) 2995 partition_ix = partition_axes[0] 2996 2997 with ops.name_scope(self._name + "/ConcatPartitions/"): 2998 concatenated = array_ops.concat(self._variable_list, partition_ix) 2999 3000 with ops.name_scope(None): 3001 return array_ops.identity(concatenated, name=self._name) 3002 3003 def as_tensor(self): 3004 """Returns the overall concatenated value as a `Tensor`. 3005 3006 The returned tensor will not inherit the control dependencies from the scope 3007 where the value is used, which is similar to getting the value of 3008 `Variable`. 3009 3010 Returns: 3011 `Tensor` containing the concatenated value. 3012 """ 3013 with ops.control_dependencies(None): 3014 return self._concat() 3015 3016 @staticmethod 3017 def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): 3018 # pylint: disable=invalid-name 3019 _ = name 3020 if dtype is not None and not dtype.is_compatible_with(v.dtype): 3021 raise ValueError( 3022 "Incompatible type conversion requested to type '%s' for variable " 3023 "of type '%s'" % (dtype.name, v.dtype.name)) 3024 if as_ref: 3025 raise NotImplementedError( 3026 "PartitionedVariable doesn't support being used as a reference.") 3027 else: 3028 return v.as_tensor() 3029 3030 @property 3031 def name(self): 3032 return self._name 3033 3034 @property 3035 def dtype(self): 3036 return self._dtype 3037 3038 @property 3039 def shape(self): 3040 return self.get_shape() 3041 3042 @property 3043 def _distribute_strategy(self): 3044 """The `tf.distribute.Strategy` that this variable was created under.""" 3045 # NOTE(yuefengz): Today, no partitioned variables in a distribute strategy. 3046 return None 3047 3048 def get_shape(self): 3049 return self._shape 3050 3051 def _get_variable_list(self): 3052 return self._variable_list 3053 3054 def _get_partitions(self): 3055 return self._partitions 3056 3057 def _apply_assign_fn(self, assign_fn, value): 3058 partition_axes = self._partition_axes() 3059 if len(partition_axes) > 1: 3060 raise NotImplementedError( 3061 "Cannot do assign action along more than one dimension: %s. " 3062 "Multi-axis partition assign action is not supported " % 3063 str(partition_axes)) 3064 if isinstance(value, list): 3065 assert len(value) == len(self._variable_list) 3066 value_list = value 3067 elif isinstance(value, PartitionedVariable): 3068 value_list = [var_part for var_part in value] 3069 else: 3070 partition_ix = partition_axes[0] 3071 size_splits_list = [ 3072 tensor_shape.dimension_value(var.shape[partition_ix]) 3073 for var in self._variable_list 3074 ] 3075 value_list = array_ops.split(value, size_splits_list, axis=partition_ix) 3076 3077 op_list = [ 3078 assign_fn(var, value_list[idx]) 3079 for idx, var in enumerate(self._variable_list) 3080 ] 3081 return op_list 3082 3083 def assign(self, value, use_locking=False, name=None, read_value=True): 3084 assign_fn = lambda var, r_value: var.assign( 3085 r_value, use_locking=use_locking, name=name, read_value=read_value) 3086 assign_list = self._apply_assign_fn(assign_fn, value) 3087 if read_value: 3088 return assign_list 3089 return [assign.op for assign in assign_list] 3090 3091 def assign_add(self, value, use_locking=False, name=None, read_value=True): 3092 assign_fn = lambda var, r_value: var.assign_add( 3093 r_value, use_locking=use_locking, name=name, read_value=read_value) 3094 assign_list = self._apply_assign_fn(assign_fn, value) 3095 if read_value: 3096 return assign_list 3097 return [assign.op for assign in assign_list] 3098 3099 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 3100 assign_fn = lambda var, r_value: var.assign_sub( 3101 r_value, use_locking=use_locking, name=name, read_value=read_value) 3102 assign_list = self._apply_assign_fn(assign_fn, value) 3103 if read_value: 3104 return assign_list 3105 return [assign.op for assign in assign_list] 3106 3107 3108# Register a conversion function which reads the value of the variable, 3109# allowing instances of the class to be used as tensors. 3110ops.register_tensor_conversion_function(RefVariable, 3111 RefVariable._TensorConversionFunction) # pylint: disable=protected-access 3112 3113 3114@tf_export(v1=["global_variables"]) 3115def global_variables(scope=None): 3116 """Returns global variables. 3117 3118 Global variables are variables that are shared across machines in a 3119 distributed environment. The `Variable()` constructor or `get_variable()` 3120 automatically adds new variables to the graph collection 3121 `GraphKeys.GLOBAL_VARIABLES`. 3122 This convenience function returns the contents of that collection. 3123 3124 An alternative to global variables are local variables. See 3125 `tf.compat.v1.local_variables` 3126 3127 @compatibility(TF2) 3128 Not compatible with eager execution and `tf.function`. In particular, Graph 3129 collections are deprecated in TF2. Instead please create a 3130 [tf.Module](https://www.tensorflow.org/guide/intro_to_modules) 3131 container for all your model state, including variables. 3132 You can then list all the variables in your `tf.Module` through the 3133 `variables` attribute. 3134 @end_compatibility 3135 3136 Args: 3137 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3138 include only items whose `name` attribute matches `scope` using 3139 `re.match`. Items without a `name` attribute are never returned if a scope 3140 is supplied. The choice of `re.match` means that a `scope` without special 3141 tokens filters by prefix. 3142 3143 Returns: 3144 A list of `Variable` objects. 3145 """ 3146 return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) 3147 3148 3149@tf_export(v1=["all_variables"]) 3150@deprecated("2017-03-02", "Please use tf.global_variables instead.") 3151def all_variables(): 3152 """Use `tf.compat.v1.global_variables` instead.""" 3153 return global_variables() 3154 3155 3156def _all_saveable_objects(scope=None): 3157 """Returns all variables and `SaveableObject`s that must be checkpointed. 3158 3159 Args: 3160 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3161 include only items whose `name` attribute matches `scope` using 3162 `re.match`. Items without a `name` attribute are never returned if a scope 3163 is supplied. The choice of `re.match` means that a `scope` without special 3164 tokens filters by prefix. 3165 3166 Returns: 3167 A list of `Variable` and `SaveableObject` to be checkpointed 3168 """ 3169 # TODO(andreasst): make this function public once things are settled. 3170 return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) + 3171 ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope)) 3172 3173 3174@tf_export(v1=["local_variables"]) 3175def local_variables(scope=None): 3176 """Returns local variables. 3177 3178 Local variables - per process variables, usually not saved/restored to 3179 checkpoint and used for temporary or intermediate values. 3180 For example, they can be used as counters for metrics computation or 3181 number of epochs this machine has read data. 3182 The `tf.contrib.framework.local_variable()` function automatically adds the 3183 new variable to `GraphKeys.LOCAL_VARIABLES`. 3184 This convenience function returns the contents of that collection. 3185 3186 An alternative to local variables are global variables. See 3187 `tf.compat.v1.global_variables` 3188 3189 Args: 3190 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3191 include only items whose `name` attribute matches `scope` using 3192 `re.match`. Items without a `name` attribute are never returned if a scope 3193 is supplied. The choice of `re.match` means that a `scope` without special 3194 tokens filters by prefix. 3195 3196 Returns: 3197 A list of local `Variable` objects. 3198 """ 3199 return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES, scope) 3200 3201 3202@tf_export(v1=["model_variables"]) 3203def model_variables(scope=None): 3204 """Returns all variables in the MODEL_VARIABLES collection. 3205 3206 Args: 3207 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3208 include only items whose `name` attribute matches `scope` using 3209 `re.match`. Items without a `name` attribute are never returned if a scope 3210 is supplied. The choice of `re.match` means that a `scope` without special 3211 tokens filters by prefix. 3212 3213 Returns: 3214 A list of local Variable objects. 3215 """ 3216 return ops.get_collection(ops.GraphKeys.MODEL_VARIABLES, scope) 3217 3218 3219@tf_export(v1=["trainable_variables"]) 3220def trainable_variables(scope=None): 3221 """Returns all variables created with `trainable=True`. 3222 3223 When passed `trainable=True`, the `Variable()` constructor automatically 3224 adds new variables to the graph collection 3225 `GraphKeys.TRAINABLE_VARIABLES`. This convenience function returns the 3226 contents of that collection. 3227 3228 @compatibility(TF2) 3229 Not compatible with eager execution and `tf.function`. In particular, Graph 3230 collections are deprecated in TF2. Instead please create a `tf.Module` 3231 container for all your model state, including variables. 3232 You can then list all the trainable variables in your `tf.Module` through the 3233 `trainable_variables` attribute. 3234 @end_compatibility 3235 3236 Args: 3237 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3238 include only items whose `name` attribute matches `scope` using 3239 `re.match`. Items without a `name` attribute are never returned if a scope 3240 is supplied. The choice of `re.match` means that a `scope` without special 3241 tokens filters by prefix. 3242 3243 Returns: 3244 A list of Variable objects. 3245 """ 3246 return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, scope) 3247 3248 3249@tf_export(v1=["moving_average_variables"]) 3250def moving_average_variables(scope=None): 3251 """Returns all variables that maintain their moving averages. 3252 3253 If an `ExponentialMovingAverage` object is created and the `apply()` 3254 method is called on a list of variables, these variables will 3255 be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection. 3256 This convenience function returns the contents of that collection. 3257 3258 Args: 3259 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3260 include only items whose `name` attribute matches `scope` using 3261 `re.match`. Items without a `name` attribute are never returned if a scope 3262 is supplied. The choice of `re.match` means that a `scope` without special 3263 tokens filters by prefix. 3264 3265 Returns: 3266 A list of Variable objects. 3267 """ 3268 return ops.get_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, scope) 3269 3270 3271@tf_export(v1=["initializers.variables", "variables_initializer"]) 3272def variables_initializer(var_list, name="init"): 3273 """Returns an Op that initializes a list of variables. 3274 3275 After you launch the graph in a session, you can run the returned Op to 3276 initialize all the variables in `var_list`. This Op runs all the 3277 initializers of the variables in `var_list` in parallel. 3278 3279 Calling `initialize_variables()` is equivalent to passing the list of 3280 initializers to `Group()`. 3281 3282 If `var_list` is empty, however, the function still returns an Op that can 3283 be run. That Op just has no effect. 3284 3285 @compatibility(TF2) 3286 In TF2, variables are initialized immediately when they are created. There is 3287 no longer a need to run variable initializers before using them. 3288 @end_compatibility 3289 3290 Args: 3291 var_list: List of `Variable` objects to initialize. 3292 name: Optional name for the returned operation. 3293 3294 Returns: 3295 An Op that run the initializers of all the specified variables. 3296 """ 3297 if var_list and not context.executing_eagerly(): 3298 return control_flow_ops.group(*[v.initializer for v in var_list], name=name) 3299 return control_flow_ops.no_op(name=name) 3300 3301 3302@tf_export(v1=["initialize_variables"]) 3303@tf_should_use.should_use_result 3304@deprecated("2017-03-02", "Use `tf.variables_initializer` instead.") 3305def initialize_variables(var_list, name="init"): 3306 """See `tf.compat.v1.variables_initializer`.""" 3307 return variables_initializer(var_list, name=name) 3308 3309 3310@tf_export(v1=["initializers.global_variables", "global_variables_initializer"]) 3311def global_variables_initializer(): 3312 """Returns an Op that initializes global variables. 3313 3314 This is just a shortcut for `variables_initializer(global_variables())` 3315 3316 @compatibility(TF2) 3317 In TF2, variables are initialized immediately when they are created. There is 3318 no longer a need to run variable initializers before using them. 3319 @end_compatibility 3320 3321 Returns: 3322 An Op that initializes global variables in the graph. 3323 """ 3324 if context.executing_eagerly(): 3325 return control_flow_ops.no_op(name="global_variables_initializer") 3326 return variables_initializer(global_variables()) 3327 3328 3329@tf_export(v1=["initialize_all_variables"]) 3330@tf_should_use.should_use_result 3331@deprecated("2017-03-02", "Use `tf.global_variables_initializer` instead.") 3332def initialize_all_variables(): 3333 """See `tf.compat.v1.global_variables_initializer`.""" 3334 return global_variables_initializer() 3335 3336 3337@tf_export(v1=["initializers.local_variables", "local_variables_initializer"]) 3338def local_variables_initializer(): 3339 """Returns an Op that initializes all local variables. 3340 3341 This is just a shortcut for `variables_initializer(local_variables())` 3342 3343 @compatibility(TF2) 3344 In TF2, variables are initialized immediately when they are created. There is 3345 no longer a need to run variable initializers before using them. 3346 @end_compatibility 3347 3348 Returns: 3349 An Op that initializes all local variables in the graph. 3350 """ 3351 if context.executing_eagerly(): 3352 return control_flow_ops.no_op(name="local_variables_initializer") 3353 return variables_initializer(local_variables()) 3354 3355 3356@tf_export(v1=["initialize_local_variables"]) 3357@tf_should_use.should_use_result 3358@deprecated("2017-03-02", "Use `tf.local_variables_initializer` instead.") 3359def initialize_local_variables(): 3360 """See `tf.compat.v1.local_variables_initializer`.""" 3361 return local_variables_initializer() 3362 3363 3364@tf_export(v1=["is_variable_initialized"]) 3365@tf_should_use.should_use_result 3366def is_variable_initialized(variable): 3367 """Tests if a variable has been initialized. 3368 3369 Args: 3370 variable: A `Variable`. 3371 3372 Returns: 3373 Returns a scalar boolean Tensor, `True` if the variable has been 3374 initialized, `False` otherwise. 3375 """ 3376 return state_ops.is_variable_initialized(variable) 3377 3378 3379@tf_export(v1=["assert_variables_initialized"]) 3380@tf_should_use.should_use_result 3381def assert_variables_initialized(var_list=None): 3382 """Returns an Op to check if variables are initialized. 3383 3384 NOTE: This function is obsolete and will be removed in 6 months. Please 3385 change your implementation to use `report_uninitialized_variables()`. 3386 3387 When run, the returned Op will raise the exception `FailedPreconditionError` 3388 if any of the variables has not yet been initialized. 3389 3390 Note: This function is implemented by trying to fetch the values of the 3391 variables. If one of the variables is not initialized a message may be 3392 logged by the C++ runtime. This is expected. 3393 3394 Args: 3395 var_list: List of `Variable` objects to check. Defaults to the value of 3396 `global_variables().` 3397 3398 Returns: 3399 An Op, or None if there are no variables. 3400 """ 3401 if var_list is None: 3402 var_list = global_variables() + local_variables() 3403 # Backwards compatibility for old-style variables. TODO(touts): remove. 3404 if not var_list: 3405 var_list = [] 3406 for op in ops.get_default_graph().get_operations(): 3407 if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]: 3408 var_list.append(op.outputs[0]) 3409 if not var_list: 3410 return None 3411 else: 3412 ranks = [] 3413 for var in var_list: 3414 with ops.colocate_with(var.op): 3415 ranks.append(array_ops.rank_internal(var, optimize=False)) 3416 if len(ranks) == 1: 3417 return ranks[0] 3418 else: 3419 return array_ops.stack(ranks) 3420 3421 3422@tf_export(v1=["report_uninitialized_variables"]) 3423@tf_should_use.should_use_result 3424def report_uninitialized_variables(var_list=None, 3425 name="report_uninitialized_variables"): 3426 """Adds ops to list the names of uninitialized variables. 3427 3428 When run, it returns a 1-D tensor containing the names of uninitialized 3429 variables if there are any, or an empty array if there are none. 3430 3431 Args: 3432 var_list: List of `Variable` objects to check. Defaults to the value of 3433 `global_variables() + local_variables()` 3434 name: Optional name of the `Operation`. 3435 3436 Returns: 3437 A 1-D tensor containing names of the uninitialized variables, or an empty 3438 1-D tensor if there are no variables or no uninitialized variables. 3439 """ 3440 if var_list is None: 3441 var_list = global_variables() + local_variables() 3442 # Backwards compatibility for old-style variables. TODO(touts): remove. 3443 if not var_list: 3444 var_list = [] 3445 for op in ops.get_default_graph().get_operations(): 3446 if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]: 3447 var_list.append(op.outputs[0]) 3448 with ops.name_scope(name): 3449 # Run all operations on CPU 3450 if var_list: 3451 init_vars = [state_ops.is_variable_initialized(v) for v in var_list] 3452 local_device = os.environ.get( 3453 "TF_DEVICE_FOR_UNINITIALIZED_VARIABLE_REPORTING", "/cpu:0") 3454 with ops.device(local_device): 3455 if not var_list: 3456 # Return an empty tensor so we only need to check for returned tensor 3457 # size being 0 as an indication of model ready. 3458 return array_ops.constant([], dtype=dtypes.string) 3459 else: 3460 # Get a 1-D boolean tensor listing whether each variable is initialized. 3461 variables_mask = math_ops.logical_not(array_ops.stack(init_vars)) 3462 # Get a 1-D string tensor containing all the variable names. 3463 variable_names_tensor = array_ops.constant( 3464 [s.op.name for s in var_list]) 3465 # Return a 1-D tensor containing all the names of 3466 # uninitialized variables. 3467 return array_ops.boolean_mask(variable_names_tensor, variables_mask) 3468 3469 3470ops.register_tensor_conversion_function( 3471 PartitionedVariable, PartitionedVariable._TensorConversionFunction) # pylint: disable=protected-access 3472