1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Variable class.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import enum # pylint: disable=g-bad-import-order 21import itertools 22import functools 23import os 24 25import six 26 27from tensorflow.core.framework import attr_value_pb2 28from tensorflow.core.framework import variable_pb2 29from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import 30from tensorflow.python.eager import context 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import tensor_shape 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import control_flow_ops 36from tensorflow.python.ops import gen_array_ops 37from tensorflow.python.ops import gen_state_ops 38from tensorflow.python.ops import gen_math_ops 39from tensorflow.python.ops import math_ops 40from tensorflow.python.ops import state_ops 41from tensorflow.python.platform import tf_logging as logging 42from tensorflow.python.training.tracking import base as trackable 43from tensorflow.python.util import _pywrap_utils 44from tensorflow.python.util import compat 45from tensorflow.python.util import object_identity 46from tensorflow.python.util import tf_should_use 47from tensorflow.python.util.deprecation import deprecated 48from tensorflow.python.util.deprecation import deprecated_args 49from tensorflow.python.util.tf_export import tf_export 50from tensorflow.python.types import core 51 52 53def default_variable_creator(_, **kwds): 54 del kwds 55 raise NotImplementedError("variable_scope needs to be imported") 56 57 58def default_variable_creator_v2(_, **kwds): 59 del kwds 60 raise NotImplementedError("variable_scope needs to be imported") 61 62 63def _make_getter(captured_getter, captured_previous): 64 """To avoid capturing loop variables.""" 65 66 def getter(**kwargs): 67 return captured_getter(captured_previous, **kwargs) 68 69 return getter 70 71 72@tf_export("VariableSynchronization") 73class VariableSynchronization(enum.Enum): 74 """Indicates when a distributed variable will be synced. 75 76 * `AUTO`: Indicates that the synchronization will be determined by the current 77 `DistributionStrategy` (eg. With `MirroredStrategy` this would be 78 `ON_WRITE`). 79 * `NONE`: Indicates that there will only be one copy of the variable, so 80 there is no need to sync. 81 * `ON_WRITE`: Indicates that the variable will be updated across devices 82 every time it is written. 83 * `ON_READ`: Indicates that the variable will be aggregated across devices 84 when it is read (eg. when checkpointing or when evaluating an op that uses 85 the variable). 86 """ 87 AUTO = 0 88 NONE = 1 89 ON_WRITE = 2 90 ON_READ = 3 91 92 93# LINT.IfChange 94@tf_export("VariableAggregation", v1=[]) 95class VariableAggregationV2(enum.Enum): 96 """Indicates how a distributed variable will be aggregated. 97 98 `tf.distribute.Strategy` distributes a model by making multiple copies 99 (called "replicas") acting data-parallel on different elements of the input 100 batch. When performing some variable-update operation, say 101 `var.assign_add(x)`, in a model, we need to resolve how to combine the 102 different values for `x` computed in the different replicas. 103 104 * `NONE`: This is the default, giving an error if you use a 105 variable-update operation with multiple replicas. 106 * `SUM`: Add the updates across replicas. 107 * `MEAN`: Take the arithmetic mean ("average") of the updates across replicas. 108 * `ONLY_FIRST_REPLICA`: This is for when every replica is performing the same 109 update, but we only want to perform the update once. Used, e.g., for the 110 global step counter. 111 """ 112 NONE = 0 113 SUM = 1 114 MEAN = 2 115 ONLY_FIRST_REPLICA = 3 116 117 def __hash__(self): 118 return hash(self.value) 119 120 def __eq__(self, other): 121 if self is other: 122 return True 123 elif isinstance(other, VariableAggregation): 124 return int(self.value) == int(other.value) 125 else: 126 return False 127 128 129@tf_export(v1=["VariableAggregation"]) 130class VariableAggregation(enum.Enum): 131 NONE = 0 132 SUM = 1 133 MEAN = 2 134 ONLY_FIRST_REPLICA = 3 135 ONLY_FIRST_TOWER = 3 # DEPRECATED 136 137 def __hash__(self): 138 return hash(self.value) 139 140 141# LINT.ThenChange(//tensorflow/core/framework/variable.proto) 142# 143# Note that we are currently relying on the integer values of the Python enums 144# matching the integer values of the proto enums. 145 146VariableAggregation.__doc__ = ( 147 VariableAggregationV2.__doc__ + 148 "* `ONLY_FIRST_TOWER`: Deprecated alias for `ONLY_FIRST_REPLICA`.\n ") 149 150 151def validate_synchronization_aggregation_trainable(synchronization, aggregation, 152 trainable, name): 153 """Given user-provided variable properties, sets defaults and validates.""" 154 if aggregation is None: 155 aggregation = VariableAggregation.NONE 156 else: 157 if not isinstance(aggregation, 158 (VariableAggregation, VariableAggregationV2)): 159 try: 160 aggregation = VariableAggregationV2(aggregation) 161 except ValueError: 162 raise ValueError( 163 "Invalid variable aggregation mode: {} for variable: {}".format( 164 aggregation, name)) 165 if synchronization is None: 166 synchronization = VariableSynchronization.AUTO 167 else: 168 try: 169 synchronization = VariableSynchronization(synchronization) 170 except ValueError: 171 raise ValueError( 172 "Invalid variable synchronization mode: {} for variable: {}".format( 173 synchronization, name)) 174 if trainable is None: 175 trainable = synchronization != VariableSynchronization.ON_READ 176 return synchronization, aggregation, trainable 177 178 179class VariableMetaclass(type): 180 """Metaclass to allow construction of tf.Variable to be overridden.""" 181 182 def _variable_v1_call(cls, 183 initial_value=None, 184 trainable=None, 185 collections=None, 186 validate_shape=True, 187 caching_device=None, 188 name=None, 189 variable_def=None, 190 dtype=None, 191 expected_shape=None, 192 import_scope=None, 193 constraint=None, 194 use_resource=None, 195 synchronization=VariableSynchronization.AUTO, 196 aggregation=VariableAggregation.NONE, 197 shape=None): 198 """Call on Variable class. Useful to force the signature.""" 199 previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs) 200 for _, getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access 201 previous_getter = _make_getter(getter, previous_getter) 202 203 # Reset `aggregation` that is explicitly set as `None` to the enum NONE. 204 if aggregation is None: 205 aggregation = VariableAggregation.NONE 206 return previous_getter( 207 initial_value=initial_value, 208 trainable=trainable, 209 collections=collections, 210 validate_shape=validate_shape, 211 caching_device=caching_device, 212 name=name, 213 variable_def=variable_def, 214 dtype=dtype, 215 expected_shape=expected_shape, 216 import_scope=import_scope, 217 constraint=constraint, 218 use_resource=use_resource, 219 synchronization=synchronization, 220 aggregation=aggregation, 221 shape=shape) 222 223 def _variable_v2_call(cls, 224 initial_value=None, 225 trainable=None, 226 validate_shape=True, 227 caching_device=None, 228 name=None, 229 variable_def=None, 230 dtype=None, 231 import_scope=None, 232 constraint=None, 233 synchronization=VariableSynchronization.AUTO, 234 aggregation=VariableAggregation.NONE, 235 shape=None): 236 """Call on Variable class. Useful to force the signature.""" 237 previous_getter = lambda **kws: default_variable_creator_v2(None, **kws) 238 for _, getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access 239 previous_getter = _make_getter(getter, previous_getter) 240 241 # Reset `aggregation` that is explicitly set as `None` to the enum NONE. 242 if aggregation is None: 243 aggregation = VariableAggregation.NONE 244 return previous_getter( 245 initial_value=initial_value, 246 trainable=trainable, 247 validate_shape=validate_shape, 248 caching_device=caching_device, 249 name=name, 250 variable_def=variable_def, 251 dtype=dtype, 252 import_scope=import_scope, 253 constraint=constraint, 254 synchronization=synchronization, 255 aggregation=aggregation, 256 shape=shape) 257 258 def __call__(cls, *args, **kwargs): 259 if cls is VariableV1: 260 return cls._variable_v1_call(*args, **kwargs) 261 elif cls is Variable: 262 return cls._variable_v2_call(*args, **kwargs) 263 else: 264 return super(VariableMetaclass, cls).__call__(*args, **kwargs) 265 266 267@tf_export("Variable", v1=[]) 268# TODO(mdan): This should subclass core.Tensor, and not all its subclasses? 269class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)): 270 """See the [variable guide](https://tensorflow.org/guide/variable). 271 272 A variable maintains shared, persistent state manipulated by a program. 273 274 The `Variable()` constructor requires an initial value for the variable, which 275 can be a `Tensor` of any type and shape. This initial value defines the type 276 and shape of the variable. After construction, the type and shape of the 277 variable are fixed. The value can be changed using one of the assign methods. 278 279 >>> v = tf.Variable(1.) 280 >>> v.assign(2.) 281 <tf.Variable ... shape=() dtype=float32, numpy=2.0> 282 >>> v.assign_add(0.5) 283 <tf.Variable ... shape=() dtype=float32, numpy=2.5> 284 285 The `shape` argument to `Variable`'s constructor allows you to construct a 286 variable with a less defined shape than its `initial_value`: 287 288 >>> v = tf.Variable(1., shape=tf.TensorShape(None)) 289 >>> v.assign([[1.]]) 290 <tf.Variable ... shape=<unknown> dtype=float32, numpy=array([[1.]], ...)> 291 292 Just like any `Tensor`, variables created with `Variable()` can be used as 293 inputs to operations. Additionally, all the operators overloaded for the 294 `Tensor` class are carried over to variables. 295 296 >>> w = tf.Variable([[1.], [2.]]) 297 >>> x = tf.constant([[3., 4.]]) 298 >>> tf.matmul(w, x) 299 <tf.Tensor:... shape=(2, 2), ... numpy= 300 array([[3., 4.], 301 [6., 8.]], dtype=float32)> 302 >>> tf.sigmoid(w + x) 303 <tf.Tensor:... shape=(2, 2), ...> 304 305 When building a machine learning model it is often convenient to distinguish 306 between variables holding trainable model parameters and other variables such 307 as a `step` variable used to count training steps. To make this easier, the 308 variable constructor supports a `trainable=<bool>` 309 parameter. `tf.GradientTape` watches trainable variables by default: 310 311 >>> with tf.GradientTape(persistent=True) as tape: 312 ... trainable = tf.Variable(1.) 313 ... non_trainable = tf.Variable(2., trainable=False) 314 ... x1 = trainable * 2. 315 ... x2 = non_trainable * 3. 316 >>> tape.gradient(x1, trainable) 317 <tf.Tensor:... shape=(), dtype=float32, numpy=2.0> 318 >>> assert tape.gradient(x2, non_trainable) is None # Unwatched 319 320 Variables are automatically tracked when assigned to attributes of types 321 inheriting from `tf.Module`. 322 323 >>> m = tf.Module() 324 >>> m.v = tf.Variable([1.]) 325 >>> m.trainable_variables 326 (<tf.Variable ... shape=(1,) ... numpy=array([1.], dtype=float32)>,) 327 328 This tracking then allows saving variable values to 329 [training checkpoints](https://www.tensorflow.org/guide/checkpoint), or to 330 [SavedModels](https://www.tensorflow.org/guide/saved_model) which include 331 serialized TensorFlow graphs. 332 333 Variables are often captured and manipulated by `tf.function`s. This works the 334 same way the un-decorated function would have: 335 336 >>> v = tf.Variable(0.) 337 >>> read_and_decrement = tf.function(lambda: v.assign_sub(0.1)) 338 >>> read_and_decrement() 339 <tf.Tensor: shape=(), dtype=float32, numpy=-0.1> 340 >>> read_and_decrement() 341 <tf.Tensor: shape=(), dtype=float32, numpy=-0.2> 342 343 Variables created inside a `tf.function` must be owned outside the function 344 and be created only once: 345 346 >>> class M(tf.Module): 347 ... @tf.function 348 ... def __call__(self, x): 349 ... if not hasattr(self, "v"): # Or set self.v to None in __init__ 350 ... self.v = tf.Variable(x) 351 ... return self.v * x 352 >>> m = M() 353 >>> m(2.) 354 <tf.Tensor: shape=(), dtype=float32, numpy=4.0> 355 >>> m(3.) 356 <tf.Tensor: shape=(), dtype=float32, numpy=6.0> 357 >>> m.v 358 <tf.Variable ... shape=() dtype=float32, numpy=2.0> 359 360 See the `tf.function` documentation for details. 361 """ 362 363 @deprecated_args( 364 None, 365 "A variable's value can be manually cached by calling " 366 "tf.Variable.read_value() under a tf.device scope. The caching_device " 367 "argument does not work properly.", 368 "caching_device") 369 def __init__(self, 370 initial_value=None, 371 trainable=None, 372 validate_shape=True, 373 caching_device=None, 374 name=None, 375 variable_def=None, 376 dtype=None, 377 import_scope=None, 378 constraint=None, 379 synchronization=VariableSynchronization.AUTO, 380 aggregation=VariableAggregation.NONE, 381 shape=None): 382 """Creates a new variable with value `initial_value`. 383 384 Args: 385 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 386 which is the initial value for the Variable. The initial value must have 387 a shape specified unless `validate_shape` is set to False. Can also be a 388 callable with no argument that returns the initial value when called. In 389 that case, `dtype` must be specified. (Note that initializer functions 390 from init_ops.py must first be bound to a shape before being used here.) 391 trainable: If `True`, GradientTapes automatically watch uses of this 392 variable. Defaults to `True`, unless `synchronization` is set to 393 `ON_READ`, in which case it defaults to `False`. 394 validate_shape: If `False`, allows the variable to be initialized with a 395 value of unknown shape. If `True`, the default, the shape of 396 `initial_value` must be known. 397 caching_device: Optional device string describing where the Variable 398 should be cached for reading. Defaults to the Variable's device. If not 399 `None`, caches on another device. Typical use is to cache on the device 400 where the Ops using the Variable reside, to deduplicate copying through 401 `Switch` and other conditional statements. 402 name: Optional name for the variable. Defaults to `'Variable'` and gets 403 uniquified automatically. 404 variable_def: `VariableDef` protocol buffer. If not `None`, recreates the 405 Variable object with its contents, referencing the variable's nodes in 406 the graph, which must already exist. The graph is not changed. 407 `variable_def` and the other arguments are mutually exclusive. 408 dtype: If set, initial_value will be converted to the given type. If 409 `None`, either the datatype will be kept (if `initial_value` is a 410 Tensor), or `convert_to_tensor` will decide. 411 import_scope: Optional `string`. Name scope to add to the `Variable.` Only 412 used when initializing from protocol buffer. 413 constraint: An optional projection function to be applied to the variable 414 after being updated by an `Optimizer` (e.g. used to implement norm 415 constraints or value constraints for layer weights). The function must 416 take as input the unprojected Tensor representing the value of the 417 variable and return the Tensor for the projected value (which must have 418 the same shape). Constraints are not safe to use when doing asynchronous 419 distributed training. 420 synchronization: Indicates when a distributed a variable will be 421 aggregated. Accepted values are constants defined in the class 422 `tf.VariableSynchronization`. By default the synchronization is set to 423 `AUTO` and the current `DistributionStrategy` chooses when to 424 synchronize. 425 aggregation: Indicates how a distributed variable will be aggregated. 426 Accepted values are constants defined in the class 427 `tf.VariableAggregation`. 428 shape: (optional) The shape of this variable. If None, the shape of 429 `initial_value` will be used. When setting this argument to 430 `tf.TensorShape(None)` (representing an unspecified shape), the variable 431 can be assigned with values of different shapes. 432 433 Raises: 434 ValueError: If both `variable_def` and initial_value are specified. 435 ValueError: If the initial value is not specified, or does not have a 436 shape and `validate_shape` is `True`. 437 """ 438 raise NotImplementedError 439 440 def __repr__(self): 441 raise NotImplementedError 442 443 def value(self): 444 """Returns the last snapshot of this variable. 445 446 You usually do not need to call this method as all ops that need the value 447 of the variable call it automatically through a `convert_to_tensor()` call. 448 449 Returns a `Tensor` which holds the value of the variable. You can not 450 assign a new value to this tensor as it is not a reference to the variable. 451 452 To avoid copies, if the consumer of the returned value is on the same device 453 as the variable, this actually returns the live value of the variable, not 454 a copy. Updates to the variable are seen by the consumer. If the consumer 455 is on a different device it will get a copy of the variable. 456 457 Returns: 458 A `Tensor` containing the value of the variable. 459 """ 460 raise NotImplementedError 461 462 def read_value(self): 463 """Returns the value of this variable, read in the current context. 464 465 Can be different from value() if it's on another device, with control 466 dependencies, etc. 467 468 Returns: 469 A `Tensor` containing the value of the variable. 470 """ 471 raise NotImplementedError 472 473 def set_shape(self, shape): 474 """Overrides the shape for this variable. 475 476 Args: 477 shape: the `TensorShape` representing the overridden shape. 478 """ 479 raise NotImplementedError 480 481 @property 482 def trainable(self): 483 raise NotImplementedError 484 485 @property 486 def synchronization(self): 487 raise NotImplementedError 488 489 @property 490 def aggregation(self): 491 raise NotImplementedError 492 493 def eval(self, session=None): 494 """In a session, computes and returns the value of this variable. 495 496 This is not a graph construction method, it does not add ops to the graph. 497 498 This convenience method requires a session where the graph 499 containing this variable has been launched. If no session is 500 passed, the default session is used. See `tf.compat.v1.Session` for more 501 information on launching a graph and on sessions. 502 503 ```python 504 v = tf.Variable([1, 2]) 505 init = tf.compat.v1.global_variables_initializer() 506 507 with tf.compat.v1.Session() as sess: 508 sess.run(init) 509 # Usage passing the session explicitly. 510 print(v.eval(sess)) 511 # Usage with the default session. The 'with' block 512 # above makes 'sess' the default session. 513 print(v.eval()) 514 ``` 515 516 Args: 517 session: The session to use to evaluate this variable. If none, the 518 default session is used. 519 520 Returns: 521 A numpy `ndarray` with a copy of the value of this variable. 522 """ 523 raise NotImplementedError 524 525 @deprecated( 526 None, "Use Variable.read_value. Variables in 2.X are initialized " 527 "automatically both in eager and graph (inside tf.defun) contexts.") 528 def initialized_value(self): 529 """Returns the value of the initialized variable. 530 531 You should use this instead of the variable itself to initialize another 532 variable with a value that depends on the value of this variable. 533 534 ```python 535 # Initialize 'v' with a random tensor. 536 v = tf.Variable(tf.random.truncated_normal([10, 40])) 537 # Use `initialized_value` to guarantee that `v` has been 538 # initialized before its value is used to initialize `w`. 539 # The random values are picked only once. 540 w = tf.Variable(v.initialized_value() * 2.0) 541 ``` 542 543 Returns: 544 A `Tensor` holding the value of this variable after its initializer 545 has run. 546 """ 547 with ops.init_scope(): 548 return control_flow_ops.cond( 549 is_variable_initialized(self), self.read_value, 550 lambda: self.initial_value) 551 552 @property 553 def initial_value(self): 554 """Returns the Tensor used as the initial value for the variable. 555 556 Note that this is different from `initialized_value()` which runs 557 the op that initializes the variable before returning its value. 558 This method returns the tensor that is used by the op that initializes 559 the variable. 560 561 Returns: 562 A `Tensor`. 563 """ 564 raise NotImplementedError 565 566 @property 567 def constraint(self): 568 """Returns the constraint function associated with this variable. 569 570 Returns: 571 The constraint function that was passed to the variable constructor. 572 Can be `None` if no constraint was passed. 573 """ 574 raise NotImplementedError 575 576 def assign(self, value, use_locking=False, name=None, read_value=True): 577 """Assigns a new value to the variable. 578 579 This is essentially a shortcut for `assign(self, value)`. 580 581 Args: 582 value: A `Tensor`. The new value for this variable. 583 use_locking: If `True`, use locking during the assignment. 584 name: The name of the operation to be created 585 read_value: if True, will return something which evaluates to the new 586 value of the variable; if False will return the assign op. 587 588 Returns: 589 The updated variable. If `read_value` is false, instead returns None in 590 Eager mode and the assign op in graph mode. 591 """ 592 raise NotImplementedError 593 594 def assign_add(self, delta, use_locking=False, name=None, read_value=True): 595 """Adds a value to this variable. 596 597 This is essentially a shortcut for `assign_add(self, delta)`. 598 599 Args: 600 delta: A `Tensor`. The value to add to this variable. 601 use_locking: If `True`, use locking during the operation. 602 name: The name of the operation to be created 603 read_value: if True, will return something which evaluates to the new 604 value of the variable; if False will return the assign op. 605 606 Returns: 607 The updated variable. If `read_value` is false, instead returns None in 608 Eager mode and the assign op in graph mode. 609 """ 610 raise NotImplementedError 611 612 def assign_sub(self, delta, use_locking=False, name=None, read_value=True): 613 """Subtracts a value from this variable. 614 615 This is essentially a shortcut for `assign_sub(self, delta)`. 616 617 Args: 618 delta: A `Tensor`. The value to subtract from this variable. 619 use_locking: If `True`, use locking during the operation. 620 name: The name of the operation to be created 621 read_value: if True, will return something which evaluates to the new 622 value of the variable; if False will return the assign op. 623 624 Returns: 625 The updated variable. If `read_value` is false, instead returns None in 626 Eager mode and the assign op in graph mode. 627 """ 628 raise NotImplementedError 629 630 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 631 """Subtracts `tf.IndexedSlices` from this variable. 632 633 Args: 634 sparse_delta: `tf.IndexedSlices` to be subtracted from this variable. 635 use_locking: If `True`, use locking during the operation. 636 name: the name of the operation. 637 638 Returns: 639 The updated variable. 640 641 Raises: 642 TypeError: if `sparse_delta` is not an `IndexedSlices`. 643 """ 644 raise NotImplementedError 645 646 def scatter_add(self, sparse_delta, use_locking=False, name=None): 647 """Adds `tf.IndexedSlices` to this variable. 648 649 Args: 650 sparse_delta: `tf.IndexedSlices` to be added to this variable. 651 use_locking: If `True`, use locking during the operation. 652 name: the name of the operation. 653 654 Returns: 655 The updated variable. 656 657 Raises: 658 TypeError: if `sparse_delta` is not an `IndexedSlices`. 659 """ 660 raise NotImplementedError 661 662 def scatter_max(self, sparse_delta, use_locking=False, name=None): 663 """Updates this variable with the max of `tf.IndexedSlices` and itself. 664 665 Args: 666 sparse_delta: `tf.IndexedSlices` to use as an argument of max with this 667 variable. 668 use_locking: If `True`, use locking during the operation. 669 name: the name of the operation. 670 671 Returns: 672 The updated variable. 673 674 Raises: 675 TypeError: if `sparse_delta` is not an `IndexedSlices`. 676 """ 677 raise NotImplementedError 678 679 def scatter_min(self, sparse_delta, use_locking=False, name=None): 680 """Updates this variable with the min of `tf.IndexedSlices` and itself. 681 682 Args: 683 sparse_delta: `tf.IndexedSlices` to use as an argument of min with this 684 variable. 685 use_locking: If `True`, use locking during the operation. 686 name: the name of the operation. 687 688 Returns: 689 The updated variable. 690 691 Raises: 692 TypeError: if `sparse_delta` is not an `IndexedSlices`. 693 """ 694 raise NotImplementedError 695 696 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 697 """Multiply this variable by `tf.IndexedSlices`. 698 699 Args: 700 sparse_delta: `tf.IndexedSlices` to multiply this variable by. 701 use_locking: If `True`, use locking during the operation. 702 name: the name of the operation. 703 704 Returns: 705 The updated variable. 706 707 Raises: 708 TypeError: if `sparse_delta` is not an `IndexedSlices`. 709 """ 710 raise NotImplementedError 711 712 def scatter_div(self, sparse_delta, use_locking=False, name=None): 713 """Divide this variable by `tf.IndexedSlices`. 714 715 Args: 716 sparse_delta: `tf.IndexedSlices` to divide this variable by. 717 use_locking: If `True`, use locking during the operation. 718 name: the name of the operation. 719 720 Returns: 721 The updated variable. 722 723 Raises: 724 TypeError: if `sparse_delta` is not an `IndexedSlices`. 725 """ 726 raise NotImplementedError 727 728 def scatter_update(self, sparse_delta, use_locking=False, name=None): 729 """Assigns `tf.IndexedSlices` to this variable. 730 731 Args: 732 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 733 use_locking: If `True`, use locking during the operation. 734 name: the name of the operation. 735 736 Returns: 737 The updated variable. 738 739 Raises: 740 TypeError: if `sparse_delta` is not an `IndexedSlices`. 741 """ 742 raise NotImplementedError 743 744 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 745 """Assigns `tf.IndexedSlices` to this variable batch-wise. 746 747 Analogous to `batch_gather`. This assumes that this variable and the 748 sparse_delta IndexedSlices have a series of leading dimensions that are the 749 same for all of them, and the updates are performed on the last dimension of 750 indices. In other words, the dimensions should be the following: 751 752 `num_prefix_dims = sparse_delta.indices.ndims - 1` 753 `batch_dim = num_prefix_dims + 1` 754 `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[ 755 batch_dim:]` 756 757 where 758 759 `sparse_delta.updates.shape[:num_prefix_dims]` 760 `== sparse_delta.indices.shape[:num_prefix_dims]` 761 `== var.shape[:num_prefix_dims]` 762 763 And the operation performed can be expressed as: 764 765 `var[i_1, ..., i_n, 766 sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[ 767 i_1, ..., i_n, j]` 768 769 When sparse_delta.indices is a 1D tensor, this operation is equivalent to 770 `scatter_update`. 771 772 To avoid this operation one can looping over the first `ndims` of the 773 variable and using `scatter_update` on the subtensors that result of slicing 774 the first dimension. This is a valid option for `ndims = 1`, but less 775 efficient than this implementation. 776 777 Args: 778 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 779 use_locking: If `True`, use locking during the operation. 780 name: the name of the operation. 781 782 Returns: 783 The updated variable. 784 785 Raises: 786 TypeError: if `sparse_delta` is not an `IndexedSlices`. 787 """ 788 raise NotImplementedError 789 790 def scatter_nd_sub(self, indices, updates, name=None): 791 """Applies sparse subtraction to individual values or slices in a Variable. 792 793 Assuming the variable has rank `P` and `indices` is a `Tensor` of rank `Q`. 794 795 `indices` must be integer tensor, containing indices into self. 796 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 797 798 The innermost dimension of `indices` (with length `K`) corresponds to 799 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 800 dimension of self. 801 802 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 803 804 ``` 805 [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]]. 806 ``` 807 808 For example, say we want to add 4 scattered elements to a rank-1 tensor to 809 8 elements. In Python, that update would look like this: 810 811 ```python 812 v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 813 indices = tf.constant([[4], [3], [1] ,[7]]) 814 updates = tf.constant([9, 10, 11, 12]) 815 op = v.scatter_nd_sub(indices, updates) 816 with tf.compat.v1.Session() as sess: 817 print sess.run(op) 818 ``` 819 820 The resulting update to v would look like this: 821 822 [1, -9, 3, -6, -6, 6, 7, -4] 823 824 See `tf.scatter_nd` for more details about how to make updates to 825 slices. 826 827 Args: 828 indices: The indices to be used in the operation. 829 updates: The values to be used in the operation. 830 name: the name of the operation. 831 832 Returns: 833 The updated variable. 834 """ 835 raise NotImplementedError 836 837 def scatter_nd_add(self, indices, updates, name=None): 838 """Applies sparse addition to individual values or slices in a Variable. 839 840 The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`. 841 842 `indices` must be integer tensor, containing indices into self. 843 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 844 845 The innermost dimension of `indices` (with length `K`) corresponds to 846 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 847 dimension of self. 848 849 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 850 851 ``` 852 [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]]. 853 ``` 854 855 For example, say we want to add 4 scattered elements to a rank-1 tensor to 856 8 elements. In Python, that update would look like this: 857 858 ```python 859 v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 860 indices = tf.constant([[4], [3], [1] ,[7]]) 861 updates = tf.constant([9, 10, 11, 12]) 862 add = v.scatter_nd_add(indices, updates) 863 with tf.compat.v1.Session() as sess: 864 print sess.run(add) 865 ``` 866 867 The resulting update to v would look like this: 868 869 [1, 13, 3, 14, 14, 6, 7, 20] 870 871 See `tf.scatter_nd` for more details about how to make updates to 872 slices. 873 874 Args: 875 indices: The indices to be used in the operation. 876 updates: The values to be used in the operation. 877 name: the name of the operation. 878 879 Returns: 880 The updated variable. 881 """ 882 raise NotImplementedError 883 884 def scatter_nd_update(self, indices, updates, name=None): 885 """Applies sparse assignment to individual values or slices in a Variable. 886 887 The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`. 888 889 `indices` must be integer tensor, containing indices into self. 890 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 891 892 The innermost dimension of `indices` (with length `K`) corresponds to 893 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 894 dimension of self. 895 896 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 897 898 ``` 899 [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]]. 900 ``` 901 902 For example, say we want to add 4 scattered elements to a rank-1 tensor to 903 8 elements. In Python, that update would look like this: 904 905 ```python 906 v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 907 indices = tf.constant([[4], [3], [1] ,[7]]) 908 updates = tf.constant([9, 10, 11, 12]) 909 op = v.scatter_nd_assign(indices, updates) 910 with tf.compat.v1.Session() as sess: 911 print sess.run(op) 912 ``` 913 914 The resulting update to v would look like this: 915 916 [1, 11, 3, 10, 9, 6, 7, 12] 917 918 See `tf.scatter_nd` for more details about how to make updates to 919 slices. 920 921 Args: 922 indices: The indices to be used in the operation. 923 updates: The values to be used in the operation. 924 name: the name of the operation. 925 926 Returns: 927 The updated variable. 928 """ 929 raise NotImplementedError 930 931 def sparse_read(self, indices, name=None): 932 r"""Gather slices from params axis axis according to indices. 933 934 This function supports a subset of tf.gather, see tf.gather for details on 935 usage. 936 937 Args: 938 indices: The index `Tensor`. Must be one of the following types: `int32`, 939 `int64`. Must be in range `[0, params.shape[axis])`. 940 name: A name for the operation (optional). 941 942 Returns: 943 A `Tensor`. Has the same type as `params`. 944 """ 945 raise AttributeError 946 947 def gather_nd(self, indices, name=None): 948 r"""Gather slices from `params` into a Tensor with shape specified by `indices`. 949 950 See tf.gather_nd for details. 951 952 Args: 953 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 954 Index tensor. 955 name: A name for the operation (optional). 956 957 Returns: 958 A `Tensor`. Has the same type as `params`. 959 """ 960 raise AttributeError 961 962 @deprecated(None, "Prefer Dataset.range instead.") 963 def count_up_to(self, limit): 964 """Increments this variable until it reaches `limit`. 965 966 When that Op is run it tries to increment the variable by `1`. If 967 incrementing the variable would bring it above `limit` then the Op raises 968 the exception `OutOfRangeError`. 969 970 If no error is raised, the Op outputs the value of the variable before 971 the increment. 972 973 This is essentially a shortcut for `count_up_to(self, limit)`. 974 975 Args: 976 limit: value at which incrementing the variable raises an error. 977 978 Returns: 979 A `Tensor` that will hold the variable value before the increment. If no 980 other Op modifies this variable, the values produced will all be 981 distinct. 982 """ 983 raise NotImplementedError 984 985 @deprecated(None, 986 "Prefer Variable.assign which has equivalent behavior in 2.X.") 987 def load(self, value, session=None): 988 """Load new value into this variable. 989 990 Writes new value to variable's memory. Doesn't add ops to the graph. 991 992 This convenience method requires a session where the graph 993 containing this variable has been launched. If no session is 994 passed, the default session is used. See `tf.compat.v1.Session` for more 995 information on launching a graph and on sessions. 996 997 ```python 998 v = tf.Variable([1, 2]) 999 init = tf.compat.v1.global_variables_initializer() 1000 1001 with tf.compat.v1.Session() as sess: 1002 sess.run(init) 1003 # Usage passing the session explicitly. 1004 v.load([2, 3], sess) 1005 print(v.eval(sess)) # prints [2 3] 1006 # Usage with the default session. The 'with' block 1007 # above makes 'sess' the default session. 1008 v.load([3, 4], sess) 1009 print(v.eval()) # prints [3 4] 1010 ``` 1011 1012 Args: 1013 value: New variable value 1014 session: The session to use to evaluate this variable. If none, the 1015 default session is used. 1016 1017 Raises: 1018 ValueError: Session is not passed and no default session 1019 """ 1020 if context.executing_eagerly(): 1021 self.assign(value) 1022 else: 1023 session = session or ops.get_default_session() 1024 if session is None: 1025 raise ValueError( 1026 "Either session argument should be provided or default session " 1027 "should be established") 1028 session.run(self.initializer, {self.initializer.inputs[1]: value}) 1029 1030 # Conversion to tensor. 1031 @staticmethod 1032 def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name 1033 """Utility function for converting a Variable to a Tensor.""" 1034 _ = name 1035 if dtype and not dtype.is_compatible_with(v.dtype): 1036 raise ValueError( 1037 "Incompatible type conversion requested to type '%s' for variable " 1038 "of type '%s'" % (dtype.name, v.dtype.name)) 1039 if as_ref: 1040 return v._ref() # pylint: disable=protected-access 1041 else: 1042 return v.value() 1043 1044 @classmethod 1045 def _OverloadAllOperators(cls): # pylint: disable=invalid-name 1046 """Register overloads for all operators.""" 1047 for operator in ops.Tensor.OVERLOADABLE_OPERATORS: 1048 cls._OverloadOperator(operator) 1049 # For slicing, bind getitem differently than a tensor (use SliceHelperVar 1050 # instead) 1051 # pylint: disable=protected-access 1052 setattr(cls, "__getitem__", array_ops._SliceHelperVar) 1053 1054 @classmethod 1055 def _OverloadOperator(cls, operator): # pylint: disable=invalid-name 1056 """Defer an operator overload to `ops.Tensor`. 1057 1058 We pull the operator out of ops.Tensor dynamically to avoid ordering issues. 1059 1060 Args: 1061 operator: string. The operator name. 1062 """ 1063 # We can't use the overload mechanism on __eq__ & __ne__ since __eq__ is 1064 # called when adding a variable to sets. As a result we call a.value() which 1065 # causes infinite recursion when operating within a GradientTape 1066 # TODO(gjn): Consider removing this 1067 if operator == "__eq__" or operator == "__ne__": 1068 return 1069 1070 tensor_oper = getattr(ops.Tensor, operator) 1071 1072 def _run_op(a, *args, **kwargs): 1073 # pylint: disable=protected-access 1074 return tensor_oper(a.value(), *args, **kwargs) 1075 1076 functools.update_wrapper(_run_op, tensor_oper) 1077 setattr(cls, operator, _run_op) 1078 1079 def __hash__(self): 1080 if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access 1081 raise TypeError("Variable is unhashable. " 1082 "Instead, use tensor.ref() as the key.") 1083 else: 1084 return id(self) 1085 1086 # TODO(gjn): duplicate of math_ops.tensor_equals, consider removing 1087 def __eq__(self, other): 1088 """Compares two variables element-wise for equality.""" 1089 if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access 1090 return gen_math_ops.equal(self, other, incompatible_shape_error=False) 1091 else: 1092 # In legacy graph mode, tensor equality is object equality 1093 return self is other 1094 1095 # TODO(gjn): duplicate of math_ops.tensor_not_equals, consider removing 1096 def __ne__(self, other): 1097 """Compares two variables element-wise for equality.""" 1098 if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access 1099 return gen_math_ops.not_equal(self, other, incompatible_shape_error=False) 1100 else: 1101 # In legacy graph mode, tensor equality is object equality 1102 return self is not other 1103 1104 def __iter__(self): 1105 """Dummy method to prevent iteration. 1106 1107 Do not call. 1108 1109 NOTE(mrry): If we register __getitem__ as an overloaded operator, 1110 Python will valiantly attempt to iterate over the variable's Tensor from 0 1111 to infinity. Declaring this method prevents this unintended behavior. 1112 1113 Raises: 1114 TypeError: when invoked. 1115 """ 1116 raise TypeError("'Variable' object is not iterable.") 1117 1118 # NOTE(mrry): This enables the Variable's overloaded "right" binary 1119 # operators to run when the left operand is an ndarray, because it 1120 # accords the Variable class higher priority than an ndarray, or a 1121 # numpy matrix. 1122 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ 1123 # mechanism, which allows more control over how Variables interact 1124 # with ndarrays. 1125 __array_priority__ = 100 1126 1127 @property 1128 def name(self): 1129 """The name of this variable.""" 1130 raise NotImplementedError 1131 1132 @property 1133 def _shared_name(self): 1134 """The shared name of the variable. 1135 1136 Unlike name(), shared_name doesn't have ":0" suffix. It is user-specified 1137 name with name scope prefix. 1138 1139 Returns: 1140 variable name. 1141 """ 1142 return self.name[:self.name.index(":")] 1143 1144 @property 1145 def initializer(self): 1146 """The initializer operation for this variable.""" 1147 raise NotImplementedError 1148 1149 @property 1150 def device(self): 1151 """The device of this variable.""" 1152 raise NotImplementedError 1153 1154 @property 1155 def dtype(self): 1156 """The `DType` of this variable.""" 1157 raise NotImplementedError 1158 1159 @property 1160 def op(self): 1161 """The `Operation` of this variable.""" 1162 raise NotImplementedError 1163 1164 @property 1165 def graph(self): 1166 """The `Graph` of this variable.""" 1167 raise NotImplementedError 1168 1169 @property 1170 def shape(self): 1171 """The `TensorShape` of this variable. 1172 1173 Returns: 1174 A `TensorShape`. 1175 """ 1176 raise NotImplementedError 1177 1178 def get_shape(self): 1179 """Alias of `Variable.shape`.""" 1180 return self.shape 1181 1182 def _gather_saveables_for_checkpoint(self): 1183 """For implementing `Trackable`. This object is saveable on its own.""" 1184 return {trackable.VARIABLE_VALUE_KEY: self} 1185 1186 def to_proto(self, export_scope=None): 1187 """Converts a `Variable` to a `VariableDef` protocol buffer. 1188 1189 Args: 1190 export_scope: Optional `string`. Name scope to remove. 1191 1192 Returns: 1193 A `VariableDef` protocol buffer, or `None` if the `Variable` is not 1194 in the specified name scope. 1195 """ 1196 raise NotImplementedError 1197 1198 @staticmethod 1199 def from_proto(variable_def, import_scope=None): 1200 """Returns a `Variable` object created from `variable_def`.""" 1201 return RefVariable(variable_def=variable_def, import_scope=import_scope) 1202 1203 def _set_save_slice_info(self, save_slice_info): 1204 """Sets the slice info for this `Variable`. 1205 1206 Args: 1207 save_slice_info: A `Variable.SaveSliceInfo` object. 1208 """ 1209 self._save_slice_info = save_slice_info 1210 1211 def _get_save_slice_info(self): 1212 return self._save_slice_info 1213 1214 @deprecated(None, "Use ref() instead.") 1215 def experimental_ref(self): 1216 return self.ref() 1217 1218 def ref(self): 1219 # tf.Tensor also has the same ref() API. If you update the 1220 # documentation here, please update tf.Tensor.ref() as well. 1221 """Returns a hashable reference object to this Variable. 1222 1223 The primary use case for this API is to put variables in a set/dictionary. 1224 We can't put variables in a set/dictionary as `variable.__hash__()` is no 1225 longer available starting Tensorflow 2.0. 1226 1227 The following will raise an exception starting 2.0 1228 1229 >>> x = tf.Variable(5) 1230 >>> y = tf.Variable(10) 1231 >>> z = tf.Variable(10) 1232 >>> variable_set = {x, y, z} 1233 Traceback (most recent call last): 1234 ... 1235 TypeError: Variable is unhashable. Instead, use tensor.ref() as the key. 1236 >>> variable_dict = {x: 'five', y: 'ten'} 1237 Traceback (most recent call last): 1238 ... 1239 TypeError: Variable is unhashable. Instead, use tensor.ref() as the key. 1240 1241 Instead, we can use `variable.ref()`. 1242 1243 >>> variable_set = {x.ref(), y.ref(), z.ref()} 1244 >>> x.ref() in variable_set 1245 True 1246 >>> variable_dict = {x.ref(): 'five', y.ref(): 'ten', z.ref(): 'ten'} 1247 >>> variable_dict[y.ref()] 1248 'ten' 1249 1250 Also, the reference object provides `.deref()` function that returns the 1251 original Variable. 1252 1253 >>> x = tf.Variable(5) 1254 >>> x.ref().deref() 1255 <tf.Variable 'Variable:0' shape=() dtype=int32, numpy=5> 1256 """ 1257 return object_identity.Reference(self) 1258 1259 class SaveSliceInfo(object): 1260 """Information on how to save this Variable as a slice. 1261 1262 Provides internal support for saving variables as slices of a larger 1263 variable. This API is not public and is subject to change. 1264 1265 Available properties: 1266 1267 * full_name 1268 * full_shape 1269 * var_offset 1270 * var_shape 1271 """ 1272 1273 def __init__(self, 1274 full_name=None, 1275 full_shape=None, 1276 var_offset=None, 1277 var_shape=None, 1278 save_slice_info_def=None, 1279 import_scope=None): 1280 """Create a `SaveSliceInfo`. 1281 1282 Args: 1283 full_name: Name of the full variable of which this `Variable` is a 1284 slice. 1285 full_shape: Shape of the full variable, as a list of int. 1286 var_offset: Offset of this `Variable` into the full variable, as a list 1287 of int. 1288 var_shape: Shape of this `Variable`, as a list of int. 1289 save_slice_info_def: `SaveSliceInfoDef` protocol buffer. If not `None`, 1290 recreates the SaveSliceInfo object its contents. `save_slice_info_def` 1291 and other arguments are mutually exclusive. 1292 import_scope: Optional `string`. Name scope to add. Only used when 1293 initializing from protocol buffer. 1294 """ 1295 if save_slice_info_def: 1296 assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef) 1297 self.full_name = ops.prepend_name_scope( 1298 save_slice_info_def.full_name, import_scope=import_scope) 1299 self.full_shape = [i for i in save_slice_info_def.full_shape] 1300 self.var_offset = [i for i in save_slice_info_def.var_offset] 1301 self.var_shape = [i for i in save_slice_info_def.var_shape] 1302 else: 1303 self.full_name = full_name 1304 self.full_shape = full_shape 1305 self.var_offset = var_offset 1306 self.var_shape = var_shape 1307 1308 @property 1309 def spec(self): 1310 """Computes the spec string used for saving.""" 1311 full_shape_str = " ".join("%d" % d for d in self.full_shape) + " " 1312 sl_spec = ":".join( 1313 "%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape)) 1314 return full_shape_str + sl_spec 1315 1316 def to_proto(self, export_scope=None): 1317 """Returns a SaveSliceInfoDef() proto. 1318 1319 Args: 1320 export_scope: Optional `string`. Name scope to remove. 1321 1322 Returns: 1323 A `SaveSliceInfoDef` protocol buffer, or None if the `Variable` is not 1324 in the specified name scope. 1325 """ 1326 if (export_scope is None or self.full_name.startswith(export_scope)): 1327 save_slice_info_def = variable_pb2.SaveSliceInfoDef() 1328 save_slice_info_def.full_name = ops.strip_name_scope( 1329 self.full_name, export_scope) 1330 for i in self.full_shape: 1331 save_slice_info_def.full_shape.append(i) 1332 for i in self.var_offset: 1333 save_slice_info_def.var_offset.append(i) 1334 for i in self.var_shape: 1335 save_slice_info_def.var_shape.append(i) 1336 return save_slice_info_def 1337 else: 1338 return None 1339 1340 1341Variable._OverloadAllOperators() # pylint: disable=protected-access 1342_pywrap_utils.RegisterType("Variable", Variable) 1343 1344 1345@tf_export(v1=["Variable"]) 1346class VariableV1(Variable): 1347 """See the [Variables Guide](https://tensorflow.org/guide/variables). 1348 1349 A variable maintains state in the graph across calls to `run()`. You add a 1350 variable to the graph by constructing an instance of the class `Variable`. 1351 1352 The `Variable()` constructor requires an initial value for the variable, 1353 which can be a `Tensor` of any type and shape. The initial value defines the 1354 type and shape of the variable. After construction, the type and shape of 1355 the variable are fixed. The value can be changed using one of the assign 1356 methods. 1357 1358 If you want to change the shape of a variable later you have to use an 1359 `assign` Op with `validate_shape=False`. 1360 1361 Just like any `Tensor`, variables created with `Variable()` can be used as 1362 inputs for other Ops in the graph. Additionally, all the operators 1363 overloaded for the `Tensor` class are carried over to variables, so you can 1364 also add nodes to the graph by just doing arithmetic on variables. 1365 1366 ```python 1367 import tensorflow as tf 1368 1369 # Create a variable. 1370 w = tf.Variable(<initial-value>, name=<optional-name>) 1371 1372 # Use the variable in the graph like any Tensor. 1373 y = tf.matmul(w, ...another variable or tensor...) 1374 1375 # The overloaded operators are available too. 1376 z = tf.sigmoid(w + y) 1377 1378 # Assign a new value to the variable with `assign()` or a related method. 1379 w.assign(w + 1.0) 1380 w.assign_add(1.0) 1381 ``` 1382 1383 When you launch the graph, variables have to be explicitly initialized before 1384 you can run Ops that use their value. You can initialize a variable by 1385 running its *initializer op*, restoring the variable from a save file, or 1386 simply running an `assign` Op that assigns a value to the variable. In fact, 1387 the variable *initializer op* is just an `assign` Op that assigns the 1388 variable's initial value to the variable itself. 1389 1390 ```python 1391 # Launch the graph in a session. 1392 with tf.compat.v1.Session() as sess: 1393 # Run the variable initializer. 1394 sess.run(w.initializer) 1395 # ...you now can run ops that use the value of 'w'... 1396 ``` 1397 1398 The most common initialization pattern is to use the convenience function 1399 `global_variables_initializer()` to add an Op to the graph that initializes 1400 all the variables. You then run that Op after launching the graph. 1401 1402 ```python 1403 # Add an Op to initialize global variables. 1404 init_op = tf.compat.v1.global_variables_initializer() 1405 1406 # Launch the graph in a session. 1407 with tf.compat.v1.Session() as sess: 1408 # Run the Op that initializes global variables. 1409 sess.run(init_op) 1410 # ...you can now run any Op that uses variable values... 1411 ``` 1412 1413 If you need to create a variable with an initial value dependent on another 1414 variable, use the other variable's `initialized_value()`. This ensures that 1415 variables are initialized in the right order. 1416 1417 All variables are automatically collected in the graph where they are 1418 created. By default, the constructor adds the new variable to the graph 1419 collection `GraphKeys.GLOBAL_VARIABLES`. The convenience function 1420 `global_variables()` returns the contents of that collection. 1421 1422 When building a machine learning model it is often convenient to distinguish 1423 between variables holding the trainable model parameters and other variables 1424 such as a `global step` variable used to count training steps. To make this 1425 easier, the variable constructor supports a `trainable=<bool>` parameter. If 1426 `True`, the new variable is also added to the graph collection 1427 `GraphKeys.TRAINABLE_VARIABLES`. The convenience function 1428 `trainable_variables()` returns the contents of this collection. The 1429 various `Optimizer` classes use this collection as the default list of 1430 variables to optimize. 1431 1432 WARNING: tf.Variable objects by default have a non-intuitive memory model. A 1433 Variable is represented internally as a mutable Tensor which can 1434 non-deterministically alias other Tensors in a graph. The set of operations 1435 which consume a Variable and can lead to aliasing is undetermined and can 1436 change across TensorFlow versions. Avoid writing code which relies on the 1437 value of a Variable either changing or not changing as other operations 1438 happen. For example, using Variable objects or simple functions thereof as 1439 predicates in a `tf.cond` is dangerous and error-prone: 1440 1441 ``` 1442 v = tf.Variable(True) 1443 tf.cond(v, lambda: v.assign(False), my_false_fn) # Note: this is broken. 1444 ``` 1445 1446 Here, adding `use_resource=True` when constructing the variable will 1447 fix any nondeterminism issues: 1448 ``` 1449 v = tf.Variable(True, use_resource=True) 1450 tf.cond(v, lambda: v.assign(False), my_false_fn) 1451 ``` 1452 1453 To use the replacement for variables which does 1454 not have these issues: 1455 1456 * Add `use_resource=True` when constructing `tf.Variable`; 1457 * Call `tf.compat.v1.get_variable_scope().set_use_resource(True)` inside a 1458 `tf.compat.v1.variable_scope` before the `tf.compat.v1.get_variable()` call. 1459 """ 1460 1461 def __init__( 1462 self, # pylint: disable=super-init-not-called 1463 initial_value=None, 1464 trainable=None, 1465 collections=None, 1466 validate_shape=True, 1467 caching_device=None, 1468 name=None, 1469 variable_def=None, 1470 dtype=None, 1471 expected_shape=None, 1472 import_scope=None, 1473 constraint=None, 1474 use_resource=None, 1475 synchronization=VariableSynchronization.AUTO, 1476 aggregation=VariableAggregation.NONE, 1477 shape=None): 1478 """Creates a new variable with value `initial_value`. 1479 1480 The new variable is added to the graph collections listed in `collections`, 1481 which defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1482 1483 If `trainable` is `True` the variable is also added to the graph collection 1484 `GraphKeys.TRAINABLE_VARIABLES`. 1485 1486 This constructor creates both a `variable` Op and an `assign` Op to set the 1487 variable to its initial value. 1488 1489 Args: 1490 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 1491 which is the initial value for the Variable. The initial value must have 1492 a shape specified unless `validate_shape` is set to False. Can also be a 1493 callable with no argument that returns the initial value when called. In 1494 that case, `dtype` must be specified. (Note that initializer functions 1495 from init_ops.py must first be bound to a shape before being used here.) 1496 trainable: If `True`, also adds the variable to the graph collection 1497 `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default 1498 list of variables to use by the `Optimizer` classes. Defaults to `True`, 1499 unless `synchronization` is set to `ON_READ`, in which case it defaults 1500 to `False`. 1501 collections: List of graph collections keys. The new variable is added to 1502 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1503 validate_shape: If `False`, allows the variable to be initialized with a 1504 value of unknown shape. If `True`, the default, the shape of 1505 `initial_value` must be known. 1506 caching_device: Optional device string describing where the Variable 1507 should be cached for reading. Defaults to the Variable's device. If not 1508 `None`, caches on another device. Typical use is to cache on the device 1509 where the Ops using the Variable reside, to deduplicate copying through 1510 `Switch` and other conditional statements. 1511 name: Optional name for the variable. Defaults to `'Variable'` and gets 1512 uniquified automatically. 1513 variable_def: `VariableDef` protocol buffer. If not `None`, recreates the 1514 Variable object with its contents, referencing the variable's nodes in 1515 the graph, which must already exist. The graph is not changed. 1516 `variable_def` and the other arguments are mutually exclusive. 1517 dtype: If set, initial_value will be converted to the given type. If 1518 `None`, either the datatype will be kept (if `initial_value` is a 1519 Tensor), or `convert_to_tensor` will decide. 1520 expected_shape: A TensorShape. If set, initial_value is expected to have 1521 this shape. 1522 import_scope: Optional `string`. Name scope to add to the `Variable.` Only 1523 used when initializing from protocol buffer. 1524 constraint: An optional projection function to be applied to the variable 1525 after being updated by an `Optimizer` (e.g. used to implement norm 1526 constraints or value constraints for layer weights). The function must 1527 take as input the unprojected Tensor representing the value of the 1528 variable and return the Tensor for the projected value (which must have 1529 the same shape). Constraints are not safe to use when doing asynchronous 1530 distributed training. 1531 use_resource: whether to use resource variables. 1532 synchronization: Indicates when a distributed a variable will be 1533 aggregated. Accepted values are constants defined in the class 1534 `tf.VariableSynchronization`. By default the synchronization is set to 1535 `AUTO` and the current `DistributionStrategy` chooses when to 1536 synchronize. 1537 aggregation: Indicates how a distributed variable will be aggregated. 1538 Accepted values are constants defined in the class 1539 `tf.VariableAggregation`. 1540 shape: (optional) The shape of this variable. If None, the shape of 1541 `initial_value` will be used. When setting this argument to 1542 `tf.TensorShape(None)` (representing an unspecified shape), the variable 1543 can be assigned with values of different shapes. 1544 1545 Raises: 1546 ValueError: If both `variable_def` and initial_value are specified. 1547 ValueError: If the initial value is not specified, or does not have a 1548 shape and `validate_shape` is `True`. 1549 RuntimeError: If eager execution is enabled. 1550 """ 1551 1552 SaveSliceInfo = Variable.SaveSliceInfo 1553 1554 1555# TODO(apassos): do not repeat all comments here 1556class RefVariable(VariableV1, core.Tensor): 1557 """Ref-based implementation of variables.""" 1558 1559 def __init__( 1560 self, # pylint: disable=super-init-not-called 1561 initial_value=None, 1562 trainable=None, 1563 collections=None, 1564 validate_shape=True, 1565 caching_device=None, 1566 name=None, 1567 variable_def=None, 1568 dtype=None, 1569 expected_shape=None, 1570 import_scope=None, 1571 constraint=None, 1572 synchronization=None, 1573 aggregation=None, 1574 shape=None): 1575 """Creates a new variable with value `initial_value`. 1576 1577 The new variable is added to the graph collections listed in `collections`, 1578 which defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1579 1580 If `trainable` is `True` the variable is also added to the graph collection 1581 `GraphKeys.TRAINABLE_VARIABLES`. 1582 1583 This constructor creates both a `variable` Op and an `assign` Op to set the 1584 variable to its initial value. 1585 1586 Args: 1587 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 1588 which is the initial value for the Variable. The initial value must have 1589 a shape specified unless `validate_shape` is set to False. Can also be a 1590 callable with no argument that returns the initial value when called. In 1591 that case, `dtype` must be specified. (Note that initializer functions 1592 from init_ops.py must first be bound to a shape before being used here.) 1593 trainable: If `True`, also adds the variable to the graph collection 1594 `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default 1595 list of variables to use by the `Optimizer` classes. Defaults to `True`, 1596 unless `synchronization` is set to `ON_READ`, in which case it defaults 1597 to `False`. 1598 collections: List of graph collections keys. The new variable is added to 1599 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1600 validate_shape: If `False`, allows the variable to be initialized with a 1601 value of unknown shape. If `True`, the default, the shape of 1602 `initial_value` must be known. 1603 caching_device: Optional device string describing where the Variable 1604 should be cached for reading. Defaults to the Variable's device. If not 1605 `None`, caches on another device. Typical use is to cache on the device 1606 where the Ops using the Variable reside, to deduplicate copying through 1607 `Switch` and other conditional statements. 1608 name: Optional name for the variable. Defaults to `'Variable'` and gets 1609 uniquified automatically. 1610 variable_def: `VariableDef` protocol buffer. If not `None`, recreates the 1611 Variable object with its contents, referencing the variable's nodes in 1612 the graph, which must already exist. The graph is not changed. 1613 `variable_def` and the other arguments are mutually exclusive. 1614 dtype: If set, initial_value will be converted to the given type. If 1615 `None`, either the datatype will be kept (if `initial_value` is a 1616 Tensor), or `convert_to_tensor` will decide. 1617 expected_shape: A TensorShape. If set, initial_value is expected to have 1618 this shape. 1619 import_scope: Optional `string`. Name scope to add to the `Variable.` Only 1620 used when initializing from protocol buffer. 1621 constraint: An optional projection function to be applied to the variable 1622 after being updated by an `Optimizer` (e.g. used to implement norm 1623 constraints or value constraints for layer weights). The function must 1624 take as input the unprojected Tensor representing the value of the 1625 variable and return the Tensor for the projected value (which must have 1626 the same shape). Constraints are not safe to use when doing asynchronous 1627 distributed training. 1628 synchronization: Indicates when a distributed a variable will be 1629 aggregated. Accepted values are constants defined in the class 1630 `tf.VariableSynchronization`. By default the synchronization is set to 1631 `AUTO` and the current `DistributionStrategy` chooses when to 1632 synchronize. 1633 aggregation: Indicates how a distributed variable will be aggregated. 1634 Accepted values are constants defined in the class 1635 `tf.VariableAggregation`. 1636 shape: (optional) The shape of this variable. If None, the shape of 1637 `initial_value` will be used. When setting this argument to 1638 `tf.TensorShape(None)` (representing an unspecified shape), the variable 1639 can be assigned with values of different shapes. 1640 1641 Raises: 1642 ValueError: If both `variable_def` and initial_value are specified. 1643 ValueError: If the initial value is not specified, or does not have a 1644 shape and `validate_shape` is `True`. 1645 RuntimeError: If eager execution is enabled. 1646 """ 1647 self._in_graph_mode = True 1648 if variable_def: 1649 # If variable_def is provided, recreates the variable from its fields. 1650 if initial_value: 1651 raise ValueError("variable_def and initial_value are mutually " 1652 "exclusive.") 1653 self._init_from_proto(variable_def, import_scope=import_scope) 1654 else: 1655 # Create from initial_value. 1656 self._init_from_args( 1657 initial_value=initial_value, 1658 trainable=trainable, 1659 collections=collections, 1660 validate_shape=validate_shape, 1661 caching_device=caching_device, 1662 name=name, 1663 dtype=dtype, 1664 expected_shape=expected_shape, 1665 constraint=constraint, 1666 synchronization=synchronization, 1667 aggregation=aggregation, 1668 shape=shape) 1669 1670 def __repr__(self): 1671 if context.executing_eagerly() and not self._in_graph_mode: 1672 return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % ( 1673 self.name, self.get_shape(), self.dtype.name, 1674 ops.numpy_text(self.read_value(), is_repr=True)) 1675 else: 1676 return "<tf.Variable '%s' shape=%s dtype=%s>" % ( 1677 self.name, self.get_shape(), self.dtype.name) 1678 1679 def _init_from_args(self, 1680 initial_value=None, 1681 trainable=None, 1682 collections=None, 1683 validate_shape=True, 1684 caching_device=None, 1685 name=None, 1686 dtype=None, 1687 expected_shape=None, 1688 constraint=None, 1689 synchronization=None, 1690 aggregation=None, 1691 shape=None): 1692 """Creates a new variable from arguments. 1693 1694 Args: 1695 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 1696 which is the initial value for the Variable. The initial value must have 1697 a shape specified unless `validate_shape` is set to False. Can also be a 1698 callable with no argument that returns the initial value when called. 1699 (Note that initializer functions from init_ops.py must first be bound to 1700 a shape before being used here.) 1701 trainable: If `True`, also adds the variable to the graph collection 1702 `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default 1703 list of variables to use by the `Optimizer` classes. Defaults to `True`, 1704 unless `synchronization` is set to `ON_READ`, in which case it defaults 1705 to `False`. 1706 collections: List of graph collections keys. The new variable is added to 1707 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1708 validate_shape: If `False`, allows the variable to be initialized with a 1709 value of unknown shape. If `True`, the default, the shape of 1710 `initial_value` must be known. 1711 caching_device: Optional device string or function describing where the 1712 Variable should be cached for reading. Defaults to the Variable's 1713 device. If not `None`, caches on another device. Typical use is to 1714 cache on the device where the Ops using the Variable reside, to 1715 deduplicate copying through `Switch` and other conditional statements. 1716 name: Optional name for the variable. Defaults to `'Variable'` and gets 1717 uniquified automatically. 1718 dtype: If set, initial_value will be converted to the given type. If None, 1719 either the datatype will be kept (if initial_value is a Tensor) or 1720 float32 will be used (if it is a Python object convertible to a Tensor). 1721 expected_shape: Deprecated. Ignored. 1722 constraint: An optional projection function to be applied to the variable 1723 after being updated by an `Optimizer` (e.g. used to implement norm 1724 constraints or value constraints for layer weights). The function must 1725 take as input the unprojected Tensor representing the value of the 1726 variable and return the Tensor for the projected value (which must have 1727 the same shape). Constraints are not safe to use when doing asynchronous 1728 distributed training. 1729 synchronization: Indicates when a distributed a variable will be 1730 aggregated. Accepted values are constants defined in the class 1731 `tf.VariableSynchronization`. By default the synchronization is set to 1732 `AUTO` and the current `DistributionStrategy` chooses when to 1733 synchronize. 1734 aggregation: Indicates how a distributed variable will be aggregated. 1735 Accepted values are constants defined in the class 1736 `tf.VariableAggregation`. 1737 shape: (optional) The shape of this variable. If None, the shape of 1738 `initial_value` will be used. When setting this argument to 1739 `tf.TensorShape(None)` (representing an unspecified shape), the variable 1740 can be assigned with values of different shapes. 1741 1742 Raises: 1743 ValueError: If the initial value is not specified, or does not have a 1744 shape and `validate_shape` is `True`. 1745 RuntimeError: If lifted into the eager context. 1746 """ 1747 _ = expected_shape 1748 if initial_value is None: 1749 raise ValueError("initial_value must be specified.") 1750 init_from_fn = callable(initial_value) 1751 1752 if collections is None: 1753 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 1754 if not isinstance(collections, (list, tuple, set)): 1755 raise ValueError( 1756 "collections argument to Variable constructor must be a list, tuple, " 1757 "or set. Got %s of type %s" % (collections, type(collections))) 1758 if constraint is not None and not callable(constraint): 1759 raise ValueError("The `constraint` argument must be a callable.") 1760 1761 # Store the graph key so optimizers know how to only retrieve variables from 1762 # this graph. 1763 self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 1764 if isinstance(initial_value, trackable.CheckpointInitialValue): 1765 self._maybe_initialize_trackable() 1766 self._update_uid = initial_value.checkpoint_position.restore_uid 1767 initial_value = initial_value.wrapped_value 1768 1769 synchronization, aggregation, trainable = ( 1770 validate_synchronization_aggregation_trainable(synchronization, 1771 aggregation, trainable, 1772 name)) 1773 self._synchronization = synchronization 1774 self._aggregation = aggregation 1775 self._trainable = trainable 1776 if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: 1777 collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] 1778 with ops.init_scope(): 1779 # Ensure that we weren't lifted into the eager context. 1780 if context.executing_eagerly(): 1781 raise RuntimeError( 1782 "RefVariable not supported when eager execution is enabled. ") 1783 with ops.name_scope(name, "Variable", 1784 [] if init_from_fn else [initial_value]) as name: 1785 1786 if init_from_fn: 1787 # Use attr_scope and device(None) to simulate the behavior of 1788 # colocate_with when the variable we want to colocate with doesn't 1789 # yet exist. 1790 true_name = ops.name_from_scope_name(name) # pylint: disable=protected-access 1791 attr = attr_value_pb2.AttrValue( 1792 list=attr_value_pb2.AttrValue.ListValue( 1793 s=[compat.as_bytes("loc:@%s" % true_name)])) 1794 # pylint: disable=protected-access 1795 with ops.get_default_graph()._attr_scope({"_class": attr}): 1796 with ops.name_scope("Initializer"), ops.device(None): 1797 initial_value = initial_value() 1798 if isinstance(initial_value, trackable.CheckpointInitialValue): 1799 self._maybe_initialize_trackable() 1800 self._update_uid = initial_value.checkpoint_position.restore_uid 1801 initial_value = initial_value.wrapped_value 1802 self._initial_value = ops.convert_to_tensor( 1803 initial_value, name="initial_value", dtype=dtype) 1804 if shape is None: 1805 shape = ( 1806 self._initial_value.get_shape() 1807 if validate_shape else tensor_shape.unknown_shape()) 1808 self._variable = state_ops.variable_op_v2( 1809 shape, self._initial_value.dtype.base_dtype, name=name) 1810 # pylint: enable=protected-access 1811 1812 # Or get the initial value from a Tensor or Python object. 1813 else: 1814 self._initial_value = ops.convert_to_tensor( 1815 initial_value, name="initial_value", dtype=dtype) 1816 # pylint: disable=protected-access 1817 if self._initial_value.op._get_control_flow_context() is not None: 1818 raise ValueError( 1819 "Initializer for variable %s is from inside a control-flow " 1820 "construct, such as a loop or conditional. When creating a " 1821 "variable inside a loop or conditional, use a lambda as the " 1822 "initializer." % name) 1823 if shape is None: 1824 # pylint: enable=protected-access 1825 shape = ( 1826 self._initial_value.get_shape() 1827 if validate_shape else tensor_shape.unknown_shape()) 1828 # In this case, the variable op can't be created until after the 1829 # initial_value has been converted to a Tensor with a known type. 1830 self._variable = state_ops.variable_op_v2( 1831 shape, self._initial_value.dtype.base_dtype, name=name) 1832 1833 # Cache the name in `self`, because some APIs call `Variable.name` in a 1834 # tight loop, and this halves the cost. 1835 self._name = self._variable.name 1836 1837 # Manually overrides the variable's shape with the initial value's. 1838 if validate_shape: 1839 initial_value_shape = self._initial_value.get_shape() 1840 if not initial_value_shape.is_fully_defined(): 1841 raise ValueError("initial_value must have a shape specified: %s" % 1842 self._initial_value) 1843 1844 # If 'initial_value' makes use of other variables, make sure we don't 1845 # have an issue if these other variables aren't initialized first by 1846 # using their initialized_value() method. 1847 self._initializer_op = state_ops.assign( 1848 self._variable, 1849 _try_guard_against_uninitialized_dependencies( 1850 name, self._initial_value), 1851 validate_shape=validate_shape).op 1852 1853 # TODO(vrv): Change this class to not take caching_device, but 1854 # to take the op to colocate the snapshot with, so we can use 1855 # colocation rather than devices. 1856 if caching_device is not None: 1857 with ops.device(caching_device): 1858 self._snapshot = array_ops.identity(self._variable, name="read") 1859 else: 1860 with ops.colocate_with(self._variable.op): 1861 self._snapshot = array_ops.identity(self._variable, name="read") 1862 ops.add_to_collections(collections, self) 1863 1864 self._caching_device = caching_device 1865 self._save_slice_info = None 1866 self._constraint = constraint 1867 1868 def _init_from_proto(self, variable_def, import_scope=None): 1869 """Recreates the Variable object from a `VariableDef` protocol buffer. 1870 1871 Args: 1872 variable_def: `VariableDef` protocol buffer, describing a variable whose 1873 nodes already exists in the graph. 1874 import_scope: Optional `string`. Name scope to add. 1875 """ 1876 assert isinstance(variable_def, variable_pb2.VariableDef) 1877 # Create from variable_def. 1878 g = ops.get_default_graph() 1879 self._variable = g.as_graph_element( 1880 ops.prepend_name_scope( 1881 variable_def.variable_name, import_scope=import_scope)) 1882 self._name = self._variable.name 1883 self._initializer_op = g.as_graph_element( 1884 ops.prepend_name_scope( 1885 variable_def.initializer_name, import_scope=import_scope)) 1886 # Tests whether initial_value_name exists first for backwards compatibility. 1887 if (hasattr(variable_def, "initial_value_name") and 1888 variable_def.initial_value_name): 1889 self._initial_value = g.as_graph_element( 1890 ops.prepend_name_scope( 1891 variable_def.initial_value_name, import_scope=import_scope)) 1892 else: 1893 self._initial_value = None 1894 synchronization, aggregation, trainable = ( 1895 validate_synchronization_aggregation_trainable( 1896 variable_def.synchronization, variable_def.aggregation, 1897 variable_def.trainable, variable_def.variable_name)) 1898 self._synchronization = synchronization 1899 self._aggregation = aggregation 1900 self._trainable = trainable 1901 self._snapshot = g.as_graph_element( 1902 ops.prepend_name_scope( 1903 variable_def.snapshot_name, import_scope=import_scope)) 1904 if variable_def.HasField("save_slice_info_def"): 1905 self._save_slice_info = Variable.SaveSliceInfo( 1906 save_slice_info_def=variable_def.save_slice_info_def, 1907 import_scope=import_scope) 1908 else: 1909 self._save_slice_info = None 1910 self._caching_device = None 1911 self._constraint = None 1912 1913 def _as_graph_element(self): 1914 """Conversion function for Graph.as_graph_element().""" 1915 return self._variable 1916 1917 def value(self): 1918 """Returns the last snapshot of this variable. 1919 1920 You usually do not need to call this method as all ops that need the value 1921 of the variable call it automatically through a `convert_to_tensor()` call. 1922 1923 Returns a `Tensor` which holds the value of the variable. You can not 1924 assign a new value to this tensor as it is not a reference to the variable. 1925 1926 To avoid copies, if the consumer of the returned value is on the same device 1927 as the variable, this actually returns the live value of the variable, not 1928 a copy. Updates to the variable are seen by the consumer. If the consumer 1929 is on a different device it will get a copy of the variable. 1930 1931 Returns: 1932 A `Tensor` containing the value of the variable. 1933 """ 1934 return self._snapshot 1935 1936 def read_value(self): 1937 """Returns the value of this variable, read in the current context. 1938 1939 Can be different from value() if it's on another device, with control 1940 dependencies, etc. 1941 1942 Returns: 1943 A `Tensor` containing the value of the variable. 1944 """ 1945 return array_ops.identity(self._variable, name="read") 1946 1947 def _ref(self): 1948 """Returns a reference to this variable. 1949 1950 You usually do not need to call this method as all ops that need a reference 1951 to the variable call it automatically. 1952 1953 Returns is a `Tensor` which holds a reference to the variable. You can 1954 assign a new value to the variable by passing the tensor to an assign op. 1955 See `tf.Variable.value` if you want to get the value of the 1956 variable. 1957 1958 Returns: 1959 A `Tensor` that is a reference to the variable. 1960 """ 1961 return self._variable 1962 1963 def set_shape(self, shape): 1964 """Overrides the shape for this variable. 1965 1966 Args: 1967 shape: the `TensorShape` representing the overridden shape. 1968 """ 1969 self._ref().set_shape(shape) 1970 self.value().set_shape(shape) 1971 1972 @property 1973 def trainable(self): 1974 return self._trainable 1975 1976 @property 1977 def synchronization(self): 1978 return self._synchronization 1979 1980 @property 1981 def aggregation(self): 1982 return self._aggregation 1983 1984 def eval(self, session=None): 1985 """In a session, computes and returns the value of this variable. 1986 1987 This is not a graph construction method, it does not add ops to the graph. 1988 1989 This convenience method requires a session where the graph 1990 containing this variable has been launched. If no session is 1991 passed, the default session is used. See `tf.compat.v1.Session` for more 1992 information on launching a graph and on sessions. 1993 1994 ```python 1995 v = tf.Variable([1, 2]) 1996 init = tf.compat.v1.global_variables_initializer() 1997 1998 with tf.compat.v1.Session() as sess: 1999 sess.run(init) 2000 # Usage passing the session explicitly. 2001 print(v.eval(sess)) 2002 # Usage with the default session. The 'with' block 2003 # above makes 'sess' the default session. 2004 print(v.eval()) 2005 ``` 2006 2007 Args: 2008 session: The session to use to evaluate this variable. If none, the 2009 default session is used. 2010 2011 Returns: 2012 A numpy `ndarray` with a copy of the value of this variable. 2013 """ 2014 return self._variable.eval(session=session) 2015 2016 @property 2017 def initial_value(self): 2018 """Returns the Tensor used as the initial value for the variable. 2019 2020 Note that this is different from `initialized_value()` which runs 2021 the op that initializes the variable before returning its value. 2022 This method returns the tensor that is used by the op that initializes 2023 the variable. 2024 2025 Returns: 2026 A `Tensor`. 2027 """ 2028 return self._initial_value 2029 2030 @property 2031 def constraint(self): 2032 """Returns the constraint function associated with this variable. 2033 2034 Returns: 2035 The constraint function that was passed to the variable constructor. 2036 Can be `None` if no constraint was passed. 2037 """ 2038 return self._constraint 2039 2040 def assign(self, value, use_locking=False, name=None, read_value=True): 2041 """Assigns a new value to the variable. 2042 2043 This is essentially a shortcut for `assign(self, value)`. 2044 2045 Args: 2046 value: A `Tensor`. The new value for this variable. 2047 use_locking: If `True`, use locking during the assignment. 2048 name: The name of the operation to be created 2049 read_value: if True, will return something which evaluates to the new 2050 value of the variable; if False will return the assign op. 2051 2052 Returns: 2053 A `Tensor` that will hold the new value of this variable after 2054 the assignment has completed. 2055 """ 2056 assign = state_ops.assign( 2057 self._variable, value, use_locking=use_locking, name=name) 2058 if read_value: 2059 return assign 2060 return assign.op 2061 2062 def assign_add(self, delta, use_locking=False, name=None, read_value=True): 2063 """Adds a value to this variable. 2064 2065 This is essentially a shortcut for `assign_add(self, delta)`. 2066 2067 Args: 2068 delta: A `Tensor`. The value to add to this variable. 2069 use_locking: If `True`, use locking during the operation. 2070 name: The name of the operation to be created 2071 read_value: if True, will return something which evaluates to the new 2072 value of the variable; if False will return the assign op. 2073 2074 Returns: 2075 A `Tensor` that will hold the new value of this variable after 2076 the addition has completed. 2077 """ 2078 assign = state_ops.assign_add( 2079 self._variable, delta, use_locking=use_locking, name=name) 2080 if read_value: 2081 return assign 2082 return assign.op 2083 2084 def assign_sub(self, delta, use_locking=False, name=None, read_value=True): 2085 """Subtracts a value from this variable. 2086 2087 This is essentially a shortcut for `assign_sub(self, delta)`. 2088 2089 Args: 2090 delta: A `Tensor`. The value to subtract from this variable. 2091 use_locking: If `True`, use locking during the operation. 2092 name: The name of the operation to be created 2093 read_value: if True, will return something which evaluates to the new 2094 value of the variable; if False will return the assign op. 2095 2096 Returns: 2097 A `Tensor` that will hold the new value of this variable after 2098 the subtraction has completed. 2099 """ 2100 assign = state_ops.assign_sub( 2101 self._variable, delta, use_locking=use_locking, name=name) 2102 if read_value: 2103 return assign 2104 return assign.op 2105 2106 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 2107 """Subtracts `tf.IndexedSlices` from this variable. 2108 2109 Args: 2110 sparse_delta: `tf.IndexedSlices` to be subtracted from this variable. 2111 use_locking: If `True`, use locking during the operation. 2112 name: the name of the operation. 2113 2114 Returns: 2115 A `Tensor` that will hold the new value of this variable after 2116 the scattered subtraction has completed. 2117 2118 Raises: 2119 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2120 """ 2121 if not isinstance(sparse_delta, ops.IndexedSlices): 2122 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2123 return gen_state_ops.scatter_sub( 2124 self._variable, 2125 sparse_delta.indices, 2126 sparse_delta.values, 2127 use_locking=use_locking, 2128 name=name) 2129 2130 def scatter_add(self, sparse_delta, use_locking=False, name=None): 2131 """Adds `tf.IndexedSlices` to this variable. 2132 2133 Args: 2134 sparse_delta: `tf.IndexedSlices` to be added to this variable. 2135 use_locking: If `True`, use locking during the operation. 2136 name: the name of the operation. 2137 2138 Returns: 2139 A `Tensor` that will hold the new value of this variable after 2140 the scattered addition has completed. 2141 2142 Raises: 2143 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2144 """ 2145 if not isinstance(sparse_delta, ops.IndexedSlices): 2146 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2147 return gen_state_ops.scatter_add( 2148 self._variable, 2149 sparse_delta.indices, 2150 sparse_delta.values, 2151 use_locking=use_locking, 2152 name=name) 2153 2154 def scatter_max(self, sparse_delta, use_locking=False, name=None): 2155 """Updates this variable with the max of `tf.IndexedSlices` and itself. 2156 2157 Args: 2158 sparse_delta: `tf.IndexedSlices` to use as an argument of max with this 2159 variable. 2160 use_locking: If `True`, use locking during the operation. 2161 name: the name of the operation. 2162 2163 Returns: 2164 A `Tensor` that will hold the new value of this variable after 2165 the scattered maximization has completed. 2166 2167 Raises: 2168 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2169 """ 2170 if not isinstance(sparse_delta, ops.IndexedSlices): 2171 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2172 return gen_state_ops.scatter_max( 2173 self._variable, 2174 sparse_delta.indices, 2175 sparse_delta.values, 2176 use_locking=use_locking, 2177 name=name) 2178 2179 def scatter_min(self, sparse_delta, use_locking=False, name=None): 2180 """Updates this variable with the min of `tf.IndexedSlices` and itself. 2181 2182 Args: 2183 sparse_delta: `tf.IndexedSlices` to use as an argument of min with this 2184 variable. 2185 use_locking: If `True`, use locking during the operation. 2186 name: the name of the operation. 2187 2188 Returns: 2189 A `Tensor` that will hold the new value of this variable after 2190 the scattered minimization has completed. 2191 2192 Raises: 2193 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2194 """ 2195 if not isinstance(sparse_delta, ops.IndexedSlices): 2196 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2197 return gen_state_ops.scatter_min( 2198 self._variable, 2199 sparse_delta.indices, 2200 sparse_delta.values, 2201 use_locking=use_locking, 2202 name=name) 2203 2204 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 2205 """Multiply this variable by `tf.IndexedSlices`. 2206 2207 Args: 2208 sparse_delta: `tf.IndexedSlices` to multiply this variable by. 2209 use_locking: If `True`, use locking during the operation. 2210 name: the name of the operation. 2211 2212 Returns: 2213 A `Tensor` that will hold the new value of this variable after 2214 the scattered multiplication has completed. 2215 2216 Raises: 2217 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2218 """ 2219 if not isinstance(sparse_delta, ops.IndexedSlices): 2220 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2221 return gen_state_ops.scatter_mul( 2222 self._variable, 2223 sparse_delta.indices, 2224 sparse_delta.values, 2225 use_locking=use_locking, 2226 name=name) 2227 2228 def scatter_div(self, sparse_delta, use_locking=False, name=None): 2229 """Divide this variable by `tf.IndexedSlices`. 2230 2231 Args: 2232 sparse_delta: `tf.IndexedSlices` to divide this variable by. 2233 use_locking: If `True`, use locking during the operation. 2234 name: the name of the operation. 2235 2236 Returns: 2237 A `Tensor` that will hold the new value of this variable after 2238 the scattered division has completed. 2239 2240 Raises: 2241 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2242 """ 2243 if not isinstance(sparse_delta, ops.IndexedSlices): 2244 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2245 return gen_state_ops.scatter_div( 2246 self._variable, 2247 sparse_delta.indices, 2248 sparse_delta.values, 2249 use_locking=use_locking, 2250 name=name) 2251 2252 def scatter_update(self, sparse_delta, use_locking=False, name=None): 2253 """Assigns `tf.IndexedSlices` to this variable. 2254 2255 Args: 2256 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 2257 use_locking: If `True`, use locking during the operation. 2258 name: the name of the operation. 2259 2260 Returns: 2261 A `Tensor` that will hold the new value of this variable after 2262 the scattered assignment has completed. 2263 2264 Raises: 2265 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2266 """ 2267 if not isinstance(sparse_delta, ops.IndexedSlices): 2268 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2269 return gen_state_ops.scatter_update( 2270 self._variable, 2271 sparse_delta.indices, 2272 sparse_delta.values, 2273 use_locking=use_locking, 2274 name=name) 2275 2276 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 2277 """Assigns `tf.IndexedSlices` to this variable batch-wise. 2278 2279 Analogous to `batch_gather`. This assumes that this variable and the 2280 sparse_delta IndexedSlices have a series of leading dimensions that are the 2281 same for all of them, and the updates are performed on the last dimension of 2282 indices. In other words, the dimensions should be the following: 2283 2284 `num_prefix_dims = sparse_delta.indices.ndims - 1` 2285 `batch_dim = num_prefix_dims + 1` 2286 `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[ 2287 batch_dim:]` 2288 2289 where 2290 2291 `sparse_delta.updates.shape[:num_prefix_dims]` 2292 `== sparse_delta.indices.shape[:num_prefix_dims]` 2293 `== var.shape[:num_prefix_dims]` 2294 2295 And the operation performed can be expressed as: 2296 2297 `var[i_1, ..., i_n, 2298 sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[ 2299 i_1, ..., i_n, j]` 2300 2301 When sparse_delta.indices is a 1D tensor, this operation is equivalent to 2302 `scatter_update`. 2303 2304 To avoid this operation one can looping over the first `ndims` of the 2305 variable and using `scatter_update` on the subtensors that result of slicing 2306 the first dimension. This is a valid option for `ndims = 1`, but less 2307 efficient than this implementation. 2308 2309 Args: 2310 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 2311 use_locking: If `True`, use locking during the operation. 2312 name: the name of the operation. 2313 2314 Returns: 2315 A `Tensor` that will hold the new value of this variable after 2316 the scattered assignment has completed. 2317 2318 Raises: 2319 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2320 """ 2321 return state_ops.batch_scatter_update( 2322 self, 2323 sparse_delta.indices, 2324 sparse_delta.values, 2325 use_locking=use_locking, 2326 name=name) 2327 2328 def scatter_nd_sub(self, indices, updates, name=None): 2329 """Applies sparse subtraction to individual values or slices in a Variable. 2330 2331 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 2332 2333 `indices` must be integer tensor, containing indices into `ref`. 2334 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 2335 2336 The innermost dimension of `indices` (with length `K`) corresponds to 2337 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 2338 dimension of `ref`. 2339 2340 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 2341 2342 ``` 2343 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 2344 ``` 2345 2346 For example, say we want to add 4 scattered elements to a rank-1 tensor to 2347 8 elements. In Python, that update would look like this: 2348 2349 ```python 2350 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 2351 indices = tf.constant([[4], [3], [1] ,[7]]) 2352 updates = tf.constant([9, 10, 11, 12]) 2353 op = ref.scatter_nd_sub(indices, updates) 2354 with tf.compat.v1.Session() as sess: 2355 print sess.run(op) 2356 ``` 2357 2358 The resulting update to ref would look like this: 2359 2360 [1, -9, 3, -6, -6, 6, 7, -4] 2361 2362 See `tf.scatter_nd` for more details about how to make updates to 2363 slices. 2364 2365 Args: 2366 indices: The indices to be used in the operation. 2367 updates: The values to be used in the operation. 2368 name: the name of the operation. 2369 2370 Returns: 2371 A `Tensor` that will hold the new value of this variable after 2372 the scattered subtraction has completed. 2373 """ 2374 return gen_state_ops.scatter_nd_sub( 2375 self._variable, indices, updates, use_locking=True, name=name) 2376 2377 def scatter_nd_add(self, indices, updates, name=None): 2378 """Applies sparse addition to individual values or slices in a Variable. 2379 2380 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 2381 2382 `indices` must be integer tensor, containing indices into `ref`. 2383 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 2384 2385 The innermost dimension of `indices` (with length `K`) corresponds to 2386 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 2387 dimension of `ref`. 2388 2389 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 2390 2391 ``` 2392 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 2393 ``` 2394 2395 For example, say we want to add 4 scattered elements to a rank-1 tensor to 2396 8 elements. In Python, that update would look like this: 2397 2398 ```python 2399 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 2400 indices = tf.constant([[4], [3], [1] ,[7]]) 2401 updates = tf.constant([9, 10, 11, 12]) 2402 add = ref.scatter_nd_add(indices, updates) 2403 with tf.compat.v1.Session() as sess: 2404 print sess.run(add) 2405 ``` 2406 2407 The resulting update to ref would look like this: 2408 2409 [1, 13, 3, 14, 14, 6, 7, 20] 2410 2411 See `tf.scatter_nd` for more details about how to make updates to 2412 slices. 2413 2414 Args: 2415 indices: The indices to be used in the operation. 2416 updates: The values to be used in the operation. 2417 name: the name of the operation. 2418 2419 Returns: 2420 A `Tensor` that will hold the new value of this variable after 2421 the scattered addition has completed. 2422 """ 2423 return gen_state_ops.scatter_nd_add( 2424 self._variable, indices, updates, use_locking=True, name=name) 2425 2426 def scatter_nd_update(self, indices, updates, name=None): 2427 """Applies sparse assignment to individual values or slices in a Variable. 2428 2429 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 2430 2431 `indices` must be integer tensor, containing indices into `ref`. 2432 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 2433 2434 The innermost dimension of `indices` (with length `K`) corresponds to 2435 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 2436 dimension of `ref`. 2437 2438 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 2439 2440 ``` 2441 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 2442 ``` 2443 2444 For example, say we want to add 4 scattered elements to a rank-1 tensor to 2445 8 elements. In Python, that update would look like this: 2446 2447 ```python 2448 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 2449 indices = tf.constant([[4], [3], [1] ,[7]]) 2450 updates = tf.constant([9, 10, 11, 12]) 2451 op = ref.scatter_nd_update(indices, updates) 2452 with tf.compat.v1.Session() as sess: 2453 print sess.run(op) 2454 ``` 2455 2456 The resulting update to ref would look like this: 2457 2458 [1, 11, 3, 10, 9, 6, 7, 12] 2459 2460 See `tf.scatter_nd` for more details about how to make updates to 2461 slices. 2462 2463 Args: 2464 indices: The indices to be used in the operation. 2465 updates: The values to be used in the operation. 2466 name: the name of the operation. 2467 2468 Returns: 2469 A `Tensor` that will hold the new value of this variable after 2470 the scattered assignment has completed. 2471 """ 2472 return gen_state_ops.scatter_nd_update( 2473 self._variable, indices, updates, use_locking=True, name=name) 2474 2475 def scatter_nd_max(self, indices, updates, name=None): 2476 """Updates this variable with the max of `tf.IndexedSlices` and itself. 2477 2478 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 2479 2480 `indices` must be integer tensor, containing indices into `ref`. 2481 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 2482 2483 The innermost dimension of `indices` (with length `K`) corresponds to 2484 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 2485 dimension of `ref`. 2486 2487 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 2488 2489 ``` 2490 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 2491 ``` 2492 2493 See `tf.scatter_nd` for more details about how to make updates to 2494 slices. 2495 2496 Args: 2497 indices: The indices to be used in the operation. 2498 updates: The values to be used in the operation. 2499 name: the name of the operation. 2500 2501 Returns: 2502 A `Tensor` that will hold the new value of this variable after 2503 the scattered addition has completed. 2504 """ 2505 return gen_state_ops.scatter_nd_max( 2506 self._variable, indices, updates, use_locking=True, name=name) 2507 2508 def scatter_nd_min(self, indices, updates, name=None): 2509 """Updates this variable with the min of `tf.IndexedSlices` and itself. 2510 2511 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 2512 2513 `indices` must be integer tensor, containing indices into `ref`. 2514 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 2515 2516 The innermost dimension of `indices` (with length `K`) corresponds to 2517 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 2518 dimension of `ref`. 2519 2520 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 2521 2522 ``` 2523 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 2524 ``` 2525 2526 See `tf.scatter_nd` for more details about how to make updates to 2527 slices. 2528 2529 Args: 2530 indices: The indices to be used in the operation. 2531 updates: The values to be used in the operation. 2532 name: the name of the operation. 2533 2534 Returns: 2535 A `Tensor` that will hold the new value of this variable after 2536 the scattered addition has completed. 2537 """ 2538 return gen_state_ops.scatter_nd_min( 2539 self._variable, indices, updates, use_locking=True, name=name) 2540 2541 def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask, 2542 end_mask, ellipsis_mask, new_axis_mask, 2543 shrink_axis_mask): 2544 return gen_array_ops.strided_slice_assign( 2545 ref=self._ref(), 2546 begin=begin, 2547 end=end, 2548 strides=strides, 2549 value=value, 2550 name=name, 2551 begin_mask=begin_mask, 2552 end_mask=end_mask, 2553 ellipsis_mask=ellipsis_mask, 2554 new_axis_mask=new_axis_mask, 2555 shrink_axis_mask=shrink_axis_mask) 2556 2557 @deprecated(None, "Prefer Dataset.range instead.") 2558 def count_up_to(self, limit): 2559 """Increments this variable until it reaches `limit`. 2560 2561 When that Op is run it tries to increment the variable by `1`. If 2562 incrementing the variable would bring it above `limit` then the Op raises 2563 the exception `OutOfRangeError`. 2564 2565 If no error is raised, the Op outputs the value of the variable before 2566 the increment. 2567 2568 This is essentially a shortcut for `count_up_to(self, limit)`. 2569 2570 Args: 2571 limit: value at which incrementing the variable raises an error. 2572 2573 Returns: 2574 A `Tensor` that will hold the variable value before the increment. If no 2575 other Op modifies this variable, the values produced will all be 2576 distinct. 2577 """ 2578 return state_ops.count_up_to(self._variable, limit=limit) 2579 2580 # Conversion to tensor. 2581 @staticmethod 2582 def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name 2583 """Utility function for converting a Variable to a Tensor.""" 2584 _ = name 2585 if dtype and not dtype.is_compatible_with(v.dtype): 2586 raise ValueError( 2587 "Incompatible type conversion requested to type '%s' for variable " 2588 "of type '%s'" % (dtype.name, v.dtype.name)) 2589 if as_ref: 2590 return v._ref() # pylint: disable=protected-access 2591 else: 2592 return v.value() 2593 2594 # NOTE(mrry): This enables the Variable's overloaded "right" binary 2595 # operators to run when the left operand is an ndarray, because it 2596 # accords the Variable class higher priority than an ndarray, or a 2597 # numpy matrix. 2598 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ 2599 # mechanism, which allows more control over how Variables interact 2600 # with ndarrays. 2601 __array_priority__ = 100 2602 2603 @property 2604 def name(self): 2605 """The name of this variable.""" 2606 return self._name 2607 2608 @property 2609 def initializer(self): 2610 """The initializer operation for this variable.""" 2611 return self._initializer_op 2612 2613 @property 2614 def device(self): 2615 """The device of this variable.""" 2616 return self._variable.device 2617 2618 @property 2619 def dtype(self): 2620 """The `DType` of this variable.""" 2621 return self._variable.dtype 2622 2623 @property 2624 def op(self): 2625 """The `Operation` of this variable.""" 2626 return self._variable.op 2627 2628 @property 2629 def graph(self): 2630 """The `Graph` of this variable.""" 2631 return self._variable.graph 2632 2633 @property 2634 def _distribute_strategy(self): 2635 """The `tf.distribute.Strategy` that this variable was created under.""" 2636 return None # Ref variables are never created inside a strategy. 2637 2638 @property 2639 def shape(self): 2640 """The `TensorShape` of this variable. 2641 2642 Returns: 2643 A `TensorShape`. 2644 """ 2645 return self._variable.get_shape() 2646 2647 def to_proto(self, export_scope=None): 2648 """Converts a `Variable` to a `VariableDef` protocol buffer. 2649 2650 Args: 2651 export_scope: Optional `string`. Name scope to remove. 2652 2653 Returns: 2654 A `VariableDef` protocol buffer, or `None` if the `Variable` is not 2655 in the specified name scope. 2656 """ 2657 if (export_scope is None or self._variable.name.startswith(export_scope)): 2658 var_def = variable_pb2.VariableDef() 2659 var_def.variable_name = ops.strip_name_scope(self._variable.name, 2660 export_scope) 2661 if self._initial_value is not None: 2662 # For backwards compatibility. 2663 var_def.initial_value_name = ops.strip_name_scope( 2664 self._initial_value.name, export_scope) 2665 var_def.trainable = self.trainable 2666 var_def.synchronization = self.synchronization.value 2667 var_def.aggregation = self.aggregation.value 2668 var_def.initializer_name = ops.strip_name_scope(self.initializer.name, 2669 export_scope) 2670 var_def.snapshot_name = ops.strip_name_scope(self._snapshot.name, 2671 export_scope) 2672 if self._save_slice_info: 2673 var_def.save_slice_info_def.MergeFrom( 2674 self._save_slice_info.to_proto(export_scope=export_scope)) 2675 return var_def 2676 else: 2677 return None 2678 2679 def __iadd__(self, other): 2680 logging.log_first_n( 2681 logging.WARN, "Variable += will be deprecated. Use variable.assign_add" 2682 " if you want assignment to the variable value or 'x = x + y'" 2683 " if you want a new python Tensor object.", 1) 2684 return self + other 2685 2686 def __isub__(self, other): 2687 logging.log_first_n( 2688 logging.WARN, "Variable -= will be deprecated. Use variable.assign_sub" 2689 " if you want assignment to the variable value or 'x = x - y'" 2690 " if you want a new python Tensor object.", 1) 2691 return self - other 2692 2693 def __imul__(self, other): 2694 logging.log_first_n( 2695 logging.WARN, 2696 "Variable *= will be deprecated. Use `var.assign(var * other)`" 2697 " if you want assignment to the variable value or `x = x * y`" 2698 " if you want a new python Tensor object.", 1) 2699 return self * other 2700 2701 def __idiv__(self, other): 2702 logging.log_first_n( 2703 logging.WARN, 2704 "Variable /= will be deprecated. Use `var.assign(var / other)`" 2705 " if you want assignment to the variable value or `x = x / y`" 2706 " if you want a new python Tensor object.", 1) 2707 return self / other 2708 2709 def __itruediv__(self, other): 2710 logging.log_first_n( 2711 logging.WARN, 2712 "Variable /= will be deprecated. Use `var.assign(var / other)`" 2713 " if you want assignment to the variable value or `x = x / y`" 2714 " if you want a new python Tensor object.", 1) 2715 return self / other 2716 2717 def __irealdiv__(self, other): 2718 logging.log_first_n( 2719 logging.WARN, 2720 "Variable /= will be deprecated. Use `var.assign(var / other)`" 2721 " if you want assignment to the variable value or `x = x / y`" 2722 " if you want a new python Tensor object.", 1) 2723 return self / other 2724 2725 def __ipow__(self, other): 2726 logging.log_first_n( 2727 logging.WARN, 2728 "Variable **= will be deprecated. Use `var.assign(var ** other)`" 2729 " if you want assignment to the variable value or `x = x ** y`" 2730 " if you want a new python Tensor object.", 1) 2731 return self**other 2732 2733 2734def _try_guard_against_uninitialized_dependencies(name, initial_value): 2735 """Attempt to guard against dependencies on uninitialized variables. 2736 2737 Replace references to variables in `initial_value` with references to the 2738 variable's initialized values. The initialized values are essentially 2739 conditional TensorFlow graphs that return a variable's value if it is 2740 initialized or its `initial_value` if it hasn't been initialized. This 2741 replacement is done on a best effort basis: 2742 2743 - If the `initial_value` graph contains cycles, we don't do any 2744 replacements for that graph. 2745 - If the variables that `initial_value` depends on are not present in the 2746 `GLOBAL_VARIABLES` or `LOCAL_VARIABLES` we don't replace them. 2747 2748 In these cases, it is up to the caller to ensure that the `initial_value` 2749 graph uses initialized variables or that they guard access to variables 2750 using their `initialized_value` method. 2751 2752 Args: 2753 name: Variable name. 2754 initial_value: `Tensor`. The initial value. 2755 2756 Returns: 2757 A `Tensor` suitable to initialize a variable. 2758 Raises: 2759 TypeError: If `initial_value` is not a `Tensor`. 2760 """ 2761 if not isinstance(initial_value, ops.Tensor): 2762 raise TypeError("initial_value needs to be a Tensor: %s" % initial_value) 2763 2764 # Don't modify initial_value if it contains any cyclic dependencies. 2765 if _has_cycle(initial_value.op, state={}): 2766 return initial_value 2767 return _safe_initial_value_from_tensor(name, initial_value, op_cache={}) 2768 2769 2770_UNKNOWN, _STARTED, _FINISHED = range(3) 2771 2772 2773def _has_cycle(op, state): 2774 """Detect cycles in the dependencies of `initial_value`.""" 2775 op_state = state.get(op.name, _UNKNOWN) 2776 if op_state == _STARTED: 2777 return True 2778 elif op_state == _FINISHED: 2779 return False 2780 2781 state[op.name] = _STARTED 2782 for i in itertools.chain((i.op for i in op.inputs), op.control_inputs): 2783 if _has_cycle(i, state): 2784 return True 2785 state[op.name] = _FINISHED 2786 return False 2787 2788 2789def _safe_initial_value_from_tensor(name, tensor, op_cache): 2790 """Replace dependencies on variables with their initialized values. 2791 2792 Args: 2793 name: Variable name. 2794 tensor: A `Tensor`. The tensor to replace. 2795 op_cache: A dict mapping operation names to `Operation`s. Used to memoize 2796 the results so as to avoid creating redundant operations. 2797 2798 Returns: 2799 A `Tensor` compatible with `tensor`. Any inputs that lead to variable 2800 values will be replaced with a corresponding graph that uses the 2801 variable's initialized values. This is done on a best-effort basis. If no 2802 modifications need to be made then `tensor` will be returned unchanged. 2803 """ 2804 op = tensor.op 2805 new_op = op_cache.get(op.name) 2806 if new_op is None: 2807 new_op = _safe_initial_value_from_op(name, op, op_cache) 2808 op_cache[op.name] = new_op 2809 return new_op.outputs[tensor.value_index] 2810 2811 2812def _safe_initial_value_from_op(name, op, op_cache): 2813 """Replace dependencies on variables with their initialized values. 2814 2815 Args: 2816 name: Variable name. 2817 op: An `Operation`. The operation to replace. 2818 op_cache: A dict mapping operation names to `Operation`s. Used to memoize 2819 the results so as to avoid creating redundant operations. 2820 2821 Returns: 2822 An `Operation` compatible with `op`. Any inputs that lead to variable 2823 values will be replaced with a corresponding graph that uses the 2824 variable's initialized values. This is done on a best-effort basis. If no 2825 modifications need to be made then `op` will be returned unchanged. 2826 """ 2827 op_type = op.node_def.op 2828 if op_type in ("IsVariableInitialized", "VarIsInitializedOp", 2829 "ReadVariableOp", "If"): 2830 return op 2831 2832 # Attempt to find the initialized_value of any variable reference / handles. 2833 # TODO(b/70206927): Fix handling of ResourceVariables. 2834 if op_type in ("Variable", "VariableV2", "VarHandleOp"): 2835 initialized_value = _find_initialized_value_for_variable(op) 2836 return op if initialized_value is None else initialized_value.op 2837 2838 # Recursively build initializer expressions for inputs. 2839 modified = False 2840 new_op_inputs = [] 2841 for op_input in op.inputs: 2842 new_op_input = _safe_initial_value_from_tensor(name, op_input, op_cache) 2843 new_op_inputs.append(new_op_input) 2844 modified = modified or (new_op_input != op_input) 2845 2846 # If at least one input was modified, replace the op. 2847 if modified: 2848 new_op_type = op_type 2849 if new_op_type == "RefSwitch": 2850 new_op_type = "Switch" 2851 new_op_name = op.node_def.name + "_" + name 2852 new_op_name = new_op_name.replace(":", "_") 2853 return op.graph.create_op( 2854 new_op_type, 2855 new_op_inputs, 2856 op._output_types, # pylint: disable=protected-access 2857 name=new_op_name, 2858 attrs=op.node_def.attr) 2859 2860 return op 2861 2862 2863def _find_initialized_value_for_variable(variable_op): 2864 """Find the initialized value for a variable op. 2865 2866 To do so, lookup the variable op in the variables collection. 2867 2868 Args: 2869 variable_op: A variable `Operation`. 2870 2871 Returns: 2872 A `Tensor` representing the initialized value for the variable or `None` 2873 if the initialized value could not be found. 2874 """ 2875 try: 2876 var_names = [variable_op.node_def.name, variable_op.node_def.name + ":0"] 2877 for collection_name in (ops.GraphKeys.GLOBAL_VARIABLES, 2878 ops.GraphKeys.LOCAL_VARIABLES): 2879 for var in variable_op.graph.get_collection(collection_name): 2880 if var.name in var_names: 2881 return var.initialized_value() 2882 except AttributeError: 2883 # Return None when an incomplete user-defined variable type was put in 2884 # the collection. 2885 return None 2886 return None 2887 2888 2889class PartitionedVariable(object): 2890 """A container for partitioned `Variable` objects. 2891 2892 @compatibility(eager) `tf.PartitionedVariable` is not compatible with 2893 eager execution. Use `tf.Variable` instead which is compatible 2894 with both eager execution and graph construction. See [the 2895 TensorFlow Eager Execution 2896 guide](https://www.tensorflow.org/guide/eager#variables_and_optimizers) 2897 for details on how variables work in eager execution. 2898 @end_compatibility 2899 """ 2900 2901 def __init__(self, name, shape, dtype, variable_list, partitions): 2902 """Creates a new partitioned variable wrapper. 2903 2904 Variables passed via the variable_list must contain a save_slice_info 2905 field. Concatenation and iteration is in lexicographic order according 2906 to the var_offset property of the save_slice_info. 2907 2908 Args: 2909 name: String. Overall name of the variables. 2910 shape: List of integers. Overall shape of the variables. 2911 dtype: Type of the variables. 2912 variable_list: List of `Variable` that comprise this partitioned variable. 2913 partitions: List of integers. Number of partitions for each dimension. 2914 2915 Raises: 2916 TypeError: If `variable_list` is not a list of `Variable` objects, or 2917 `partitions` is not a list. 2918 ValueError: If `variable_list` is empty, or the `Variable` shape 2919 information does not match `shape`, or `partitions` has invalid values. 2920 """ 2921 if not isinstance(variable_list, (list, tuple)): 2922 raise TypeError("variable_list is not a list or tuple: %s" % 2923 variable_list) 2924 if not isinstance(partitions, (list, tuple)): 2925 raise TypeError("partitions is not a list or tuple: %s" % partitions) 2926 if not all(p >= 1 for p in partitions): 2927 raise ValueError("partition values must be positive: %s" % partitions) 2928 if not variable_list: 2929 raise ValueError("variable_list may not be empty") 2930 # pylint: disable=protected-access 2931 for v in variable_list: 2932 # Sort the variable_list lexicographically according to var offset value. 2933 if not all(v._get_save_slice_info() is not None for v in variable_list): 2934 raise ValueError( 2935 "All variables must have a save_slice_info available: %s" % 2936 [v.name for v in variable_list]) 2937 if len(shape) != len(partitions): 2938 raise ValueError("len(shape) != len(partitions): %s vs. %s" % 2939 (shape, partitions)) 2940 if v._get_save_slice_info().full_shape != shape: 2941 raise ValueError("All variables' full shapes must match shape: %s; " 2942 "but full shapes were: %s" % 2943 (shape, str([v._get_save_slice_info().full_shape]))) 2944 self._variable_list = sorted( 2945 variable_list, key=lambda v: v._get_save_slice_info().var_offset) 2946 # pylint: enable=protected-access 2947 2948 self._name = name 2949 self._shape = shape 2950 self._dtype = dtype 2951 self._partitions = partitions 2952 self._as_tensor = None 2953 2954 def __iter__(self): 2955 """Return an iterable for accessing the underlying partition Variables.""" 2956 return iter(self._variable_list) 2957 2958 def __len__(self): 2959 num_partition_axes = len(self._partition_axes()) 2960 if num_partition_axes > 1: 2961 raise ValueError("Cannot get a length for %d > 1 partition axes" % 2962 num_partition_axes) 2963 return len(self._variable_list) 2964 2965 def _partition_axes(self): 2966 if all(p == 1 for p in self._partitions): 2967 return [0] 2968 else: 2969 return [i for i, p in enumerate(self._partitions) if p > 1] 2970 2971 def _concat(self): 2972 """Returns the overall concatenated value as a `Tensor`. 2973 2974 This is different from using the partitioned variable directly as a tensor 2975 (through tensor conversion and `as_tensor`) in that it creates a new set of 2976 operations that keeps the control dependencies from its scope. 2977 2978 Returns: 2979 `Tensor` containing the concatenated value. 2980 """ 2981 if len(self._variable_list) == 1: 2982 with ops.name_scope(None): 2983 return array_ops.identity(self._variable_list[0], name=self._name) 2984 2985 partition_axes = self._partition_axes() 2986 2987 if len(partition_axes) > 1: 2988 raise NotImplementedError( 2989 "Cannot concatenate along more than one dimension: %s. " 2990 "Multi-axis partition concat is not supported" % str(partition_axes)) 2991 partition_ix = partition_axes[0] 2992 2993 with ops.name_scope(self._name + "/ConcatPartitions/"): 2994 concatenated = array_ops.concat(self._variable_list, partition_ix) 2995 2996 with ops.name_scope(None): 2997 return array_ops.identity(concatenated, name=self._name) 2998 2999 def as_tensor(self): 3000 """Returns the overall concatenated value as a `Tensor`. 3001 3002 The returned tensor will not inherit the control dependencies from the scope 3003 where the value is used, which is similar to getting the value of 3004 `Variable`. 3005 3006 Returns: 3007 `Tensor` containing the concatenated value. 3008 """ 3009 with ops.control_dependencies(None): 3010 return self._concat() 3011 3012 @staticmethod 3013 def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): 3014 # pylint: disable=invalid-name 3015 _ = name 3016 if dtype is not None and not dtype.is_compatible_with(v.dtype): 3017 raise ValueError( 3018 "Incompatible type conversion requested to type '%s' for variable " 3019 "of type '%s'" % (dtype.name, v.dtype.name)) 3020 if as_ref: 3021 raise NotImplementedError( 3022 "PartitionedVariable doesn't support being used as a reference.") 3023 else: 3024 return v.as_tensor() 3025 3026 @property 3027 def name(self): 3028 return self._name 3029 3030 @property 3031 def dtype(self): 3032 return self._dtype 3033 3034 @property 3035 def shape(self): 3036 return self.get_shape() 3037 3038 @property 3039 def _distribute_strategy(self): 3040 """The `tf.distribute.Strategy` that this variable was created under.""" 3041 # NOTE(yuefengz): Today, no partitioned variables in a distribute strategy. 3042 return None 3043 3044 def get_shape(self): 3045 return self._shape 3046 3047 def _get_variable_list(self): 3048 return self._variable_list 3049 3050 def _get_partitions(self): 3051 return self._partitions 3052 3053 def _apply_assign_fn(self, assign_fn, value): 3054 partition_axes = self._partition_axes() 3055 if len(partition_axes) > 1: 3056 raise NotImplementedError( 3057 "Cannot do assign action along more than one dimension: %s. " 3058 "Multi-axis partition assign action is not supported " % 3059 str(partition_axes)) 3060 if isinstance(value, list): 3061 assert len(value) == len(self._variable_list) 3062 value_list = value 3063 elif isinstance(value, PartitionedVariable): 3064 value_list = [var_part for var_part in value] 3065 else: 3066 partition_ix = partition_axes[0] 3067 size_splits_list = [ 3068 tensor_shape.dimension_value(var.shape[partition_ix]) 3069 for var in self._variable_list 3070 ] 3071 value_list = array_ops.split(value, size_splits_list, axis=partition_ix) 3072 3073 op_list = [ 3074 assign_fn(var, value_list[idx]) 3075 for idx, var in enumerate(self._variable_list) 3076 ] 3077 return op_list 3078 3079 def assign(self, value, use_locking=False, name=None, read_value=True): 3080 assign_fn = lambda var, r_value: var.assign( 3081 r_value, use_locking=use_locking, name=name, read_value=read_value) 3082 assign_list = self._apply_assign_fn(assign_fn, value) 3083 if read_value: 3084 return assign_list 3085 return [assign.op for assign in assign_list] 3086 3087 def assign_add(self, value, use_locking=False, name=None, read_value=True): 3088 assign_fn = lambda var, r_value: var.assign_add( 3089 r_value, use_locking=use_locking, name=name, read_value=read_value) 3090 assign_list = self._apply_assign_fn(assign_fn, value) 3091 if read_value: 3092 return assign_list 3093 return [assign.op for assign in assign_list] 3094 3095 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 3096 assign_fn = lambda var, r_value: var.assign_sub( 3097 r_value, use_locking=use_locking, name=name, read_value=read_value) 3098 assign_list = self._apply_assign_fn(assign_fn, value) 3099 if read_value: 3100 return assign_list 3101 return [assign.op for assign in assign_list] 3102 3103 3104# Register a conversion function which reads the value of the variable, 3105# allowing instances of the class to be used as tensors. 3106ops.register_tensor_conversion_function(RefVariable, 3107 RefVariable._TensorConversionFunction) # pylint: disable=protected-access 3108 3109 3110@tf_export(v1=["global_variables"]) 3111def global_variables(scope=None): 3112 """Returns global variables. 3113 3114 Global variables are variables that are shared across machines in a 3115 distributed environment. The `Variable()` constructor or `get_variable()` 3116 automatically adds new variables to the graph collection 3117 `GraphKeys.GLOBAL_VARIABLES`. 3118 This convenience function returns the contents of that collection. 3119 3120 An alternative to global variables are local variables. See 3121 `tf.compat.v1.local_variables` 3122 3123 Args: 3124 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3125 include only items whose `name` attribute matches `scope` using 3126 `re.match`. Items without a `name` attribute are never returned if a scope 3127 is supplied. The choice of `re.match` means that a `scope` without special 3128 tokens filters by prefix. 3129 3130 Returns: 3131 A list of `Variable` objects. 3132 """ 3133 return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) 3134 3135 3136@tf_export(v1=["all_variables"]) 3137@deprecated("2017-03-02", "Please use tf.global_variables instead.") 3138def all_variables(): 3139 """Use `tf.compat.v1.global_variables` instead.""" 3140 return global_variables() 3141 3142 3143def _all_saveable_objects(scope=None): 3144 """Returns all variables and `SaveableObject`s that must be checkpointed. 3145 3146 Args: 3147 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3148 include only items whose `name` attribute matches `scope` using 3149 `re.match`. Items without a `name` attribute are never returned if a scope 3150 is supplied. The choice of `re.match` means that a `scope` without special 3151 tokens filters by prefix. 3152 3153 Returns: 3154 A list of `Variable` and `SaveableObject` to be checkpointed 3155 """ 3156 # TODO(andreasst): make this function public once things are settled. 3157 return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) + 3158 ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope)) 3159 3160 3161@tf_export(v1=["local_variables"]) 3162def local_variables(scope=None): 3163 """Returns local variables. 3164 3165 Local variables - per process variables, usually not saved/restored to 3166 checkpoint and used for temporary or intermediate values. 3167 For example, they can be used as counters for metrics computation or 3168 number of epochs this machine has read data. 3169 The `tf.contrib.framework.local_variable()` function automatically adds the 3170 new variable to `GraphKeys.LOCAL_VARIABLES`. 3171 This convenience function returns the contents of that collection. 3172 3173 An alternative to local variables are global variables. See 3174 `tf.compat.v1.global_variables` 3175 3176 Args: 3177 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3178 include only items whose `name` attribute matches `scope` using 3179 `re.match`. Items without a `name` attribute are never returned if a scope 3180 is supplied. The choice of `re.match` means that a `scope` without special 3181 tokens filters by prefix. 3182 3183 Returns: 3184 A list of local `Variable` objects. 3185 """ 3186 return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES, scope) 3187 3188 3189@tf_export(v1=["model_variables"]) 3190def model_variables(scope=None): 3191 """Returns all variables in the MODEL_VARIABLES collection. 3192 3193 Args: 3194 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3195 include only items whose `name` attribute matches `scope` using 3196 `re.match`. Items without a `name` attribute are never returned if a scope 3197 is supplied. The choice of `re.match` means that a `scope` without special 3198 tokens filters by prefix. 3199 3200 Returns: 3201 A list of local Variable objects. 3202 """ 3203 return ops.get_collection(ops.GraphKeys.MODEL_VARIABLES, scope) 3204 3205 3206@tf_export(v1=["trainable_variables"]) 3207def trainable_variables(scope=None): 3208 """Returns all variables created with `trainable=True`. 3209 3210 When passed `trainable=True`, the `Variable()` constructor automatically 3211 adds new variables to the graph collection 3212 `GraphKeys.TRAINABLE_VARIABLES`. This convenience function returns the 3213 contents of that collection. 3214 3215 Args: 3216 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3217 include only items whose `name` attribute matches `scope` using 3218 `re.match`. Items without a `name` attribute are never returned if a scope 3219 is supplied. The choice of `re.match` means that a `scope` without special 3220 tokens filters by prefix. 3221 3222 Returns: 3223 A list of Variable objects. 3224 """ 3225 return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, scope) 3226 3227 3228@tf_export(v1=["moving_average_variables"]) 3229def moving_average_variables(scope=None): 3230 """Returns all variables that maintain their moving averages. 3231 3232 If an `ExponentialMovingAverage` object is created and the `apply()` 3233 method is called on a list of variables, these variables will 3234 be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection. 3235 This convenience function returns the contents of that collection. 3236 3237 Args: 3238 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3239 include only items whose `name` attribute matches `scope` using 3240 `re.match`. Items without a `name` attribute are never returned if a scope 3241 is supplied. The choice of `re.match` means that a `scope` without special 3242 tokens filters by prefix. 3243 3244 Returns: 3245 A list of Variable objects. 3246 """ 3247 return ops.get_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, scope) 3248 3249 3250@tf_export(v1=["initializers.variables", "variables_initializer"]) 3251def variables_initializer(var_list, name="init"): 3252 """Returns an Op that initializes a list of variables. 3253 3254 After you launch the graph in a session, you can run the returned Op to 3255 initialize all the variables in `var_list`. This Op runs all the 3256 initializers of the variables in `var_list` in parallel. 3257 3258 Calling `initialize_variables()` is equivalent to passing the list of 3259 initializers to `Group()`. 3260 3261 If `var_list` is empty, however, the function still returns an Op that can 3262 be run. That Op just has no effect. 3263 3264 Args: 3265 var_list: List of `Variable` objects to initialize. 3266 name: Optional name for the returned operation. 3267 3268 Returns: 3269 An Op that run the initializers of all the specified variables. 3270 """ 3271 if var_list and not context.executing_eagerly(): 3272 return control_flow_ops.group(*[v.initializer for v in var_list], name=name) 3273 return control_flow_ops.no_op(name=name) 3274 3275 3276@tf_export(v1=["initialize_variables"]) 3277@tf_should_use.should_use_result 3278@deprecated("2017-03-02", "Use `tf.variables_initializer` instead.") 3279def initialize_variables(var_list, name="init"): 3280 """See `tf.compat.v1.variables_initializer`.""" 3281 return variables_initializer(var_list, name=name) 3282 3283 3284@tf_export(v1=["initializers.global_variables", "global_variables_initializer"]) 3285def global_variables_initializer(): 3286 """Returns an Op that initializes global variables. 3287 3288 This is just a shortcut for `variables_initializer(global_variables())` 3289 3290 Returns: 3291 An Op that initializes global variables in the graph. 3292 """ 3293 if context.executing_eagerly(): 3294 return control_flow_ops.no_op(name="global_variables_initializer") 3295 return variables_initializer(global_variables()) 3296 3297 3298@tf_export(v1=["initialize_all_variables"]) 3299@tf_should_use.should_use_result 3300@deprecated("2017-03-02", "Use `tf.global_variables_initializer` instead.") 3301def initialize_all_variables(): 3302 """See `tf.compat.v1.global_variables_initializer`.""" 3303 return global_variables_initializer() 3304 3305 3306@tf_export(v1=["initializers.local_variables", "local_variables_initializer"]) 3307def local_variables_initializer(): 3308 """Returns an Op that initializes all local variables. 3309 3310 This is just a shortcut for `variables_initializer(local_variables())` 3311 3312 Returns: 3313 An Op that initializes all local variables in the graph. 3314 """ 3315 if context.executing_eagerly(): 3316 return control_flow_ops.no_op(name="local_variables_initializer") 3317 return variables_initializer(local_variables()) 3318 3319 3320@tf_export(v1=["initialize_local_variables"]) 3321@tf_should_use.should_use_result 3322@deprecated("2017-03-02", "Use `tf.local_variables_initializer` instead.") 3323def initialize_local_variables(): 3324 """See `tf.compat.v1.local_variables_initializer`.""" 3325 return local_variables_initializer() 3326 3327 3328@tf_export(v1=["is_variable_initialized"]) 3329@tf_should_use.should_use_result 3330def is_variable_initialized(variable): 3331 """Tests if a variable has been initialized. 3332 3333 Args: 3334 variable: A `Variable`. 3335 3336 Returns: 3337 Returns a scalar boolean Tensor, `True` if the variable has been 3338 initialized, `False` otherwise. 3339 """ 3340 return state_ops.is_variable_initialized(variable) 3341 3342 3343@tf_export(v1=["assert_variables_initialized"]) 3344@tf_should_use.should_use_result 3345def assert_variables_initialized(var_list=None): 3346 """Returns an Op to check if variables are initialized. 3347 3348 NOTE: This function is obsolete and will be removed in 6 months. Please 3349 change your implementation to use `report_uninitialized_variables()`. 3350 3351 When run, the returned Op will raise the exception `FailedPreconditionError` 3352 if any of the variables has not yet been initialized. 3353 3354 Note: This function is implemented by trying to fetch the values of the 3355 variables. If one of the variables is not initialized a message may be 3356 logged by the C++ runtime. This is expected. 3357 3358 Args: 3359 var_list: List of `Variable` objects to check. Defaults to the value of 3360 `global_variables().` 3361 3362 Returns: 3363 An Op, or None if there are no variables. 3364 """ 3365 if var_list is None: 3366 var_list = global_variables() + local_variables() 3367 # Backwards compatibility for old-style variables. TODO(touts): remove. 3368 if not var_list: 3369 var_list = [] 3370 for op in ops.get_default_graph().get_operations(): 3371 if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]: 3372 var_list.append(op.outputs[0]) 3373 if not var_list: 3374 return None 3375 else: 3376 ranks = [] 3377 for var in var_list: 3378 with ops.colocate_with(var.op): 3379 ranks.append(array_ops.rank_internal(var, optimize=False)) 3380 if len(ranks) == 1: 3381 return ranks[0] 3382 else: 3383 return array_ops.stack(ranks) 3384 3385 3386@tf_export(v1=["report_uninitialized_variables"]) 3387@tf_should_use.should_use_result 3388def report_uninitialized_variables(var_list=None, 3389 name="report_uninitialized_variables"): 3390 """Adds ops to list the names of uninitialized variables. 3391 3392 When run, it returns a 1-D tensor containing the names of uninitialized 3393 variables if there are any, or an empty array if there are none. 3394 3395 Args: 3396 var_list: List of `Variable` objects to check. Defaults to the value of 3397 `global_variables() + local_variables()` 3398 name: Optional name of the `Operation`. 3399 3400 Returns: 3401 A 1-D tensor containing names of the uninitialized variables, or an empty 3402 1-D tensor if there are no variables or no uninitialized variables. 3403 """ 3404 if var_list is None: 3405 var_list = global_variables() + local_variables() 3406 # Backwards compatibility for old-style variables. TODO(touts): remove. 3407 if not var_list: 3408 var_list = [] 3409 for op in ops.get_default_graph().get_operations(): 3410 if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]: 3411 var_list.append(op.outputs[0]) 3412 with ops.name_scope(name): 3413 # Run all operations on CPU 3414 if var_list: 3415 init_vars = [state_ops.is_variable_initialized(v) for v in var_list] 3416 local_device = os.environ.get( 3417 "TF_DEVICE_FOR_UNINITIALIZED_VARIABLE_REPORTING", "/cpu:0") 3418 with ops.device(local_device): 3419 if not var_list: 3420 # Return an empty tensor so we only need to check for returned tensor 3421 # size being 0 as an indication of model ready. 3422 return array_ops.constant([], dtype=dtypes.string) 3423 else: 3424 # Get a 1-D boolean tensor listing whether each variable is initialized. 3425 variables_mask = math_ops.logical_not(array_ops.stack(init_vars)) 3426 # Get a 1-D string tensor containing all the variable names. 3427 variable_names_tensor = array_ops.constant( 3428 [s.op.name for s in var_list]) 3429 # Return a 1-D tensor containing all the names of 3430 # uninitialized variables. 3431 return array_ops.boolean_mask(variable_names_tensor, variables_mask) 3432 3433 3434ops.register_tensor_conversion_function( 3435 PartitionedVariable, PartitionedVariable._TensorConversionFunction) # pylint: disable=protected-access 3436