1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Variable class.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import enum # pylint: disable=g-bad-import-order 21import itertools 22import functools 23import os 24 25import six 26 27from tensorflow.core.framework import attr_value_pb2 28from tensorflow.core.framework import variable_pb2 29from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import 30from tensorflow.python import _pywrap_utils 31from tensorflow.python.eager import context 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import tensor_shape 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import gen_array_ops 38from tensorflow.python.ops import gen_state_ops 39from tensorflow.python.ops import gen_math_ops 40from tensorflow.python.ops import math_ops 41from tensorflow.python.ops import state_ops 42from tensorflow.python.platform import tf_logging as logging 43from tensorflow.python.training.tracking import base as trackable 44from tensorflow.python.util import compat 45from tensorflow.python.util import object_identity 46from tensorflow.python.util import tf_should_use 47from tensorflow.python.util.deprecation import deprecated 48from tensorflow.python.util.deprecation import deprecated_args 49from tensorflow.python.util.tf_export import tf_export 50 51 52def default_variable_creator(_, **kwds): 53 del kwds 54 raise NotImplementedError("variable_scope needs to be imported") 55 56 57def default_variable_creator_v2(_, **kwds): 58 del kwds 59 raise NotImplementedError("variable_scope needs to be imported") 60 61 62def _make_getter(captured_getter, captured_previous): 63 """To avoid capturing loop variables.""" 64 65 def getter(**kwargs): 66 return captured_getter(captured_previous, **kwargs) 67 68 return getter 69 70 71@tf_export("VariableSynchronization") 72class VariableSynchronization(enum.Enum): 73 """Indicates when a distributed variable will be synced. 74 75 * `AUTO`: Indicates that the synchronization will be determined by the current 76 `DistributionStrategy` (eg. With `MirroredStrategy` this would be 77 `ON_WRITE`). 78 * `NONE`: Indicates that there will only be one copy of the variable, so 79 there is no need to sync. 80 * `ON_WRITE`: Indicates that the variable will be updated across devices 81 every time it is written. 82 * `ON_READ`: Indicates that the variable will be aggregated across devices 83 when it is read (eg. when checkpointing or when evaluating an op that uses 84 the variable). 85 """ 86 AUTO = 0 87 NONE = 1 88 ON_WRITE = 2 89 ON_READ = 3 90 91 92# LINT.IfChange 93@tf_export("VariableAggregation", v1=[]) 94class VariableAggregationV2(enum.Enum): 95 """Indicates how a distributed variable will be aggregated. 96 97 `tf.distribute.Strategy` distributes a model by making multiple copies 98 (called "replicas") acting data-parallel on different elements of the input 99 batch. When performing some variable-update operation, say 100 `var.assign_add(x)`, in a model, we need to resolve how to combine the 101 different values for `x` computed in the different replicas. 102 103 * `NONE`: This is the default, giving an error if you use a 104 variable-update operation with multiple replicas. 105 * `SUM`: Add the updates across replicas. 106 * `MEAN`: Take the arithmetic mean ("average") of the updates across replicas. 107 * `ONLY_FIRST_REPLICA`: This is for when every replica is performing the same 108 update, but we only want to perform the update once. Used, e.g., for the 109 global step counter. 110 """ 111 NONE = 0 112 SUM = 1 113 MEAN = 2 114 ONLY_FIRST_REPLICA = 3 115 116 def __hash__(self): 117 return hash(self.value) 118 119 def __eq__(self, other): 120 if self is other: 121 return True 122 elif isinstance(other, VariableAggregation): 123 return int(self.value) == int(other.value) 124 else: 125 return False 126 127 128@tf_export(v1=["VariableAggregation"]) 129class VariableAggregation(enum.Enum): 130 NONE = 0 131 SUM = 1 132 MEAN = 2 133 ONLY_FIRST_REPLICA = 3 134 ONLY_FIRST_TOWER = 3 # DEPRECATED 135 136 def __hash__(self): 137 return hash(self.value) 138 139 140# LINT.ThenChange(//tensorflow/core/framework/variable.proto) 141# 142# Note that we are currently relying on the integer values of the Python enums 143# matching the integer values of the proto enums. 144 145VariableAggregation.__doc__ = ( 146 VariableAggregationV2.__doc__ + 147 "* `ONLY_FIRST_TOWER`: Deprecated alias for `ONLY_FIRST_REPLICA`.\n ") 148 149 150def validate_synchronization_aggregation_trainable(synchronization, aggregation, 151 trainable, name): 152 """Given user-provided variable properties, sets defaults and validates.""" 153 if aggregation is None: 154 aggregation = VariableAggregation.NONE 155 else: 156 if not isinstance(aggregation, 157 (VariableAggregation, VariableAggregationV2)): 158 try: 159 aggregation = VariableAggregationV2(aggregation) 160 except ValueError: 161 raise ValueError( 162 "Invalid variable aggregation mode: {} for variable: {}".format( 163 aggregation, name)) 164 if synchronization is None: 165 synchronization = VariableSynchronization.AUTO 166 else: 167 try: 168 synchronization = VariableSynchronization(synchronization) 169 except ValueError: 170 raise ValueError( 171 "Invalid variable synchronization mode: {} for variable: {}".format( 172 synchronization, name)) 173 if trainable is None: 174 trainable = synchronization != VariableSynchronization.ON_READ 175 return synchronization, aggregation, trainable 176 177 178class VariableMetaclass(type): 179 """Metaclass to allow construction of tf.Variable to be overridden.""" 180 181 def _variable_v1_call(cls, 182 initial_value=None, 183 trainable=None, 184 collections=None, 185 validate_shape=True, 186 caching_device=None, 187 name=None, 188 variable_def=None, 189 dtype=None, 190 expected_shape=None, 191 import_scope=None, 192 constraint=None, 193 use_resource=None, 194 synchronization=VariableSynchronization.AUTO, 195 aggregation=VariableAggregation.NONE, 196 shape=None): 197 """Call on Variable class. Useful to force the signature.""" 198 previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs) 199 for _, getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access 200 previous_getter = _make_getter(getter, previous_getter) 201 202 # Reset `aggregation` that is explicitly set as `None` to the enum NONE. 203 if aggregation is None: 204 aggregation = VariableAggregation.NONE 205 return previous_getter( 206 initial_value=initial_value, 207 trainable=trainable, 208 collections=collections, 209 validate_shape=validate_shape, 210 caching_device=caching_device, 211 name=name, 212 variable_def=variable_def, 213 dtype=dtype, 214 expected_shape=expected_shape, 215 import_scope=import_scope, 216 constraint=constraint, 217 use_resource=use_resource, 218 synchronization=synchronization, 219 aggregation=aggregation, 220 shape=shape) 221 222 def _variable_v2_call(cls, 223 initial_value=None, 224 trainable=None, 225 validate_shape=True, 226 caching_device=None, 227 name=None, 228 variable_def=None, 229 dtype=None, 230 import_scope=None, 231 constraint=None, 232 synchronization=VariableSynchronization.AUTO, 233 aggregation=VariableAggregation.NONE, 234 shape=None): 235 """Call on Variable class. Useful to force the signature.""" 236 previous_getter = lambda **kws: default_variable_creator_v2(None, **kws) 237 for _, getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access 238 previous_getter = _make_getter(getter, previous_getter) 239 240 # Reset `aggregation` that is explicitly set as `None` to the enum NONE. 241 if aggregation is None: 242 aggregation = VariableAggregation.NONE 243 return previous_getter( 244 initial_value=initial_value, 245 trainable=trainable, 246 validate_shape=validate_shape, 247 caching_device=caching_device, 248 name=name, 249 variable_def=variable_def, 250 dtype=dtype, 251 import_scope=import_scope, 252 constraint=constraint, 253 synchronization=synchronization, 254 aggregation=aggregation, 255 shape=shape) 256 257 def __call__(cls, *args, **kwargs): 258 if cls is VariableV1: 259 return cls._variable_v1_call(*args, **kwargs) 260 elif cls is Variable: 261 return cls._variable_v2_call(*args, **kwargs) 262 else: 263 return super(VariableMetaclass, cls).__call__(*args, **kwargs) 264 265 266@tf_export("Variable", v1=[]) 267class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)): 268 """See the [variable guide](https://tensorflow.org/guide/variable). 269 270 A variable maintains shared, persistent state manipulated by a program. 271 272 The `Variable()` constructor requires an initial value for the variable, which 273 can be a `Tensor` of any type and shape. This initial value defines the type 274 and shape of the variable. After construction, the type and shape of the 275 variable are fixed. The value can be changed using one of the assign methods. 276 277 >>> v = tf.Variable(1.) 278 >>> v.assign(2.) 279 <tf.Variable ... shape=() dtype=float32, numpy=2.0> 280 >>> v.assign_add(0.5) 281 <tf.Variable ... shape=() dtype=float32, numpy=2.5> 282 283 The `shape` argument to `Variable`'s constructor allows you to construct a 284 variable with a less defined shape than its `initial_value`: 285 286 >>> v = tf.Variable(1., shape=tf.TensorShape(None)) 287 >>> v.assign([[1.]]) 288 <tf.Variable ... shape=<unknown> dtype=float32, numpy=array([[1.]], ...)> 289 290 Just like any `Tensor`, variables created with `Variable()` can be used as 291 inputs to operations. Additionally, all the operators overloaded for the 292 `Tensor` class are carried over to variables. 293 294 >>> w = tf.Variable([[1.], [2.]]) 295 >>> x = tf.constant([[3., 4.]]) 296 >>> tf.matmul(w, x) 297 <tf.Tensor:... shape=(2, 2), ... numpy= 298 array([[3., 4.], 299 [6., 8.]], dtype=float32)> 300 >>> tf.sigmoid(w + x) 301 <tf.Tensor:... shape=(2, 2), ...> 302 303 When building a machine learning model it is often convenient to distinguish 304 between variables holding trainable model parameters and other variables such 305 as a `step` variable used to count training steps. To make this easier, the 306 variable constructor supports a `trainable=<bool>` 307 parameter. `tf.GradientTape` watches trainable variables by default: 308 309 >>> with tf.GradientTape(persistent=True) as tape: 310 ... trainable = tf.Variable(1.) 311 ... non_trainable = tf.Variable(2., trainable=False) 312 ... x1 = trainable * 2. 313 ... x2 = non_trainable * 3. 314 >>> tape.gradient(x1, trainable) 315 <tf.Tensor:... shape=(), dtype=float32, numpy=2.0> 316 >>> assert tape.gradient(x2, non_trainable) is None # Unwatched 317 318 Variables are automatically tracked when assigned to attributes of types 319 inheriting from `tf.Module`. 320 321 >>> m = tf.Module() 322 >>> m.v = tf.Variable([1.]) 323 >>> m.trainable_variables 324 (<tf.Variable ... shape=(1,) ... numpy=array([1.], dtype=float32)>,) 325 326 This tracking then allows saving variable values to 327 [training checkpoints](https://www.tensorflow.org/guide/checkpoint), or to 328 [SavedModels](https://www.tensorflow.org/guide/saved_model) which include 329 serialized TensorFlow graphs. 330 331 Variables are often captured and manipulated by `tf.function`s. This works the 332 same way the un-decorated function would have: 333 334 >>> v = tf.Variable(0.) 335 >>> read_and_decrement = tf.function(lambda: v.assign_sub(0.1)) 336 >>> read_and_decrement() 337 <tf.Tensor: shape=(), dtype=float32, numpy=-0.1> 338 >>> read_and_decrement() 339 <tf.Tensor: shape=(), dtype=float32, numpy=-0.2> 340 341 Variables created inside a `tf.function` must be owned outside the function 342 and be created only once: 343 344 >>> class M(tf.Module): 345 ... @tf.function 346 ... def __call__(self, x): 347 ... if not hasattr(self, "v"): # Or set self.v to None in __init__ 348 ... self.v = tf.Variable(x) 349 ... return self.v * x 350 >>> m = M() 351 >>> m(2.) 352 <tf.Tensor: shape=(), dtype=float32, numpy=4.0> 353 >>> m(3.) 354 <tf.Tensor: shape=(), dtype=float32, numpy=6.0> 355 >>> m.v 356 <tf.Variable ... shape=() dtype=float32, numpy=2.0> 357 358 See the `tf.function` documentation for details. 359 """ 360 361 @deprecated_args( 362 None, 363 "A variable's value can be manually cached by calling " 364 "tf.Variable.read_value() under a tf.device scope. The caching_device " 365 "argument does not work properly.", 366 "caching_device") 367 def __init__(self, 368 initial_value=None, 369 trainable=None, 370 validate_shape=True, 371 caching_device=None, 372 name=None, 373 variable_def=None, 374 dtype=None, 375 import_scope=None, 376 constraint=None, 377 synchronization=VariableSynchronization.AUTO, 378 aggregation=VariableAggregation.NONE, 379 shape=None): 380 """Creates a new variable with value `initial_value`. 381 382 Args: 383 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 384 which is the initial value for the Variable. The initial value must have 385 a shape specified unless `validate_shape` is set to False. Can also be a 386 callable with no argument that returns the initial value when called. In 387 that case, `dtype` must be specified. (Note that initializer functions 388 from init_ops.py must first be bound to a shape before being used here.) 389 trainable: If `True`, GradientTapes automatically watch uses of this 390 variable. Defaults to `True`, unless `synchronization` is set to 391 `ON_READ`, in which case it defaults to `False`. 392 validate_shape: If `False`, allows the variable to be initialized with a 393 value of unknown shape. If `True`, the default, the shape of 394 `initial_value` must be known. 395 caching_device: Optional device string describing where the Variable 396 should be cached for reading. Defaults to the Variable's device. If not 397 `None`, caches on another device. Typical use is to cache on the device 398 where the Ops using the Variable reside, to deduplicate copying through 399 `Switch` and other conditional statements. 400 name: Optional name for the variable. Defaults to `'Variable'` and gets 401 uniquified automatically. 402 variable_def: `VariableDef` protocol buffer. If not `None`, recreates the 403 Variable object with its contents, referencing the variable's nodes in 404 the graph, which must already exist. The graph is not changed. 405 `variable_def` and the other arguments are mutually exclusive. 406 dtype: If set, initial_value will be converted to the given type. If 407 `None`, either the datatype will be kept (if `initial_value` is a 408 Tensor), or `convert_to_tensor` will decide. 409 import_scope: Optional `string`. Name scope to add to the `Variable.` Only 410 used when initializing from protocol buffer. 411 constraint: An optional projection function to be applied to the variable 412 after being updated by an `Optimizer` (e.g. used to implement norm 413 constraints or value constraints for layer weights). The function must 414 take as input the unprojected Tensor representing the value of the 415 variable and return the Tensor for the projected value (which must have 416 the same shape). Constraints are not safe to use when doing asynchronous 417 distributed training. 418 synchronization: Indicates when a distributed a variable will be 419 aggregated. Accepted values are constants defined in the class 420 `tf.VariableSynchronization`. By default the synchronization is set to 421 `AUTO` and the current `DistributionStrategy` chooses when to 422 synchronize. 423 aggregation: Indicates how a distributed variable will be aggregated. 424 Accepted values are constants defined in the class 425 `tf.VariableAggregation`. 426 shape: (optional) The shape of this variable. If None, the shape of 427 `initial_value` will be used. When setting this argument to 428 `tf.TensorShape(None)` (representing an unspecified shape), the variable 429 can be assigned with values of different shapes. 430 431 Raises: 432 ValueError: If both `variable_def` and initial_value are specified. 433 ValueError: If the initial value is not specified, or does not have a 434 shape and `validate_shape` is `True`. 435 """ 436 raise NotImplementedError 437 438 def __repr__(self): 439 raise NotImplementedError 440 441 def value(self): 442 """Returns the last snapshot of this variable. 443 444 You usually do not need to call this method as all ops that need the value 445 of the variable call it automatically through a `convert_to_tensor()` call. 446 447 Returns a `Tensor` which holds the value of the variable. You can not 448 assign a new value to this tensor as it is not a reference to the variable. 449 450 To avoid copies, if the consumer of the returned value is on the same device 451 as the variable, this actually returns the live value of the variable, not 452 a copy. Updates to the variable are seen by the consumer. If the consumer 453 is on a different device it will get a copy of the variable. 454 455 Returns: 456 A `Tensor` containing the value of the variable. 457 """ 458 raise NotImplementedError 459 460 def read_value(self): 461 """Returns the value of this variable, read in the current context. 462 463 Can be different from value() if it's on another device, with control 464 dependencies, etc. 465 466 Returns: 467 A `Tensor` containing the value of the variable. 468 """ 469 raise NotImplementedError 470 471 def set_shape(self, shape): 472 """Overrides the shape for this variable. 473 474 Args: 475 shape: the `TensorShape` representing the overridden shape. 476 """ 477 raise NotImplementedError 478 479 @property 480 def trainable(self): 481 raise NotImplementedError 482 483 @property 484 def synchronization(self): 485 raise NotImplementedError 486 487 @property 488 def aggregation(self): 489 raise NotImplementedError 490 491 def eval(self, session=None): 492 """In a session, computes and returns the value of this variable. 493 494 This is not a graph construction method, it does not add ops to the graph. 495 496 This convenience method requires a session where the graph 497 containing this variable has been launched. If no session is 498 passed, the default session is used. See `tf.compat.v1.Session` for more 499 information on launching a graph and on sessions. 500 501 ```python 502 v = tf.Variable([1, 2]) 503 init = tf.compat.v1.global_variables_initializer() 504 505 with tf.compat.v1.Session() as sess: 506 sess.run(init) 507 # Usage passing the session explicitly. 508 print(v.eval(sess)) 509 # Usage with the default session. The 'with' block 510 # above makes 'sess' the default session. 511 print(v.eval()) 512 ``` 513 514 Args: 515 session: The session to use to evaluate this variable. If none, the 516 default session is used. 517 518 Returns: 519 A numpy `ndarray` with a copy of the value of this variable. 520 """ 521 raise NotImplementedError 522 523 @deprecated( 524 None, "Use Variable.read_value. Variables in 2.X are initialized " 525 "automatically both in eager and graph (inside tf.defun) contexts.") 526 def initialized_value(self): 527 """Returns the value of the initialized variable. 528 529 You should use this instead of the variable itself to initialize another 530 variable with a value that depends on the value of this variable. 531 532 ```python 533 # Initialize 'v' with a random tensor. 534 v = tf.Variable(tf.random.truncated_normal([10, 40])) 535 # Use `initialized_value` to guarantee that `v` has been 536 # initialized before its value is used to initialize `w`. 537 # The random values are picked only once. 538 w = tf.Variable(v.initialized_value() * 2.0) 539 ``` 540 541 Returns: 542 A `Tensor` holding the value of this variable after its initializer 543 has run. 544 """ 545 with ops.init_scope(): 546 return control_flow_ops.cond( 547 is_variable_initialized(self), self.read_value, 548 lambda: self.initial_value) 549 550 @property 551 def initial_value(self): 552 """Returns the Tensor used as the initial value for the variable. 553 554 Note that this is different from `initialized_value()` which runs 555 the op that initializes the variable before returning its value. 556 This method returns the tensor that is used by the op that initializes 557 the variable. 558 559 Returns: 560 A `Tensor`. 561 """ 562 raise NotImplementedError 563 564 @property 565 def constraint(self): 566 """Returns the constraint function associated with this variable. 567 568 Returns: 569 The constraint function that was passed to the variable constructor. 570 Can be `None` if no constraint was passed. 571 """ 572 raise NotImplementedError 573 574 def assign(self, value, use_locking=False, name=None, read_value=True): 575 """Assigns a new value to the variable. 576 577 This is essentially a shortcut for `assign(self, value)`. 578 579 Args: 580 value: A `Tensor`. The new value for this variable. 581 use_locking: If `True`, use locking during the assignment. 582 name: The name of the operation to be created 583 read_value: if True, will return something which evaluates to the new 584 value of the variable; if False will return the assign op. 585 586 Returns: 587 The updated variable. If `read_value` is false, instead returns None in 588 Eager mode and the assign op in graph mode. 589 """ 590 raise NotImplementedError 591 592 def assign_add(self, delta, use_locking=False, name=None, read_value=True): 593 """Adds a value to this variable. 594 595 This is essentially a shortcut for `assign_add(self, delta)`. 596 597 Args: 598 delta: A `Tensor`. The value to add to this variable. 599 use_locking: If `True`, use locking during the operation. 600 name: The name of the operation to be created 601 read_value: if True, will return something which evaluates to the new 602 value of the variable; if False will return the assign op. 603 604 Returns: 605 The updated variable. If `read_value` is false, instead returns None in 606 Eager mode and the assign op in graph mode. 607 """ 608 raise NotImplementedError 609 610 def assign_sub(self, delta, use_locking=False, name=None, read_value=True): 611 """Subtracts a value from this variable. 612 613 This is essentially a shortcut for `assign_sub(self, delta)`. 614 615 Args: 616 delta: A `Tensor`. The value to subtract from this variable. 617 use_locking: If `True`, use locking during the operation. 618 name: The name of the operation to be created 619 read_value: if True, will return something which evaluates to the new 620 value of the variable; if False will return the assign op. 621 622 Returns: 623 The updated variable. If `read_value` is false, instead returns None in 624 Eager mode and the assign op in graph mode. 625 """ 626 raise NotImplementedError 627 628 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 629 """Subtracts `tf.IndexedSlices` from this variable. 630 631 Args: 632 sparse_delta: `tf.IndexedSlices` to be subtracted from this variable. 633 use_locking: If `True`, use locking during the operation. 634 name: the name of the operation. 635 636 Returns: 637 The updated variable. 638 639 Raises: 640 TypeError: if `sparse_delta` is not an `IndexedSlices`. 641 """ 642 raise NotImplementedError 643 644 def scatter_add(self, sparse_delta, use_locking=False, name=None): 645 """Adds `tf.IndexedSlices` to this variable. 646 647 Args: 648 sparse_delta: `tf.IndexedSlices` to be added to this variable. 649 use_locking: If `True`, use locking during the operation. 650 name: the name of the operation. 651 652 Returns: 653 The updated variable. 654 655 Raises: 656 TypeError: if `sparse_delta` is not an `IndexedSlices`. 657 """ 658 raise NotImplementedError 659 660 def scatter_max(self, sparse_delta, use_locking=False, name=None): 661 """Updates this variable with the max of `tf.IndexedSlices` and itself. 662 663 Args: 664 sparse_delta: `tf.IndexedSlices` to use as an argument of max with this 665 variable. 666 use_locking: If `True`, use locking during the operation. 667 name: the name of the operation. 668 669 Returns: 670 The updated variable. 671 672 Raises: 673 TypeError: if `sparse_delta` is not an `IndexedSlices`. 674 """ 675 raise NotImplementedError 676 677 def scatter_min(self, sparse_delta, use_locking=False, name=None): 678 """Updates this variable with the min of `tf.IndexedSlices` and itself. 679 680 Args: 681 sparse_delta: `tf.IndexedSlices` to use as an argument of min with this 682 variable. 683 use_locking: If `True`, use locking during the operation. 684 name: the name of the operation. 685 686 Returns: 687 The updated variable. 688 689 Raises: 690 TypeError: if `sparse_delta` is not an `IndexedSlices`. 691 """ 692 raise NotImplementedError 693 694 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 695 """Multiply this variable by `tf.IndexedSlices`. 696 697 Args: 698 sparse_delta: `tf.IndexedSlices` to multiply this variable by. 699 use_locking: If `True`, use locking during the operation. 700 name: the name of the operation. 701 702 Returns: 703 The updated variable. 704 705 Raises: 706 TypeError: if `sparse_delta` is not an `IndexedSlices`. 707 """ 708 raise NotImplementedError 709 710 def scatter_div(self, sparse_delta, use_locking=False, name=None): 711 """Divide this variable by `tf.IndexedSlices`. 712 713 Args: 714 sparse_delta: `tf.IndexedSlices` to divide this variable by. 715 use_locking: If `True`, use locking during the operation. 716 name: the name of the operation. 717 718 Returns: 719 The updated variable. 720 721 Raises: 722 TypeError: if `sparse_delta` is not an `IndexedSlices`. 723 """ 724 raise NotImplementedError 725 726 def scatter_update(self, sparse_delta, use_locking=False, name=None): 727 """Assigns `tf.IndexedSlices` to this variable. 728 729 Args: 730 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 731 use_locking: If `True`, use locking during the operation. 732 name: the name of the operation. 733 734 Returns: 735 The updated variable. 736 737 Raises: 738 TypeError: if `sparse_delta` is not an `IndexedSlices`. 739 """ 740 raise NotImplementedError 741 742 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 743 """Assigns `tf.IndexedSlices` to this variable batch-wise. 744 745 Analogous to `batch_gather`. This assumes that this variable and the 746 sparse_delta IndexedSlices have a series of leading dimensions that are the 747 same for all of them, and the updates are performed on the last dimension of 748 indices. In other words, the dimensions should be the following: 749 750 `num_prefix_dims = sparse_delta.indices.ndims - 1` 751 `batch_dim = num_prefix_dims + 1` 752 `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[ 753 batch_dim:]` 754 755 where 756 757 `sparse_delta.updates.shape[:num_prefix_dims]` 758 `== sparse_delta.indices.shape[:num_prefix_dims]` 759 `== var.shape[:num_prefix_dims]` 760 761 And the operation performed can be expressed as: 762 763 `var[i_1, ..., i_n, 764 sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[ 765 i_1, ..., i_n, j]` 766 767 When sparse_delta.indices is a 1D tensor, this operation is equivalent to 768 `scatter_update`. 769 770 To avoid this operation one can looping over the first `ndims` of the 771 variable and using `scatter_update` on the subtensors that result of slicing 772 the first dimension. This is a valid option for `ndims = 1`, but less 773 efficient than this implementation. 774 775 Args: 776 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 777 use_locking: If `True`, use locking during the operation. 778 name: the name of the operation. 779 780 Returns: 781 The updated variable. 782 783 Raises: 784 TypeError: if `sparse_delta` is not an `IndexedSlices`. 785 """ 786 raise NotImplementedError 787 788 def scatter_nd_sub(self, indices, updates, name=None): 789 """Applies sparse subtraction to individual values or slices in a Variable. 790 791 Assuming the variable has rank `P` and `indices` is a `Tensor` of rank `Q`. 792 793 `indices` must be integer tensor, containing indices into self. 794 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 795 796 The innermost dimension of `indices` (with length `K`) corresponds to 797 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 798 dimension of self. 799 800 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 801 802 ``` 803 [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]]. 804 ``` 805 806 For example, say we want to add 4 scattered elements to a rank-1 tensor to 807 8 elements. In Python, that update would look like this: 808 809 ```python 810 v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 811 indices = tf.constant([[4], [3], [1] ,[7]]) 812 updates = tf.constant([9, 10, 11, 12]) 813 op = v.scatter_nd_sub(indices, updates) 814 with tf.compat.v1.Session() as sess: 815 print sess.run(op) 816 ``` 817 818 The resulting update to v would look like this: 819 820 [1, -9, 3, -6, -6, 6, 7, -4] 821 822 See `tf.scatter_nd` for more details about how to make updates to 823 slices. 824 825 Args: 826 indices: The indices to be used in the operation. 827 updates: The values to be used in the operation. 828 name: the name of the operation. 829 830 Returns: 831 The updated variable. 832 """ 833 raise NotImplementedError 834 835 def scatter_nd_add(self, indices, updates, name=None): 836 """Applies sparse addition to individual values or slices in a Variable. 837 838 The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`. 839 840 `indices` must be integer tensor, containing indices into self. 841 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 842 843 The innermost dimension of `indices` (with length `K`) corresponds to 844 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 845 dimension of self. 846 847 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 848 849 ``` 850 [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]]. 851 ``` 852 853 For example, say we want to add 4 scattered elements to a rank-1 tensor to 854 8 elements. In Python, that update would look like this: 855 856 ```python 857 v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 858 indices = tf.constant([[4], [3], [1] ,[7]]) 859 updates = tf.constant([9, 10, 11, 12]) 860 add = v.scatter_nd_add(indices, updates) 861 with tf.compat.v1.Session() as sess: 862 print sess.run(add) 863 ``` 864 865 The resulting update to v would look like this: 866 867 [1, 13, 3, 14, 14, 6, 7, 20] 868 869 See `tf.scatter_nd` for more details about how to make updates to 870 slices. 871 872 Args: 873 indices: The indices to be used in the operation. 874 updates: The values to be used in the operation. 875 name: the name of the operation. 876 877 Returns: 878 The updated variable. 879 """ 880 raise NotImplementedError 881 882 def scatter_nd_update(self, indices, updates, name=None): 883 """Applies sparse assignment to individual values or slices in a Variable. 884 885 The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`. 886 887 `indices` must be integer tensor, containing indices into self. 888 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 889 890 The innermost dimension of `indices` (with length `K`) corresponds to 891 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 892 dimension of self. 893 894 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 895 896 ``` 897 [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]]. 898 ``` 899 900 For example, say we want to add 4 scattered elements to a rank-1 tensor to 901 8 elements. In Python, that update would look like this: 902 903 ```python 904 v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 905 indices = tf.constant([[4], [3], [1] ,[7]]) 906 updates = tf.constant([9, 10, 11, 12]) 907 op = v.scatter_nd_assign(indices, updates) 908 with tf.compat.v1.Session() as sess: 909 print sess.run(op) 910 ``` 911 912 The resulting update to v would look like this: 913 914 [1, 11, 3, 10, 9, 6, 7, 12] 915 916 See `tf.scatter_nd` for more details about how to make updates to 917 slices. 918 919 Args: 920 indices: The indices to be used in the operation. 921 updates: The values to be used in the operation. 922 name: the name of the operation. 923 924 Returns: 925 The updated variable. 926 """ 927 raise NotImplementedError 928 929 def sparse_read(self, indices, name=None): 930 r"""Gather slices from params axis axis according to indices. 931 932 This function supports a subset of tf.gather, see tf.gather for details on 933 usage. 934 935 Args: 936 indices: The index `Tensor`. Must be one of the following types: `int32`, 937 `int64`. Must be in range `[0, params.shape[axis])`. 938 name: A name for the operation (optional). 939 940 Returns: 941 A `Tensor`. Has the same type as `params`. 942 """ 943 raise AttributeError 944 945 def gather_nd(self, indices, name=None): 946 r"""Gather slices from `params` into a Tensor with shape specified by `indices`. 947 948 See tf.gather_nd for details. 949 950 Args: 951 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 952 Index tensor. 953 name: A name for the operation (optional). 954 955 Returns: 956 A `Tensor`. Has the same type as `params`. 957 """ 958 raise AttributeError 959 960 @deprecated(None, "Prefer Dataset.range instead.") 961 def count_up_to(self, limit): 962 """Increments this variable until it reaches `limit`. 963 964 When that Op is run it tries to increment the variable by `1`. If 965 incrementing the variable would bring it above `limit` then the Op raises 966 the exception `OutOfRangeError`. 967 968 If no error is raised, the Op outputs the value of the variable before 969 the increment. 970 971 This is essentially a shortcut for `count_up_to(self, limit)`. 972 973 Args: 974 limit: value at which incrementing the variable raises an error. 975 976 Returns: 977 A `Tensor` that will hold the variable value before the increment. If no 978 other Op modifies this variable, the values produced will all be 979 distinct. 980 """ 981 raise NotImplementedError 982 983 @deprecated(None, 984 "Prefer Variable.assign which has equivalent behavior in 2.X.") 985 def load(self, value, session=None): 986 """Load new value into this variable. 987 988 Writes new value to variable's memory. Doesn't add ops to the graph. 989 990 This convenience method requires a session where the graph 991 containing this variable has been launched. If no session is 992 passed, the default session is used. See `tf.compat.v1.Session` for more 993 information on launching a graph and on sessions. 994 995 ```python 996 v = tf.Variable([1, 2]) 997 init = tf.compat.v1.global_variables_initializer() 998 999 with tf.compat.v1.Session() as sess: 1000 sess.run(init) 1001 # Usage passing the session explicitly. 1002 v.load([2, 3], sess) 1003 print(v.eval(sess)) # prints [2 3] 1004 # Usage with the default session. The 'with' block 1005 # above makes 'sess' the default session. 1006 v.load([3, 4], sess) 1007 print(v.eval()) # prints [3 4] 1008 ``` 1009 1010 Args: 1011 value: New variable value 1012 session: The session to use to evaluate this variable. If none, the 1013 default session is used. 1014 1015 Raises: 1016 ValueError: Session is not passed and no default session 1017 """ 1018 if context.executing_eagerly(): 1019 self.assign(value) 1020 else: 1021 session = session or ops.get_default_session() 1022 if session is None: 1023 raise ValueError( 1024 "Either session argument should be provided or default session " 1025 "should be established") 1026 session.run(self.initializer, {self.initializer.inputs[1]: value}) 1027 1028 # Conversion to tensor. 1029 @staticmethod 1030 def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name 1031 """Utility function for converting a Variable to a Tensor.""" 1032 _ = name 1033 if dtype and not dtype.is_compatible_with(v.dtype): 1034 raise ValueError( 1035 "Incompatible type conversion requested to type '%s' for variable " 1036 "of type '%s'" % (dtype.name, v.dtype.name)) 1037 if as_ref: 1038 return v._ref() # pylint: disable=protected-access 1039 else: 1040 return v.value() 1041 1042 @classmethod 1043 def _OverloadAllOperators(cls): # pylint: disable=invalid-name 1044 """Register overloads for all operators.""" 1045 for operator in ops.Tensor.OVERLOADABLE_OPERATORS: 1046 cls._OverloadOperator(operator) 1047 # For slicing, bind getitem differently than a tensor (use SliceHelperVar 1048 # instead) 1049 # pylint: disable=protected-access 1050 setattr(cls, "__getitem__", array_ops._SliceHelperVar) 1051 1052 @classmethod 1053 def _OverloadOperator(cls, operator): # pylint: disable=invalid-name 1054 """Defer an operator overload to `ops.Tensor`. 1055 1056 We pull the operator out of ops.Tensor dynamically to avoid ordering issues. 1057 1058 Args: 1059 operator: string. The operator name. 1060 """ 1061 # We can't use the overload mechanism on __eq__ & __ne__ since __eq__ is 1062 # called when adding a variable to sets. As a result we call a.value() which 1063 # causes infinite recursion when operating within a GradientTape 1064 # TODO(gjn): Consider removing this 1065 if operator == "__eq__" or operator == "__ne__": 1066 return 1067 1068 tensor_oper = getattr(ops.Tensor, operator) 1069 1070 def _run_op(a, *args, **kwargs): 1071 # pylint: disable=protected-access 1072 return tensor_oper(a.value(), *args, **kwargs) 1073 1074 functools.update_wrapper(_run_op, tensor_oper) 1075 setattr(cls, operator, _run_op) 1076 1077 def __hash__(self): 1078 if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access 1079 raise TypeError("Variable is unhashable if Tensor equality is enabled. " 1080 "Instead, use tensor.experimental_ref() as the key.") 1081 else: 1082 return id(self) 1083 1084 # TODO(gjn): duplicate of math_ops.tensor_equals, consider removing 1085 def __eq__(self, other): 1086 """Compares two variables element-wise for equality.""" 1087 if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access 1088 return gen_math_ops.equal(self, other, incompatible_shape_error=False) 1089 else: 1090 # In legacy graph mode, tensor equality is object equality 1091 return self is other 1092 1093 # TODO(gjn): duplicate of math_ops.tensor_not_equals, consider removing 1094 def __ne__(self, other): 1095 """Compares two variables element-wise for equality.""" 1096 if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access 1097 return gen_math_ops.not_equal(self, other, incompatible_shape_error=False) 1098 else: 1099 # In legacy graph mode, tensor equality is object equality 1100 return self is not other 1101 1102 def __iter__(self): 1103 """Dummy method to prevent iteration. 1104 1105 Do not call. 1106 1107 NOTE(mrry): If we register __getitem__ as an overloaded operator, 1108 Python will valiantly attempt to iterate over the variable's Tensor from 0 1109 to infinity. Declaring this method prevents this unintended behavior. 1110 1111 Raises: 1112 TypeError: when invoked. 1113 """ 1114 raise TypeError("'Variable' object is not iterable.") 1115 1116 # NOTE(mrry): This enables the Variable's overloaded "right" binary 1117 # operators to run when the left operand is an ndarray, because it 1118 # accords the Variable class higher priority than an ndarray, or a 1119 # numpy matrix. 1120 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ 1121 # mechanism, which allows more control over how Variables interact 1122 # with ndarrays. 1123 __array_priority__ = 100 1124 1125 @property 1126 def name(self): 1127 """The name of this variable.""" 1128 raise NotImplementedError 1129 1130 @property 1131 def _shared_name(self): 1132 """The shared name of the variable. 1133 1134 Unlike name(), shared_name doesn't have ":0" suffix. It is user-specified 1135 name with name scope prefix. 1136 1137 Returns: 1138 variable name. 1139 """ 1140 return self.name[:self.name.index(":")] 1141 1142 @property 1143 def initializer(self): 1144 """The initializer operation for this variable.""" 1145 raise NotImplementedError 1146 1147 @property 1148 def device(self): 1149 """The device of this variable.""" 1150 raise NotImplementedError 1151 1152 @property 1153 def dtype(self): 1154 """The `DType` of this variable.""" 1155 raise NotImplementedError 1156 1157 @property 1158 def op(self): 1159 """The `Operation` of this variable.""" 1160 raise NotImplementedError 1161 1162 @property 1163 def graph(self): 1164 """The `Graph` of this variable.""" 1165 raise NotImplementedError 1166 1167 @property 1168 def shape(self): 1169 """The `TensorShape` of this variable. 1170 1171 Returns: 1172 A `TensorShape`. 1173 """ 1174 raise NotImplementedError 1175 1176 def get_shape(self): 1177 """Alias of `Variable.shape`.""" 1178 return self.shape 1179 1180 def _gather_saveables_for_checkpoint(self): 1181 """For implementing `Trackable`. This object is saveable on its own.""" 1182 return {trackable.VARIABLE_VALUE_KEY: self} 1183 1184 def to_proto(self, export_scope=None): 1185 """Converts a `Variable` to a `VariableDef` protocol buffer. 1186 1187 Args: 1188 export_scope: Optional `string`. Name scope to remove. 1189 1190 Returns: 1191 A `VariableDef` protocol buffer, or `None` if the `Variable` is not 1192 in the specified name scope. 1193 """ 1194 raise NotImplementedError 1195 1196 @staticmethod 1197 def from_proto(variable_def, import_scope=None): 1198 """Returns a `Variable` object created from `variable_def`.""" 1199 return RefVariable(variable_def=variable_def, import_scope=import_scope) 1200 1201 def _set_save_slice_info(self, save_slice_info): 1202 """Sets the slice info for this `Variable`. 1203 1204 Args: 1205 save_slice_info: A `Variable.SaveSliceInfo` object. 1206 """ 1207 self._save_slice_info = save_slice_info 1208 1209 def _get_save_slice_info(self): 1210 return self._save_slice_info 1211 1212 def experimental_ref(self): 1213 # tf.Tensor also has the same experimental_ref() API. If you update the 1214 # documenation here, please update tf.Tensor.experimental_ref() as well. 1215 """Returns a hashable reference object to this Variable. 1216 1217 Warning: Experimental API that could be changed or removed. 1218 1219 The primary usecase for this API is to put variables in a set/dictionary. 1220 We can't put variables in a set/dictionary as `variable.__hash__()` is no 1221 longer available starting Tensorflow 2.0. 1222 1223 ```python 1224 import tensorflow as tf 1225 1226 x = tf.Variable(5) 1227 y = tf.Variable(10) 1228 z = tf.Variable(10) 1229 1230 # The followings will raise an exception starting 2.0 1231 # TypeError: Variable is unhashable if Variable equality is enabled. 1232 variable_set = {x, y, z} 1233 variable_dict = {x: 'five', y: 'ten'} 1234 ``` 1235 1236 Instead, we can use `variable.experimental_ref()`. 1237 1238 ```python 1239 variable_set = {x.experimental_ref(), 1240 y.experimental_ref(), 1241 z.experimental_ref()} 1242 1243 print(x.experimental_ref() in variable_set) 1244 ==> True 1245 1246 variable_dict = {x.experimental_ref(): 'five', 1247 y.experimental_ref(): 'ten', 1248 z.experimental_ref(): 'ten'} 1249 1250 print(variable_dict[y.experimental_ref()]) 1251 ==> ten 1252 ``` 1253 1254 Also, the reference object provides `.deref()` function that returns the 1255 original Variable. 1256 1257 ```python 1258 x = tf.Variable(5) 1259 print(x.experimental_ref().deref()) 1260 ==> <tf.Variable 'Variable:0' shape=() dtype=int32, numpy=5> 1261 ``` 1262 """ 1263 return object_identity.Reference(self) 1264 1265 class SaveSliceInfo(object): 1266 """Information on how to save this Variable as a slice. 1267 1268 Provides internal support for saving variables as slices of a larger 1269 variable. This API is not public and is subject to change. 1270 1271 Available properties: 1272 1273 * full_name 1274 * full_shape 1275 * var_offset 1276 * var_shape 1277 """ 1278 1279 def __init__(self, 1280 full_name=None, 1281 full_shape=None, 1282 var_offset=None, 1283 var_shape=None, 1284 save_slice_info_def=None, 1285 import_scope=None): 1286 """Create a `SaveSliceInfo`. 1287 1288 Args: 1289 full_name: Name of the full variable of which this `Variable` is a 1290 slice. 1291 full_shape: Shape of the full variable, as a list of int. 1292 var_offset: Offset of this `Variable` into the full variable, as a list 1293 of int. 1294 var_shape: Shape of this `Variable`, as a list of int. 1295 save_slice_info_def: `SaveSliceInfoDef` protocol buffer. If not `None`, 1296 recreates the SaveSliceInfo object its contents. `save_slice_info_def` 1297 and other arguments are mutually exclusive. 1298 import_scope: Optional `string`. Name scope to add. Only used when 1299 initializing from protocol buffer. 1300 """ 1301 if save_slice_info_def: 1302 assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef) 1303 self.full_name = ops.prepend_name_scope( 1304 save_slice_info_def.full_name, import_scope=import_scope) 1305 self.full_shape = [i for i in save_slice_info_def.full_shape] 1306 self.var_offset = [i for i in save_slice_info_def.var_offset] 1307 self.var_shape = [i for i in save_slice_info_def.var_shape] 1308 else: 1309 self.full_name = full_name 1310 self.full_shape = full_shape 1311 self.var_offset = var_offset 1312 self.var_shape = var_shape 1313 1314 @property 1315 def spec(self): 1316 """Computes the spec string used for saving.""" 1317 full_shape_str = " ".join("%d" % d for d in self.full_shape) + " " 1318 sl_spec = ":".join( 1319 "%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape)) 1320 return full_shape_str + sl_spec 1321 1322 def to_proto(self, export_scope=None): 1323 """Returns a SaveSliceInfoDef() proto. 1324 1325 Args: 1326 export_scope: Optional `string`. Name scope to remove. 1327 1328 Returns: 1329 A `SaveSliceInfoDef` protocol buffer, or None if the `Variable` is not 1330 in the specified name scope. 1331 """ 1332 if (export_scope is None or self.full_name.startswith(export_scope)): 1333 save_slice_info_def = variable_pb2.SaveSliceInfoDef() 1334 save_slice_info_def.full_name = ops.strip_name_scope( 1335 self.full_name, export_scope) 1336 for i in self.full_shape: 1337 save_slice_info_def.full_shape.append(i) 1338 for i in self.var_offset: 1339 save_slice_info_def.var_offset.append(i) 1340 for i in self.var_shape: 1341 save_slice_info_def.var_shape.append(i) 1342 return save_slice_info_def 1343 else: 1344 return None 1345 1346 1347Variable._OverloadAllOperators() # pylint: disable=protected-access 1348_pywrap_utils.RegisterType("Variable", Variable) 1349 1350 1351@tf_export(v1=["Variable"]) 1352class VariableV1(Variable): 1353 """See the [Variables Guide](https://tensorflow.org/guide/variables). 1354 1355 A variable maintains state in the graph across calls to `run()`. You add a 1356 variable to the graph by constructing an instance of the class `Variable`. 1357 1358 The `Variable()` constructor requires an initial value for the variable, 1359 which can be a `Tensor` of any type and shape. The initial value defines the 1360 type and shape of the variable. After construction, the type and shape of 1361 the variable are fixed. The value can be changed using one of the assign 1362 methods. 1363 1364 If you want to change the shape of a variable later you have to use an 1365 `assign` Op with `validate_shape=False`. 1366 1367 Just like any `Tensor`, variables created with `Variable()` can be used as 1368 inputs for other Ops in the graph. Additionally, all the operators 1369 overloaded for the `Tensor` class are carried over to variables, so you can 1370 also add nodes to the graph by just doing arithmetic on variables. 1371 1372 ```python 1373 import tensorflow as tf 1374 1375 # Create a variable. 1376 w = tf.Variable(<initial-value>, name=<optional-name>) 1377 1378 # Use the variable in the graph like any Tensor. 1379 y = tf.matmul(w, ...another variable or tensor...) 1380 1381 # The overloaded operators are available too. 1382 z = tf.sigmoid(w + y) 1383 1384 # Assign a new value to the variable with `assign()` or a related method. 1385 w.assign(w + 1.0) 1386 w.assign_add(1.0) 1387 ``` 1388 1389 When you launch the graph, variables have to be explicitly initialized before 1390 you can run Ops that use their value. You can initialize a variable by 1391 running its *initializer op*, restoring the variable from a save file, or 1392 simply running an `assign` Op that assigns a value to the variable. In fact, 1393 the variable *initializer op* is just an `assign` Op that assigns the 1394 variable's initial value to the variable itself. 1395 1396 ```python 1397 # Launch the graph in a session. 1398 with tf.compat.v1.Session() as sess: 1399 # Run the variable initializer. 1400 sess.run(w.initializer) 1401 # ...you now can run ops that use the value of 'w'... 1402 ``` 1403 1404 The most common initialization pattern is to use the convenience function 1405 `global_variables_initializer()` to add an Op to the graph that initializes 1406 all the variables. You then run that Op after launching the graph. 1407 1408 ```python 1409 # Add an Op to initialize global variables. 1410 init_op = tf.compat.v1.global_variables_initializer() 1411 1412 # Launch the graph in a session. 1413 with tf.compat.v1.Session() as sess: 1414 # Run the Op that initializes global variables. 1415 sess.run(init_op) 1416 # ...you can now run any Op that uses variable values... 1417 ``` 1418 1419 If you need to create a variable with an initial value dependent on another 1420 variable, use the other variable's `initialized_value()`. This ensures that 1421 variables are initialized in the right order. 1422 1423 All variables are automatically collected in the graph where they are 1424 created. By default, the constructor adds the new variable to the graph 1425 collection `GraphKeys.GLOBAL_VARIABLES`. The convenience function 1426 `global_variables()` returns the contents of that collection. 1427 1428 When building a machine learning model it is often convenient to distinguish 1429 between variables holding the trainable model parameters and other variables 1430 such as a `global step` variable used to count training steps. To make this 1431 easier, the variable constructor supports a `trainable=<bool>` parameter. If 1432 `True`, the new variable is also added to the graph collection 1433 `GraphKeys.TRAINABLE_VARIABLES`. The convenience function 1434 `trainable_variables()` returns the contents of this collection. The 1435 various `Optimizer` classes use this collection as the default list of 1436 variables to optimize. 1437 1438 WARNING: tf.Variable objects by default have a non-intuitive memory model. A 1439 Variable is represented internally as a mutable Tensor which can 1440 non-deterministically alias other Tensors in a graph. The set of operations 1441 which consume a Variable and can lead to aliasing is undetermined and can 1442 change across TensorFlow versions. Avoid writing code which relies on the 1443 value of a Variable either changing or not changing as other operations 1444 happen. For example, using Variable objects or simple functions thereof as 1445 predicates in a `tf.cond` is dangerous and error-prone: 1446 1447 ``` 1448 v = tf.Variable(True) 1449 tf.cond(v, lambda: v.assign(False), my_false_fn) # Note: this is broken. 1450 ``` 1451 1452 Here, adding `use_resource=True` when constructing the variable will 1453 fix any nondeterminism issues: 1454 ``` 1455 v = tf.Variable(True, use_resource=True) 1456 tf.cond(v, lambda: v.assign(False), my_false_fn) 1457 ``` 1458 1459 To use the replacement for variables which does 1460 not have these issues: 1461 1462 * Add `use_resource=True` when constructing `tf.Variable`; 1463 * Call `tf.compat.v1.get_variable_scope().set_use_resource(True)` inside a 1464 `tf.compat.v1.variable_scope` before the `tf.compat.v1.get_variable()` call. 1465 """ 1466 1467 def __init__( 1468 self, # pylint: disable=super-init-not-called 1469 initial_value=None, 1470 trainable=None, 1471 collections=None, 1472 validate_shape=True, 1473 caching_device=None, 1474 name=None, 1475 variable_def=None, 1476 dtype=None, 1477 expected_shape=None, 1478 import_scope=None, 1479 constraint=None, 1480 use_resource=None, 1481 synchronization=VariableSynchronization.AUTO, 1482 aggregation=VariableAggregation.NONE, 1483 shape=None): 1484 """Creates a new variable with value `initial_value`. 1485 1486 The new variable is added to the graph collections listed in `collections`, 1487 which defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1488 1489 If `trainable` is `True` the variable is also added to the graph collection 1490 `GraphKeys.TRAINABLE_VARIABLES`. 1491 1492 This constructor creates both a `variable` Op and an `assign` Op to set the 1493 variable to its initial value. 1494 1495 Args: 1496 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 1497 which is the initial value for the Variable. The initial value must have 1498 a shape specified unless `validate_shape` is set to False. Can also be a 1499 callable with no argument that returns the initial value when called. In 1500 that case, `dtype` must be specified. (Note that initializer functions 1501 from init_ops.py must first be bound to a shape before being used here.) 1502 trainable: If `True`, also adds the variable to the graph collection 1503 `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default 1504 list of variables to use by the `Optimizer` classes. Defaults to `True`, 1505 unless `synchronization` is set to `ON_READ`, in which case it defaults 1506 to `False`. 1507 collections: List of graph collections keys. The new variable is added to 1508 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1509 validate_shape: If `False`, allows the variable to be initialized with a 1510 value of unknown shape. If `True`, the default, the shape of 1511 `initial_value` must be known. 1512 caching_device: Optional device string describing where the Variable 1513 should be cached for reading. Defaults to the Variable's device. If not 1514 `None`, caches on another device. Typical use is to cache on the device 1515 where the Ops using the Variable reside, to deduplicate copying through 1516 `Switch` and other conditional statements. 1517 name: Optional name for the variable. Defaults to `'Variable'` and gets 1518 uniquified automatically. 1519 variable_def: `VariableDef` protocol buffer. If not `None`, recreates the 1520 Variable object with its contents, referencing the variable's nodes in 1521 the graph, which must already exist. The graph is not changed. 1522 `variable_def` and the other arguments are mutually exclusive. 1523 dtype: If set, initial_value will be converted to the given type. If 1524 `None`, either the datatype will be kept (if `initial_value` is a 1525 Tensor), or `convert_to_tensor` will decide. 1526 expected_shape: A TensorShape. If set, initial_value is expected to have 1527 this shape. 1528 import_scope: Optional `string`. Name scope to add to the `Variable.` Only 1529 used when initializing from protocol buffer. 1530 constraint: An optional projection function to be applied to the variable 1531 after being updated by an `Optimizer` (e.g. used to implement norm 1532 constraints or value constraints for layer weights). The function must 1533 take as input the unprojected Tensor representing the value of the 1534 variable and return the Tensor for the projected value (which must have 1535 the same shape). Constraints are not safe to use when doing asynchronous 1536 distributed training. 1537 use_resource: whether to use resource variables. 1538 synchronization: Indicates when a distributed a variable will be 1539 aggregated. Accepted values are constants defined in the class 1540 `tf.VariableSynchronization`. By default the synchronization is set to 1541 `AUTO` and the current `DistributionStrategy` chooses when to 1542 synchronize. 1543 aggregation: Indicates how a distributed variable will be aggregated. 1544 Accepted values are constants defined in the class 1545 `tf.VariableAggregation`. 1546 shape: (optional) The shape of this variable. If None, the shape of 1547 `initial_value` will be used. When setting this argument to 1548 `tf.TensorShape(None)` (representing an unspecified shape), the variable 1549 can be assigned with values of different shapes. 1550 1551 Raises: 1552 ValueError: If both `variable_def` and initial_value are specified. 1553 ValueError: If the initial value is not specified, or does not have a 1554 shape and `validate_shape` is `True`. 1555 RuntimeError: If eager execution is enabled. 1556 """ 1557 1558 SaveSliceInfo = Variable.SaveSliceInfo 1559 1560 1561# TODO(apassos): do not repeat all comments here 1562class RefVariable(VariableV1): 1563 """Ref-based implementation of variables.""" 1564 1565 def __init__( 1566 self, # pylint: disable=super-init-not-called 1567 initial_value=None, 1568 trainable=None, 1569 collections=None, 1570 validate_shape=True, 1571 caching_device=None, 1572 name=None, 1573 variable_def=None, 1574 dtype=None, 1575 expected_shape=None, 1576 import_scope=None, 1577 constraint=None, 1578 synchronization=None, 1579 aggregation=None, 1580 shape=None): 1581 """Creates a new variable with value `initial_value`. 1582 1583 The new variable is added to the graph collections listed in `collections`, 1584 which defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1585 1586 If `trainable` is `True` the variable is also added to the graph collection 1587 `GraphKeys.TRAINABLE_VARIABLES`. 1588 1589 This constructor creates both a `variable` Op and an `assign` Op to set the 1590 variable to its initial value. 1591 1592 Args: 1593 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 1594 which is the initial value for the Variable. The initial value must have 1595 a shape specified unless `validate_shape` is set to False. Can also be a 1596 callable with no argument that returns the initial value when called. In 1597 that case, `dtype` must be specified. (Note that initializer functions 1598 from init_ops.py must first be bound to a shape before being used here.) 1599 trainable: If `True`, also adds the variable to the graph collection 1600 `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default 1601 list of variables to use by the `Optimizer` classes. Defaults to `True`, 1602 unless `synchronization` is set to `ON_READ`, in which case it defaults 1603 to `False`. 1604 collections: List of graph collections keys. The new variable is added to 1605 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1606 validate_shape: If `False`, allows the variable to be initialized with a 1607 value of unknown shape. If `True`, the default, the shape of 1608 `initial_value` must be known. 1609 caching_device: Optional device string describing where the Variable 1610 should be cached for reading. Defaults to the Variable's device. If not 1611 `None`, caches on another device. Typical use is to cache on the device 1612 where the Ops using the Variable reside, to deduplicate copying through 1613 `Switch` and other conditional statements. 1614 name: Optional name for the variable. Defaults to `'Variable'` and gets 1615 uniquified automatically. 1616 variable_def: `VariableDef` protocol buffer. If not `None`, recreates the 1617 Variable object with its contents, referencing the variable's nodes in 1618 the graph, which must already exist. The graph is not changed. 1619 `variable_def` and the other arguments are mutually exclusive. 1620 dtype: If set, initial_value will be converted to the given type. If 1621 `None`, either the datatype will be kept (if `initial_value` is a 1622 Tensor), or `convert_to_tensor` will decide. 1623 expected_shape: A TensorShape. If set, initial_value is expected to have 1624 this shape. 1625 import_scope: Optional `string`. Name scope to add to the `Variable.` Only 1626 used when initializing from protocol buffer. 1627 constraint: An optional projection function to be applied to the variable 1628 after being updated by an `Optimizer` (e.g. used to implement norm 1629 constraints or value constraints for layer weights). The function must 1630 take as input the unprojected Tensor representing the value of the 1631 variable and return the Tensor for the projected value (which must have 1632 the same shape). Constraints are not safe to use when doing asynchronous 1633 distributed training. 1634 synchronization: Indicates when a distributed a variable will be 1635 aggregated. Accepted values are constants defined in the class 1636 `tf.VariableSynchronization`. By default the synchronization is set to 1637 `AUTO` and the current `DistributionStrategy` chooses when to 1638 synchronize. 1639 aggregation: Indicates how a distributed variable will be aggregated. 1640 Accepted values are constants defined in the class 1641 `tf.VariableAggregation`. 1642 shape: (optional) The shape of this variable. If None, the shape of 1643 `initial_value` will be used. When setting this argument to 1644 `tf.TensorShape(None)` (representing an unspecified shape), the variable 1645 can be assigned with values of different shapes. 1646 1647 Raises: 1648 ValueError: If both `variable_def` and initial_value are specified. 1649 ValueError: If the initial value is not specified, or does not have a 1650 shape and `validate_shape` is `True`. 1651 RuntimeError: If eager execution is enabled. 1652 """ 1653 self._in_graph_mode = True 1654 if variable_def: 1655 # If variable_def is provided, recreates the variable from its fields. 1656 if initial_value: 1657 raise ValueError("variable_def and initial_value are mutually " 1658 "exclusive.") 1659 self._init_from_proto(variable_def, import_scope=import_scope) 1660 else: 1661 # Create from initial_value. 1662 self._init_from_args( 1663 initial_value=initial_value, 1664 trainable=trainable, 1665 collections=collections, 1666 validate_shape=validate_shape, 1667 caching_device=caching_device, 1668 name=name, 1669 dtype=dtype, 1670 expected_shape=expected_shape, 1671 constraint=constraint, 1672 synchronization=synchronization, 1673 aggregation=aggregation, 1674 shape=shape) 1675 1676 def __repr__(self): 1677 if context.executing_eagerly() and not self._in_graph_mode: 1678 return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % ( 1679 self.name, self.get_shape(), self.dtype.name, 1680 ops.numpy_text(self.read_value(), is_repr=True)) 1681 else: 1682 return "<tf.Variable '%s' shape=%s dtype=%s>" % ( 1683 self.name, self.get_shape(), self.dtype.name) 1684 1685 def _init_from_args(self, 1686 initial_value=None, 1687 trainable=None, 1688 collections=None, 1689 validate_shape=True, 1690 caching_device=None, 1691 name=None, 1692 dtype=None, 1693 expected_shape=None, 1694 constraint=None, 1695 synchronization=None, 1696 aggregation=None, 1697 shape=None): 1698 """Creates a new variable from arguments. 1699 1700 Args: 1701 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 1702 which is the initial value for the Variable. The initial value must have 1703 a shape specified unless `validate_shape` is set to False. Can also be a 1704 callable with no argument that returns the initial value when called. 1705 (Note that initializer functions from init_ops.py must first be bound to 1706 a shape before being used here.) 1707 trainable: If `True`, also adds the variable to the graph collection 1708 `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default 1709 list of variables to use by the `Optimizer` classes. Defaults to `True`, 1710 unless `synchronization` is set to `ON_READ`, in which case it defaults 1711 to `False`. 1712 collections: List of graph collections keys. The new variable is added to 1713 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1714 validate_shape: If `False`, allows the variable to be initialized with a 1715 value of unknown shape. If `True`, the default, the shape of 1716 `initial_value` must be known. 1717 caching_device: Optional device string or function describing where the 1718 Variable should be cached for reading. Defaults to the Variable's 1719 device. If not `None`, caches on another device. Typical use is to 1720 cache on the device where the Ops using the Variable reside, to 1721 deduplicate copying through `Switch` and other conditional statements. 1722 name: Optional name for the variable. Defaults to `'Variable'` and gets 1723 uniquified automatically. 1724 dtype: If set, initial_value will be converted to the given type. If None, 1725 either the datatype will be kept (if initial_value is a Tensor) or 1726 float32 will be used (if it is a Python object convertible to a Tensor). 1727 expected_shape: Deprecated. Ignored. 1728 constraint: An optional projection function to be applied to the variable 1729 after being updated by an `Optimizer` (e.g. used to implement norm 1730 constraints or value constraints for layer weights). The function must 1731 take as input the unprojected Tensor representing the value of the 1732 variable and return the Tensor for the projected value (which must have 1733 the same shape). Constraints are not safe to use when doing asynchronous 1734 distributed training. 1735 synchronization: Indicates when a distributed a variable will be 1736 aggregated. Accepted values are constants defined in the class 1737 `tf.VariableSynchronization`. By default the synchronization is set to 1738 `AUTO` and the current `DistributionStrategy` chooses when to 1739 synchronize. 1740 aggregation: Indicates how a distributed variable will be aggregated. 1741 Accepted values are constants defined in the class 1742 `tf.VariableAggregation`. 1743 shape: (optional) The shape of this variable. If None, the shape of 1744 `initial_value` will be used. When setting this argument to 1745 `tf.TensorShape(None)` (representing an unspecified shape), the variable 1746 can be assigned with values of different shapes. 1747 1748 Raises: 1749 ValueError: If the initial value is not specified, or does not have a 1750 shape and `validate_shape` is `True`. 1751 RuntimeError: If lifted into the eager context. 1752 """ 1753 _ = expected_shape 1754 if initial_value is None: 1755 raise ValueError("initial_value must be specified.") 1756 init_from_fn = callable(initial_value) 1757 1758 if collections is None: 1759 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 1760 if not isinstance(collections, (list, tuple, set)): 1761 raise ValueError( 1762 "collections argument to Variable constructor must be a list, tuple, " 1763 "or set. Got %s of type %s" % (collections, type(collections))) 1764 if constraint is not None and not callable(constraint): 1765 raise ValueError("The `constraint` argument must be a callable.") 1766 1767 # Store the graph key so optimizers know how to only retrieve variables from 1768 # this graph. 1769 self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 1770 if isinstance(initial_value, trackable.CheckpointInitialValue): 1771 self._maybe_initialize_trackable() 1772 self._update_uid = initial_value.checkpoint_position.restore_uid 1773 initial_value = initial_value.wrapped_value 1774 1775 synchronization, aggregation, trainable = ( 1776 validate_synchronization_aggregation_trainable(synchronization, 1777 aggregation, trainable, 1778 name)) 1779 self._synchronization = synchronization 1780 self._aggregation = aggregation 1781 self._trainable = trainable 1782 if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: 1783 collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] 1784 with ops.init_scope(): 1785 # Ensure that we weren't lifted into the eager context. 1786 if context.executing_eagerly(): 1787 raise RuntimeError( 1788 "RefVariable not supported when eager execution is enabled. ") 1789 with ops.name_scope(name, "Variable", 1790 [] if init_from_fn else [initial_value]) as name: 1791 1792 if init_from_fn: 1793 # Use attr_scope and device(None) to simulate the behavior of 1794 # colocate_with when the variable we want to colocate with doesn't 1795 # yet exist. 1796 true_name = ops.name_from_scope_name(name) # pylint: disable=protected-access 1797 attr = attr_value_pb2.AttrValue( 1798 list=attr_value_pb2.AttrValue.ListValue( 1799 s=[compat.as_bytes("loc:@%s" % true_name)])) 1800 # pylint: disable=protected-access 1801 with ops.get_default_graph()._attr_scope({"_class": attr}): 1802 with ops.name_scope("Initializer"), ops.device(None): 1803 self._initial_value = ops.convert_to_tensor( 1804 initial_value(), name="initial_value", dtype=dtype) 1805 if shape is None: 1806 shape = ( 1807 self._initial_value.get_shape() 1808 if validate_shape else tensor_shape.unknown_shape()) 1809 self._variable = state_ops.variable_op_v2( 1810 shape, self._initial_value.dtype.base_dtype, name=name) 1811 # pylint: enable=protected-access 1812 1813 # Or get the initial value from a Tensor or Python object. 1814 else: 1815 self._initial_value = ops.convert_to_tensor( 1816 initial_value, name="initial_value", dtype=dtype) 1817 # pylint: disable=protected-access 1818 if self._initial_value.op._get_control_flow_context() is not None: 1819 raise ValueError( 1820 "Initializer for variable %s is from inside a control-flow " 1821 "construct, such as a loop or conditional. When creating a " 1822 "variable inside a loop or conditional, use a lambda as the " 1823 "initializer." % name) 1824 if shape is None: 1825 # pylint: enable=protected-access 1826 shape = ( 1827 self._initial_value.get_shape() 1828 if validate_shape else tensor_shape.unknown_shape()) 1829 # In this case, the variable op can't be created until after the 1830 # initial_value has been converted to a Tensor with a known type. 1831 self._variable = state_ops.variable_op_v2( 1832 shape, self._initial_value.dtype.base_dtype, name=name) 1833 1834 # Cache the name in `self`, because some APIs call `Variable.name` in a 1835 # tight loop, and this halves the cost. 1836 self._name = self._variable.name 1837 1838 # Manually overrides the variable's shape with the initial value's. 1839 if validate_shape: 1840 initial_value_shape = self._initial_value.get_shape() 1841 if not initial_value_shape.is_fully_defined(): 1842 raise ValueError("initial_value must have a shape specified: %s" % 1843 self._initial_value) 1844 1845 # If 'initial_value' makes use of other variables, make sure we don't 1846 # have an issue if these other variables aren't initialized first by 1847 # using their initialized_value() method. 1848 self._initializer_op = state_ops.assign( 1849 self._variable, 1850 _try_guard_against_uninitialized_dependencies( 1851 name, self._initial_value), 1852 validate_shape=validate_shape).op 1853 1854 # TODO(vrv): Change this class to not take caching_device, but 1855 # to take the op to colocate the snapshot with, so we can use 1856 # colocation rather than devices. 1857 if caching_device is not None: 1858 with ops.device(caching_device): 1859 self._snapshot = array_ops.identity(self._variable, name="read") 1860 else: 1861 with ops.colocate_with(self._variable.op): 1862 self._snapshot = array_ops.identity(self._variable, name="read") 1863 ops.add_to_collections(collections, self) 1864 1865 self._caching_device = caching_device 1866 self._save_slice_info = None 1867 self._constraint = constraint 1868 1869 def _init_from_proto(self, variable_def, import_scope=None): 1870 """Recreates the Variable object from a `VariableDef` protocol buffer. 1871 1872 Args: 1873 variable_def: `VariableDef` protocol buffer, describing a variable whose 1874 nodes already exists in the graph. 1875 import_scope: Optional `string`. Name scope to add. 1876 """ 1877 assert isinstance(variable_def, variable_pb2.VariableDef) 1878 # Create from variable_def. 1879 g = ops.get_default_graph() 1880 self._variable = g.as_graph_element( 1881 ops.prepend_name_scope( 1882 variable_def.variable_name, import_scope=import_scope)) 1883 self._name = self._variable.name 1884 self._initializer_op = g.as_graph_element( 1885 ops.prepend_name_scope( 1886 variable_def.initializer_name, import_scope=import_scope)) 1887 # Tests whether initial_value_name exists first for backwards compatibility. 1888 if (hasattr(variable_def, "initial_value_name") and 1889 variable_def.initial_value_name): 1890 self._initial_value = g.as_graph_element( 1891 ops.prepend_name_scope( 1892 variable_def.initial_value_name, import_scope=import_scope)) 1893 else: 1894 self._initial_value = None 1895 synchronization, aggregation, trainable = ( 1896 validate_synchronization_aggregation_trainable( 1897 variable_def.synchronization, variable_def.aggregation, 1898 variable_def.trainable, variable_def.variable_name)) 1899 self._synchronization = synchronization 1900 self._aggregation = aggregation 1901 self._trainable = trainable 1902 self._snapshot = g.as_graph_element( 1903 ops.prepend_name_scope( 1904 variable_def.snapshot_name, import_scope=import_scope)) 1905 if variable_def.HasField("save_slice_info_def"): 1906 self._save_slice_info = Variable.SaveSliceInfo( 1907 save_slice_info_def=variable_def.save_slice_info_def, 1908 import_scope=import_scope) 1909 else: 1910 self._save_slice_info = None 1911 self._caching_device = None 1912 self._constraint = None 1913 1914 def _as_graph_element(self): 1915 """Conversion function for Graph.as_graph_element().""" 1916 return self._variable 1917 1918 def value(self): 1919 """Returns the last snapshot of this variable. 1920 1921 You usually do not need to call this method as all ops that need the value 1922 of the variable call it automatically through a `convert_to_tensor()` call. 1923 1924 Returns a `Tensor` which holds the value of the variable. You can not 1925 assign a new value to this tensor as it is not a reference to the variable. 1926 1927 To avoid copies, if the consumer of the returned value is on the same device 1928 as the variable, this actually returns the live value of the variable, not 1929 a copy. Updates to the variable are seen by the consumer. If the consumer 1930 is on a different device it will get a copy of the variable. 1931 1932 Returns: 1933 A `Tensor` containing the value of the variable. 1934 """ 1935 return self._snapshot 1936 1937 def read_value(self): 1938 """Returns the value of this variable, read in the current context. 1939 1940 Can be different from value() if it's on another device, with control 1941 dependencies, etc. 1942 1943 Returns: 1944 A `Tensor` containing the value of the variable. 1945 """ 1946 return array_ops.identity(self._variable, name="read") 1947 1948 def _ref(self): 1949 """Returns a reference to this variable. 1950 1951 You usually do not need to call this method as all ops that need a reference 1952 to the variable call it automatically. 1953 1954 Returns is a `Tensor` which holds a reference to the variable. You can 1955 assign a new value to the variable by passing the tensor to an assign op. 1956 See `tf.Variable.value` if you want to get the value of the 1957 variable. 1958 1959 Returns: 1960 A `Tensor` that is a reference to the variable. 1961 """ 1962 return self._variable 1963 1964 def set_shape(self, shape): 1965 """Overrides the shape for this variable. 1966 1967 Args: 1968 shape: the `TensorShape` representing the overridden shape. 1969 """ 1970 self._ref().set_shape(shape) 1971 self.value().set_shape(shape) 1972 1973 @property 1974 def trainable(self): 1975 return self._trainable 1976 1977 @property 1978 def synchronization(self): 1979 return self._synchronization 1980 1981 @property 1982 def aggregation(self): 1983 return self._aggregation 1984 1985 def eval(self, session=None): 1986 """In a session, computes and returns the value of this variable. 1987 1988 This is not a graph construction method, it does not add ops to the graph. 1989 1990 This convenience method requires a session where the graph 1991 containing this variable has been launched. If no session is 1992 passed, the default session is used. See `tf.compat.v1.Session` for more 1993 information on launching a graph and on sessions. 1994 1995 ```python 1996 v = tf.Variable([1, 2]) 1997 init = tf.compat.v1.global_variables_initializer() 1998 1999 with tf.compat.v1.Session() as sess: 2000 sess.run(init) 2001 # Usage passing the session explicitly. 2002 print(v.eval(sess)) 2003 # Usage with the default session. The 'with' block 2004 # above makes 'sess' the default session. 2005 print(v.eval()) 2006 ``` 2007 2008 Args: 2009 session: The session to use to evaluate this variable. If none, the 2010 default session is used. 2011 2012 Returns: 2013 A numpy `ndarray` with a copy of the value of this variable. 2014 """ 2015 return self._variable.eval(session=session) 2016 2017 @property 2018 def initial_value(self): 2019 """Returns the Tensor used as the initial value for the variable. 2020 2021 Note that this is different from `initialized_value()` which runs 2022 the op that initializes the variable before returning its value. 2023 This method returns the tensor that is used by the op that initializes 2024 the variable. 2025 2026 Returns: 2027 A `Tensor`. 2028 """ 2029 return self._initial_value 2030 2031 @property 2032 def constraint(self): 2033 """Returns the constraint function associated with this variable. 2034 2035 Returns: 2036 The constraint function that was passed to the variable constructor. 2037 Can be `None` if no constraint was passed. 2038 """ 2039 return self._constraint 2040 2041 def assign(self, value, use_locking=False, name=None, read_value=True): 2042 """Assigns a new value to the variable. 2043 2044 This is essentially a shortcut for `assign(self, value)`. 2045 2046 Args: 2047 value: A `Tensor`. The new value for this variable. 2048 use_locking: If `True`, use locking during the assignment. 2049 name: The name of the operation to be created 2050 read_value: if True, will return something which evaluates to the new 2051 value of the variable; if False will return the assign op. 2052 2053 Returns: 2054 A `Tensor` that will hold the new value of this variable after 2055 the assignment has completed. 2056 """ 2057 assign = state_ops.assign( 2058 self._variable, value, use_locking=use_locking, name=name) 2059 if read_value: 2060 return assign 2061 return assign.op 2062 2063 def assign_add(self, delta, use_locking=False, name=None, read_value=True): 2064 """Adds a value to this variable. 2065 2066 This is essentially a shortcut for `assign_add(self, delta)`. 2067 2068 Args: 2069 delta: A `Tensor`. The value to add to this variable. 2070 use_locking: If `True`, use locking during the operation. 2071 name: The name of the operation to be created 2072 read_value: if True, will return something which evaluates to the new 2073 value of the variable; if False will return the assign op. 2074 2075 Returns: 2076 A `Tensor` that will hold the new value of this variable after 2077 the addition has completed. 2078 """ 2079 assign = state_ops.assign_add( 2080 self._variable, delta, use_locking=use_locking, name=name) 2081 if read_value: 2082 return assign 2083 return assign.op 2084 2085 def assign_sub(self, delta, use_locking=False, name=None, read_value=True): 2086 """Subtracts a value from this variable. 2087 2088 This is essentially a shortcut for `assign_sub(self, delta)`. 2089 2090 Args: 2091 delta: A `Tensor`. The value to subtract from this variable. 2092 use_locking: If `True`, use locking during the operation. 2093 name: The name of the operation to be created 2094 read_value: if True, will return something which evaluates to the new 2095 value of the variable; if False will return the assign op. 2096 2097 Returns: 2098 A `Tensor` that will hold the new value of this variable after 2099 the subtraction has completed. 2100 """ 2101 assign = state_ops.assign_sub( 2102 self._variable, delta, use_locking=use_locking, name=name) 2103 if read_value: 2104 return assign 2105 return assign.op 2106 2107 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 2108 """Subtracts `tf.IndexedSlices` from this variable. 2109 2110 Args: 2111 sparse_delta: `tf.IndexedSlices` to be subtracted from this variable. 2112 use_locking: If `True`, use locking during the operation. 2113 name: the name of the operation. 2114 2115 Returns: 2116 A `Tensor` that will hold the new value of this variable after 2117 the scattered subtraction has completed. 2118 2119 Raises: 2120 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2121 """ 2122 if not isinstance(sparse_delta, ops.IndexedSlices): 2123 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2124 return gen_state_ops.scatter_sub( 2125 self._variable, 2126 sparse_delta.indices, 2127 sparse_delta.values, 2128 use_locking=use_locking, 2129 name=name) 2130 2131 def scatter_add(self, sparse_delta, use_locking=False, name=None): 2132 """Adds `tf.IndexedSlices` to this variable. 2133 2134 Args: 2135 sparse_delta: `tf.IndexedSlices` to be added to this variable. 2136 use_locking: If `True`, use locking during the operation. 2137 name: the name of the operation. 2138 2139 Returns: 2140 A `Tensor` that will hold the new value of this variable after 2141 the scattered addition has completed. 2142 2143 Raises: 2144 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2145 """ 2146 if not isinstance(sparse_delta, ops.IndexedSlices): 2147 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2148 return gen_state_ops.scatter_add( 2149 self._variable, 2150 sparse_delta.indices, 2151 sparse_delta.values, 2152 use_locking=use_locking, 2153 name=name) 2154 2155 def scatter_max(self, sparse_delta, use_locking=False, name=None): 2156 """Updates this variable with the max of `tf.IndexedSlices` and itself. 2157 2158 Args: 2159 sparse_delta: `tf.IndexedSlices` to use as an argument of max with this 2160 variable. 2161 use_locking: If `True`, use locking during the operation. 2162 name: the name of the operation. 2163 2164 Returns: 2165 A `Tensor` that will hold the new value of this variable after 2166 the scattered maximization has completed. 2167 2168 Raises: 2169 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2170 """ 2171 if not isinstance(sparse_delta, ops.IndexedSlices): 2172 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2173 return gen_state_ops.scatter_max( 2174 self._variable, 2175 sparse_delta.indices, 2176 sparse_delta.values, 2177 use_locking=use_locking, 2178 name=name) 2179 2180 def scatter_min(self, sparse_delta, use_locking=False, name=None): 2181 """Updates this variable with the min of `tf.IndexedSlices` and itself. 2182 2183 Args: 2184 sparse_delta: `tf.IndexedSlices` to use as an argument of min with this 2185 variable. 2186 use_locking: If `True`, use locking during the operation. 2187 name: the name of the operation. 2188 2189 Returns: 2190 A `Tensor` that will hold the new value of this variable after 2191 the scattered minimization has completed. 2192 2193 Raises: 2194 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2195 """ 2196 if not isinstance(sparse_delta, ops.IndexedSlices): 2197 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2198 return gen_state_ops.scatter_min( 2199 self._variable, 2200 sparse_delta.indices, 2201 sparse_delta.values, 2202 use_locking=use_locking, 2203 name=name) 2204 2205 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 2206 """Multiply this variable by `tf.IndexedSlices`. 2207 2208 Args: 2209 sparse_delta: `tf.IndexedSlices` to multiply this variable by. 2210 use_locking: If `True`, use locking during the operation. 2211 name: the name of the operation. 2212 2213 Returns: 2214 A `Tensor` that will hold the new value of this variable after 2215 the scattered multiplication has completed. 2216 2217 Raises: 2218 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2219 """ 2220 if not isinstance(sparse_delta, ops.IndexedSlices): 2221 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2222 return gen_state_ops.scatter_mul( 2223 self._variable, 2224 sparse_delta.indices, 2225 sparse_delta.values, 2226 use_locking=use_locking, 2227 name=name) 2228 2229 def scatter_div(self, sparse_delta, use_locking=False, name=None): 2230 """Divide this variable by `tf.IndexedSlices`. 2231 2232 Args: 2233 sparse_delta: `tf.IndexedSlices` to divide this variable by. 2234 use_locking: If `True`, use locking during the operation. 2235 name: the name of the operation. 2236 2237 Returns: 2238 A `Tensor` that will hold the new value of this variable after 2239 the scattered division has completed. 2240 2241 Raises: 2242 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2243 """ 2244 if not isinstance(sparse_delta, ops.IndexedSlices): 2245 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2246 return gen_state_ops.scatter_div( 2247 self._variable, 2248 sparse_delta.indices, 2249 sparse_delta.values, 2250 use_locking=use_locking, 2251 name=name) 2252 2253 def scatter_update(self, sparse_delta, use_locking=False, name=None): 2254 """Assigns `tf.IndexedSlices` to this variable. 2255 2256 Args: 2257 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 2258 use_locking: If `True`, use locking during the operation. 2259 name: the name of the operation. 2260 2261 Returns: 2262 A `Tensor` that will hold the new value of this variable after 2263 the scattered assignment has completed. 2264 2265 Raises: 2266 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2267 """ 2268 if not isinstance(sparse_delta, ops.IndexedSlices): 2269 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 2270 return gen_state_ops.scatter_update( 2271 self._variable, 2272 sparse_delta.indices, 2273 sparse_delta.values, 2274 use_locking=use_locking, 2275 name=name) 2276 2277 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 2278 """Assigns `tf.IndexedSlices` to this variable batch-wise. 2279 2280 Analogous to `batch_gather`. This assumes that this variable and the 2281 sparse_delta IndexedSlices have a series of leading dimensions that are the 2282 same for all of them, and the updates are performed on the last dimension of 2283 indices. In other words, the dimensions should be the following: 2284 2285 `num_prefix_dims = sparse_delta.indices.ndims - 1` 2286 `batch_dim = num_prefix_dims + 1` 2287 `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[ 2288 batch_dim:]` 2289 2290 where 2291 2292 `sparse_delta.updates.shape[:num_prefix_dims]` 2293 `== sparse_delta.indices.shape[:num_prefix_dims]` 2294 `== var.shape[:num_prefix_dims]` 2295 2296 And the operation performed can be expressed as: 2297 2298 `var[i_1, ..., i_n, 2299 sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[ 2300 i_1, ..., i_n, j]` 2301 2302 When sparse_delta.indices is a 1D tensor, this operation is equivalent to 2303 `scatter_update`. 2304 2305 To avoid this operation one can looping over the first `ndims` of the 2306 variable and using `scatter_update` on the subtensors that result of slicing 2307 the first dimension. This is a valid option for `ndims = 1`, but less 2308 efficient than this implementation. 2309 2310 Args: 2311 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 2312 use_locking: If `True`, use locking during the operation. 2313 name: the name of the operation. 2314 2315 Returns: 2316 A `Tensor` that will hold the new value of this variable after 2317 the scattered assignment has completed. 2318 2319 Raises: 2320 TypeError: if `sparse_delta` is not an `IndexedSlices`. 2321 """ 2322 return state_ops.batch_scatter_update( 2323 self, 2324 sparse_delta.indices, 2325 sparse_delta.values, 2326 use_locking=use_locking, 2327 name=name) 2328 2329 def scatter_nd_sub(self, indices, updates, name=None): 2330 """Applies sparse subtraction to individual values or slices in a Variable. 2331 2332 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 2333 2334 `indices` must be integer tensor, containing indices into `ref`. 2335 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 2336 2337 The innermost dimension of `indices` (with length `K`) corresponds to 2338 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 2339 dimension of `ref`. 2340 2341 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 2342 2343 ``` 2344 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 2345 ``` 2346 2347 For example, say we want to add 4 scattered elements to a rank-1 tensor to 2348 8 elements. In Python, that update would look like this: 2349 2350 ```python 2351 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 2352 indices = tf.constant([[4], [3], [1] ,[7]]) 2353 updates = tf.constant([9, 10, 11, 12]) 2354 op = ref.scatter_nd_sub(indices, updates) 2355 with tf.compat.v1.Session() as sess: 2356 print sess.run(op) 2357 ``` 2358 2359 The resulting update to ref would look like this: 2360 2361 [1, -9, 3, -6, -6, 6, 7, -4] 2362 2363 See `tf.scatter_nd` for more details about how to make updates to 2364 slices. 2365 2366 Args: 2367 indices: The indices to be used in the operation. 2368 updates: The values to be used in the operation. 2369 name: the name of the operation. 2370 2371 Returns: 2372 A `Tensor` that will hold the new value of this variable after 2373 the scattered subtraction has completed. 2374 """ 2375 return gen_state_ops.scatter_nd_sub( 2376 self._variable, indices, updates, use_locking=True, name=name) 2377 2378 def scatter_nd_add(self, indices, updates, name=None): 2379 """Applies sparse addition to individual values or slices in a Variable. 2380 2381 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 2382 2383 `indices` must be integer tensor, containing indices into `ref`. 2384 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 2385 2386 The innermost dimension of `indices` (with length `K`) corresponds to 2387 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 2388 dimension of `ref`. 2389 2390 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 2391 2392 ``` 2393 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 2394 ``` 2395 2396 For example, say we want to add 4 scattered elements to a rank-1 tensor to 2397 8 elements. In Python, that update would look like this: 2398 2399 ```python 2400 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 2401 indices = tf.constant([[4], [3], [1] ,[7]]) 2402 updates = tf.constant([9, 10, 11, 12]) 2403 add = ref.scatter_nd_add(indices, updates) 2404 with tf.compat.v1.Session() as sess: 2405 print sess.run(add) 2406 ``` 2407 2408 The resulting update to ref would look like this: 2409 2410 [1, 13, 3, 14, 14, 6, 7, 20] 2411 2412 See `tf.scatter_nd` for more details about how to make updates to 2413 slices. 2414 2415 Args: 2416 indices: The indices to be used in the operation. 2417 updates: The values to be used in the operation. 2418 name: the name of the operation. 2419 2420 Returns: 2421 A `Tensor` that will hold the new value of this variable after 2422 the scattered addition has completed. 2423 """ 2424 return gen_state_ops.scatter_nd_add( 2425 self._variable, indices, updates, use_locking=True, name=name) 2426 2427 def scatter_nd_update(self, indices, updates, name=None): 2428 """Applies sparse assignment to individual values or slices in a Variable. 2429 2430 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 2431 2432 `indices` must be integer tensor, containing indices into `ref`. 2433 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 2434 2435 The innermost dimension of `indices` (with length `K`) corresponds to 2436 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 2437 dimension of `ref`. 2438 2439 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 2440 2441 ``` 2442 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 2443 ``` 2444 2445 For example, say we want to add 4 scattered elements to a rank-1 tensor to 2446 8 elements. In Python, that update would look like this: 2447 2448 ```python 2449 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 2450 indices = tf.constant([[4], [3], [1] ,[7]]) 2451 updates = tf.constant([9, 10, 11, 12]) 2452 op = ref.scatter_nd_update(indices, updates) 2453 with tf.compat.v1.Session() as sess: 2454 print sess.run(op) 2455 ``` 2456 2457 The resulting update to ref would look like this: 2458 2459 [1, 11, 3, 10, 9, 6, 7, 12] 2460 2461 See `tf.scatter_nd` for more details about how to make updates to 2462 slices. 2463 2464 Args: 2465 indices: The indices to be used in the operation. 2466 updates: The values to be used in the operation. 2467 name: the name of the operation. 2468 2469 Returns: 2470 A `Tensor` that will hold the new value of this variable after 2471 the scattered assignment has completed. 2472 """ 2473 return gen_state_ops.scatter_nd_update( 2474 self._variable, indices, updates, use_locking=True, name=name) 2475 2476 def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask, 2477 end_mask, ellipsis_mask, new_axis_mask, 2478 shrink_axis_mask): 2479 return gen_array_ops.strided_slice_assign( 2480 ref=self._ref(), 2481 begin=begin, 2482 end=end, 2483 strides=strides, 2484 value=value, 2485 name=name, 2486 begin_mask=begin_mask, 2487 end_mask=end_mask, 2488 ellipsis_mask=ellipsis_mask, 2489 new_axis_mask=new_axis_mask, 2490 shrink_axis_mask=shrink_axis_mask) 2491 2492 @deprecated(None, "Prefer Dataset.range instead.") 2493 def count_up_to(self, limit): 2494 """Increments this variable until it reaches `limit`. 2495 2496 When that Op is run it tries to increment the variable by `1`. If 2497 incrementing the variable would bring it above `limit` then the Op raises 2498 the exception `OutOfRangeError`. 2499 2500 If no error is raised, the Op outputs the value of the variable before 2501 the increment. 2502 2503 This is essentially a shortcut for `count_up_to(self, limit)`. 2504 2505 Args: 2506 limit: value at which incrementing the variable raises an error. 2507 2508 Returns: 2509 A `Tensor` that will hold the variable value before the increment. If no 2510 other Op modifies this variable, the values produced will all be 2511 distinct. 2512 """ 2513 return state_ops.count_up_to(self._variable, limit=limit) 2514 2515 # Conversion to tensor. 2516 @staticmethod 2517 def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name 2518 """Utility function for converting a Variable to a Tensor.""" 2519 _ = name 2520 if dtype and not dtype.is_compatible_with(v.dtype): 2521 raise ValueError( 2522 "Incompatible type conversion requested to type '%s' for variable " 2523 "of type '%s'" % (dtype.name, v.dtype.name)) 2524 if as_ref: 2525 return v._ref() # pylint: disable=protected-access 2526 else: 2527 return v.value() 2528 2529 # NOTE(mrry): This enables the Variable's overloaded "right" binary 2530 # operators to run when the left operand is an ndarray, because it 2531 # accords the Variable class higher priority than an ndarray, or a 2532 # numpy matrix. 2533 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ 2534 # mechanism, which allows more control over how Variables interact 2535 # with ndarrays. 2536 __array_priority__ = 100 2537 2538 @property 2539 def name(self): 2540 """The name of this variable.""" 2541 return self._name 2542 2543 @property 2544 def initializer(self): 2545 """The initializer operation for this variable.""" 2546 return self._initializer_op 2547 2548 @property 2549 def device(self): 2550 """The device of this variable.""" 2551 return self._variable.device 2552 2553 @property 2554 def dtype(self): 2555 """The `DType` of this variable.""" 2556 return self._variable.dtype 2557 2558 @property 2559 def op(self): 2560 """The `Operation` of this variable.""" 2561 return self._variable.op 2562 2563 @property 2564 def graph(self): 2565 """The `Graph` of this variable.""" 2566 return self._variable.graph 2567 2568 @property 2569 def _distribute_strategy(self): 2570 """The `tf.distribute.Strategy` that this variable was created under.""" 2571 return None # Ref variables are never created inside a strategy. 2572 2573 @property 2574 def shape(self): 2575 """The `TensorShape` of this variable. 2576 2577 Returns: 2578 A `TensorShape`. 2579 """ 2580 return self._variable.get_shape() 2581 2582 def to_proto(self, export_scope=None): 2583 """Converts a `Variable` to a `VariableDef` protocol buffer. 2584 2585 Args: 2586 export_scope: Optional `string`. Name scope to remove. 2587 2588 Returns: 2589 A `VariableDef` protocol buffer, or `None` if the `Variable` is not 2590 in the specified name scope. 2591 """ 2592 if (export_scope is None or self._variable.name.startswith(export_scope)): 2593 var_def = variable_pb2.VariableDef() 2594 var_def.variable_name = ops.strip_name_scope(self._variable.name, 2595 export_scope) 2596 if self._initial_value is not None: 2597 # For backwards compatibility. 2598 var_def.initial_value_name = ops.strip_name_scope( 2599 self._initial_value.name, export_scope) 2600 var_def.trainable = self.trainable 2601 var_def.synchronization = self.synchronization.value 2602 var_def.aggregation = self.aggregation.value 2603 var_def.initializer_name = ops.strip_name_scope(self.initializer.name, 2604 export_scope) 2605 var_def.snapshot_name = ops.strip_name_scope(self._snapshot.name, 2606 export_scope) 2607 if self._save_slice_info: 2608 var_def.save_slice_info_def.MergeFrom( 2609 self._save_slice_info.to_proto(export_scope=export_scope)) 2610 return var_def 2611 else: 2612 return None 2613 2614 def __iadd__(self, other): 2615 logging.log_first_n( 2616 logging.WARN, "Variable += will be deprecated. Use variable.assign_add" 2617 " if you want assignment to the variable value or 'x = x + y'" 2618 " if you want a new python Tensor object.", 1) 2619 return self + other 2620 2621 def __isub__(self, other): 2622 logging.log_first_n( 2623 logging.WARN, "Variable -= will be deprecated. Use variable.assign_sub" 2624 " if you want assignment to the variable value or 'x = x - y'" 2625 " if you want a new python Tensor object.", 1) 2626 return self - other 2627 2628 def __imul__(self, other): 2629 logging.log_first_n( 2630 logging.WARN, 2631 "Variable *= will be deprecated. Use `var.assign(var * other)`" 2632 " if you want assignment to the variable value or `x = x * y`" 2633 " if you want a new python Tensor object.", 1) 2634 return self * other 2635 2636 def __idiv__(self, other): 2637 logging.log_first_n( 2638 logging.WARN, 2639 "Variable /= will be deprecated. Use `var.assign(var / other)`" 2640 " if you want assignment to the variable value or `x = x / y`" 2641 " if you want a new python Tensor object.", 1) 2642 return self / other 2643 2644 def __itruediv__(self, other): 2645 logging.log_first_n( 2646 logging.WARN, 2647 "Variable /= will be deprecated. Use `var.assign(var / other)`" 2648 " if you want assignment to the variable value or `x = x / y`" 2649 " if you want a new python Tensor object.", 1) 2650 return self / other 2651 2652 def __irealdiv__(self, other): 2653 logging.log_first_n( 2654 logging.WARN, 2655 "Variable /= will be deprecated. Use `var.assign(var / other)`" 2656 " if you want assignment to the variable value or `x = x / y`" 2657 " if you want a new python Tensor object.", 1) 2658 return self / other 2659 2660 def __ipow__(self, other): 2661 logging.log_first_n( 2662 logging.WARN, 2663 "Variable **= will be deprecated. Use `var.assign(var ** other)`" 2664 " if you want assignment to the variable value or `x = x ** y`" 2665 " if you want a new python Tensor object.", 1) 2666 return self**other 2667 2668 2669def _try_guard_against_uninitialized_dependencies(name, initial_value): 2670 """Attempt to guard against dependencies on uninitialized variables. 2671 2672 Replace references to variables in `initial_value` with references to the 2673 variable's initialized values. The initialized values are essentially 2674 conditional TensorFlow graphs that return a variable's value if it is 2675 initialized or its `initial_value` if it hasn't been initialized. This 2676 replacement is done on a best effort basis: 2677 2678 - If the `initial_value` graph contains cycles, we don't do any 2679 replacements for that graph. 2680 - If the variables that `initial_value` depends on are not present in the 2681 `GLOBAL_VARIABLES` or `LOCAL_VARIABLES` we don't replace them. 2682 2683 In these cases, it is up to the caller to ensure that the `initial_value` 2684 graph uses initialized variables or that they guard access to variables 2685 using their `initialized_value` method. 2686 2687 Args: 2688 name: Variable name. 2689 initial_value: `Tensor`. The initial value. 2690 2691 Returns: 2692 A `Tensor` suitable to initialize a variable. 2693 Raises: 2694 TypeError: If `initial_value` is not a `Tensor`. 2695 """ 2696 if not isinstance(initial_value, ops.Tensor): 2697 raise TypeError("initial_value needs to be a Tensor: %s" % initial_value) 2698 2699 # Don't modify initial_value if it contains any cyclic dependencies. 2700 if _has_cycle(initial_value.op, state={}): 2701 return initial_value 2702 return _safe_initial_value_from_tensor(name, initial_value, op_cache={}) 2703 2704 2705_UNKNOWN, _STARTED, _FINISHED = range(3) 2706 2707 2708def _has_cycle(op, state): 2709 """Detect cycles in the dependencies of `initial_value`.""" 2710 op_state = state.get(op.name, _UNKNOWN) 2711 if op_state == _STARTED: 2712 return True 2713 elif op_state == _FINISHED: 2714 return False 2715 2716 state[op.name] = _STARTED 2717 for i in itertools.chain((i.op for i in op.inputs), op.control_inputs): 2718 if _has_cycle(i, state): 2719 return True 2720 state[op.name] = _FINISHED 2721 return False 2722 2723 2724def _safe_initial_value_from_tensor(name, tensor, op_cache): 2725 """Replace dependencies on variables with their initialized values. 2726 2727 Args: 2728 name: Variable name. 2729 tensor: A `Tensor`. The tensor to replace. 2730 op_cache: A dict mapping operation names to `Operation`s. Used to memoize 2731 the results so as to avoid creating redundant operations. 2732 2733 Returns: 2734 A `Tensor` compatible with `tensor`. Any inputs that lead to variable 2735 values will be replaced with a corresponding graph that uses the 2736 variable's initialized values. This is done on a best-effort basis. If no 2737 modifications need to be made then `tensor` will be returned unchanged. 2738 """ 2739 op = tensor.op 2740 new_op = op_cache.get(op.name) 2741 if new_op is None: 2742 new_op = _safe_initial_value_from_op(name, op, op_cache) 2743 op_cache[op.name] = new_op 2744 return new_op.outputs[tensor.value_index] 2745 2746 2747def _safe_initial_value_from_op(name, op, op_cache): 2748 """Replace dependencies on variables with their initialized values. 2749 2750 Args: 2751 name: Variable name. 2752 op: An `Operation`. The operation to replace. 2753 op_cache: A dict mapping operation names to `Operation`s. Used to memoize 2754 the results so as to avoid creating redundant operations. 2755 2756 Returns: 2757 An `Operation` compatible with `op`. Any inputs that lead to variable 2758 values will be replaced with a corresponding graph that uses the 2759 variable's initialized values. This is done on a best-effort basis. If no 2760 modifications need to be made then `op` will be returned unchanged. 2761 """ 2762 op_type = op.node_def.op 2763 if op_type in ("IsVariableInitialized", "VarIsInitializedOp", 2764 "ReadVariableOp", "If"): 2765 return op 2766 2767 # Attempt to find the initialized_value of any variable reference / handles. 2768 # TODO(b/70206927): Fix handling of ResourceVariables. 2769 if op_type in ("Variable", "VariableV2", "VarHandleOp"): 2770 initialized_value = _find_initialized_value_for_variable(op) 2771 return op if initialized_value is None else initialized_value.op 2772 2773 # Recursively build initializer expressions for inputs. 2774 modified = False 2775 new_op_inputs = [] 2776 for op_input in op.inputs: 2777 new_op_input = _safe_initial_value_from_tensor(name, op_input, op_cache) 2778 new_op_inputs.append(new_op_input) 2779 modified = modified or (new_op_input != op_input) 2780 2781 # If at least one input was modified, replace the op. 2782 if modified: 2783 new_op_type = op_type 2784 if new_op_type == "RefSwitch": 2785 new_op_type = "Switch" 2786 new_op_name = op.node_def.name + "_" + name 2787 new_op_name = new_op_name.replace(":", "_") 2788 return op.graph.create_op( 2789 new_op_type, 2790 new_op_inputs, 2791 op._output_types, # pylint: disable=protected-access 2792 name=new_op_name, 2793 attrs=op.node_def.attr) 2794 2795 return op 2796 2797 2798def _find_initialized_value_for_variable(variable_op): 2799 """Find the initialized value for a variable op. 2800 2801 To do so, lookup the variable op in the variables collection. 2802 2803 Args: 2804 variable_op: A variable `Operation`. 2805 2806 Returns: 2807 A `Tensor` representing the initialized value for the variable or `None` 2808 if the initialized value could not be found. 2809 """ 2810 try: 2811 var_names = [variable_op.node_def.name, variable_op.node_def.name + ":0"] 2812 for collection_name in (ops.GraphKeys.GLOBAL_VARIABLES, 2813 ops.GraphKeys.LOCAL_VARIABLES): 2814 for var in variable_op.graph.get_collection(collection_name): 2815 if var.name in var_names: 2816 return var.initialized_value() 2817 except AttributeError: 2818 # Return None when an incomplete user-defined variable type was put in 2819 # the collection. 2820 return None 2821 return None 2822 2823 2824class PartitionedVariable(object): 2825 """A container for partitioned `Variable` objects. 2826 2827 @compatibility(eager) `tf.PartitionedVariable` is not compatible with 2828 eager execution. Use `tf.Variable` instead which is compatible 2829 with both eager execution and graph construction. See [the 2830 TensorFlow Eager Execution 2831 guide](https://www.tensorflow.org/guide/eager#variables_and_optimizers) 2832 for details on how variables work in eager execution. 2833 @end_compatibility 2834 """ 2835 2836 def __init__(self, name, shape, dtype, variable_list, partitions): 2837 """Creates a new partitioned variable wrapper. 2838 2839 Variables passed via the variable_list must contain a save_slice_info 2840 field. Concatenation and iteration is in lexicographic order according 2841 to the var_offset property of the save_slice_info. 2842 2843 Args: 2844 name: String. Overall name of the variables. 2845 shape: List of integers. Overall shape of the variables. 2846 dtype: Type of the variables. 2847 variable_list: List of `Variable` that comprise this partitioned variable. 2848 partitions: List of integers. Number of partitions for each dimension. 2849 2850 Raises: 2851 TypeError: If `variable_list` is not a list of `Variable` objects, or 2852 `partitions` is not a list. 2853 ValueError: If `variable_list` is empty, or the `Variable` shape 2854 information does not match `shape`, or `partitions` has invalid values. 2855 """ 2856 if not isinstance(variable_list, (list, tuple)): 2857 raise TypeError("variable_list is not a list or tuple: %s" % 2858 variable_list) 2859 if not isinstance(partitions, (list, tuple)): 2860 raise TypeError("partitions is not a list or tuple: %s" % partitions) 2861 if not all(p >= 1 for p in partitions): 2862 raise ValueError("partition values must be positive: %s" % partitions) 2863 if not variable_list: 2864 raise ValueError("variable_list may not be empty") 2865 # pylint: disable=protected-access 2866 for v in variable_list: 2867 # Sort the variable_list lexicographically according to var offset value. 2868 if not all(v._get_save_slice_info() is not None for v in variable_list): 2869 raise ValueError( 2870 "All variables must have a save_slice_info available: %s" % 2871 [v.name for v in variable_list]) 2872 if len(shape) != len(partitions): 2873 raise ValueError("len(shape) != len(partitions): %s vs. %s" % 2874 (shape, partitions)) 2875 if v._get_save_slice_info().full_shape != shape: 2876 raise ValueError("All variables' full shapes must match shape: %s; " 2877 "but full shapes were: %s" % 2878 (shape, str([v._get_save_slice_info().full_shape]))) 2879 self._variable_list = sorted( 2880 variable_list, key=lambda v: v._get_save_slice_info().var_offset) 2881 # pylint: enable=protected-access 2882 2883 self._name = name 2884 self._shape = shape 2885 self._dtype = dtype 2886 self._partitions = partitions 2887 self._as_tensor = None 2888 2889 def __iter__(self): 2890 """Return an iterable for accessing the underlying partition Variables.""" 2891 return iter(self._variable_list) 2892 2893 def __len__(self): 2894 num_partition_axes = len(self._partition_axes()) 2895 if num_partition_axes > 1: 2896 raise ValueError("Cannot get a length for %d > 1 partition axes" % 2897 num_partition_axes) 2898 return len(self._variable_list) 2899 2900 def _partition_axes(self): 2901 if all(p == 1 for p in self._partitions): 2902 return [0] 2903 else: 2904 return [i for i, p in enumerate(self._partitions) if p > 1] 2905 2906 def _concat(self): 2907 """Returns the overall concatenated value as a `Tensor`. 2908 2909 This is different from using the partitioned variable directly as a tensor 2910 (through tensor conversion and `as_tensor`) in that it creates a new set of 2911 operations that keeps the control dependencies from its scope. 2912 2913 Returns: 2914 `Tensor` containing the concatenated value. 2915 """ 2916 if len(self._variable_list) == 1: 2917 with ops.name_scope(None): 2918 return array_ops.identity(self._variable_list[0], name=self._name) 2919 2920 partition_axes = self._partition_axes() 2921 2922 if len(partition_axes) > 1: 2923 raise NotImplementedError( 2924 "Cannot concatenate along more than one dimension: %s. " 2925 "Multi-axis partition concat is not supported" % str(partition_axes)) 2926 partition_ix = partition_axes[0] 2927 2928 with ops.name_scope(self._name + "/ConcatPartitions/"): 2929 concatenated = array_ops.concat(self._variable_list, partition_ix) 2930 2931 with ops.name_scope(None): 2932 return array_ops.identity(concatenated, name=self._name) 2933 2934 def as_tensor(self): 2935 """Returns the overall concatenated value as a `Tensor`. 2936 2937 The returned tensor will not inherit the control dependencies from the scope 2938 where the value is used, which is similar to getting the value of 2939 `Variable`. 2940 2941 Returns: 2942 `Tensor` containing the concatenated value. 2943 """ 2944 with ops.control_dependencies(None): 2945 return self._concat() 2946 2947 @staticmethod 2948 def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): 2949 # pylint: disable=invalid-name 2950 _ = name 2951 if dtype is not None and not dtype.is_compatible_with(v.dtype): 2952 raise ValueError( 2953 "Incompatible type conversion requested to type '%s' for variable " 2954 "of type '%s'" % (dtype.name, v.dtype.name)) 2955 if as_ref: 2956 raise NotImplementedError( 2957 "PartitionedVariable doesn't support being used as a reference.") 2958 else: 2959 return v.as_tensor() 2960 2961 @property 2962 def name(self): 2963 return self._name 2964 2965 @property 2966 def dtype(self): 2967 return self._dtype 2968 2969 @property 2970 def shape(self): 2971 return self.get_shape() 2972 2973 @property 2974 def _distribute_strategy(self): 2975 """The `tf.distribute.Strategy` that this variable was created under.""" 2976 # NOTE(yuefengz): Today, no partitioned variables in a distribute strategy. 2977 return None 2978 2979 def get_shape(self): 2980 return self._shape 2981 2982 def _get_variable_list(self): 2983 return self._variable_list 2984 2985 def _get_partitions(self): 2986 return self._partitions 2987 2988 def _apply_assign_fn(self, assign_fn, value): 2989 partition_axes = self._partition_axes() 2990 if len(partition_axes) > 1: 2991 raise NotImplementedError( 2992 "Cannot do assign action along more than one dimension: %s. " 2993 "Multi-axis partition assign action is not supported " % 2994 str(partition_axes)) 2995 if isinstance(value, list): 2996 assert len(value) == len(self._variable_list) 2997 value_list = value 2998 elif isinstance(value, PartitionedVariable): 2999 value_list = [var_part for var_part in value] 3000 else: 3001 partition_ix = partition_axes[0] 3002 size_splits_list = [ 3003 tensor_shape.dimension_value(var.shape[partition_ix]) 3004 for var in self._variable_list 3005 ] 3006 value_list = array_ops.split(value, size_splits_list, axis=partition_ix) 3007 3008 op_list = [ 3009 assign_fn(var, value_list[idx]) 3010 for idx, var in enumerate(self._variable_list) 3011 ] 3012 return op_list 3013 3014 def assign(self, value, use_locking=False, name=None, read_value=True): 3015 assign_fn = lambda var, r_value: var.assign( 3016 r_value, use_locking=use_locking, name=name, read_value=read_value) 3017 assign_list = self._apply_assign_fn(assign_fn, value) 3018 if read_value: 3019 return assign_list 3020 return [assign.op for assign in assign_list] 3021 3022 def assign_add(self, value, use_locking=False, name=None, read_value=True): 3023 assign_fn = lambda var, r_value: var.assign_add( 3024 r_value, use_locking=use_locking, name=name, read_value=read_value) 3025 assign_list = self._apply_assign_fn(assign_fn, value) 3026 if read_value: 3027 return assign_list 3028 return [assign.op for assign in assign_list] 3029 3030 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 3031 assign_fn = lambda var, r_value: var.assign_sub( 3032 r_value, use_locking=use_locking, name=name, read_value=read_value) 3033 assign_list = self._apply_assign_fn(assign_fn, value) 3034 if read_value: 3035 return assign_list 3036 return [assign.op for assign in assign_list] 3037 3038 3039# Register a conversion function which reads the value of the variable, 3040# allowing instances of the class to be used as tensors. 3041ops.register_tensor_conversion_function(RefVariable, 3042 RefVariable._TensorConversionFunction) # pylint: disable=protected-access 3043ops.register_dense_tensor_like_type(RefVariable) 3044 3045 3046@tf_export(v1=["global_variables"]) 3047def global_variables(scope=None): 3048 """Returns global variables. 3049 3050 Global variables are variables that are shared across machines in a 3051 distributed environment. The `Variable()` constructor or `get_variable()` 3052 automatically adds new variables to the graph collection 3053 `GraphKeys.GLOBAL_VARIABLES`. 3054 This convenience function returns the contents of that collection. 3055 3056 An alternative to global variables are local variables. See 3057 `tf.compat.v1.local_variables` 3058 3059 Args: 3060 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3061 include only items whose `name` attribute matches `scope` using 3062 `re.match`. Items without a `name` attribute are never returned if a scope 3063 is supplied. The choice of `re.match` means that a `scope` without special 3064 tokens filters by prefix. 3065 3066 Returns: 3067 A list of `Variable` objects. 3068 """ 3069 return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) 3070 3071 3072@tf_export(v1=["all_variables"]) 3073@deprecated("2017-03-02", "Please use tf.global_variables instead.") 3074def all_variables(): 3075 """Use `tf.compat.v1.global_variables` instead.""" 3076 return global_variables() 3077 3078 3079def _all_saveable_objects(scope=None): 3080 """Returns all variables and `SaveableObject`s that must be checkpointed. 3081 3082 Args: 3083 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3084 include only items whose `name` attribute matches `scope` using 3085 `re.match`. Items without a `name` attribute are never returned if a scope 3086 is supplied. The choice of `re.match` means that a `scope` without special 3087 tokens filters by prefix. 3088 3089 Returns: 3090 A list of `Variable` and `SaveableObject` to be checkpointed 3091 """ 3092 # TODO(andreasst): make this function public once things are settled. 3093 return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) + 3094 ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope)) 3095 3096 3097@tf_export(v1=["local_variables"]) 3098def local_variables(scope=None): 3099 """Returns local variables. 3100 3101 Local variables - per process variables, usually not saved/restored to 3102 checkpoint and used for temporary or intermediate values. 3103 For example, they can be used as counters for metrics computation or 3104 number of epochs this machine has read data. 3105 The `tf.contrib.framework.local_variable()` function automatically adds the 3106 new variable to `GraphKeys.LOCAL_VARIABLES`. 3107 This convenience function returns the contents of that collection. 3108 3109 An alternative to local variables are global variables. See 3110 `tf.compat.v1.global_variables` 3111 3112 Args: 3113 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3114 include only items whose `name` attribute matches `scope` using 3115 `re.match`. Items without a `name` attribute are never returned if a scope 3116 is supplied. The choice of `re.match` means that a `scope` without special 3117 tokens filters by prefix. 3118 3119 Returns: 3120 A list of local `Variable` objects. 3121 """ 3122 return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES, scope) 3123 3124 3125@tf_export(v1=["model_variables"]) 3126def model_variables(scope=None): 3127 """Returns all variables in the MODEL_VARIABLES collection. 3128 3129 Args: 3130 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3131 include only items whose `name` attribute matches `scope` using 3132 `re.match`. Items without a `name` attribute are never returned if a scope 3133 is supplied. The choice of `re.match` means that a `scope` without special 3134 tokens filters by prefix. 3135 3136 Returns: 3137 A list of local Variable objects. 3138 """ 3139 return ops.get_collection(ops.GraphKeys.MODEL_VARIABLES, scope) 3140 3141 3142@tf_export(v1=["trainable_variables"]) 3143def trainable_variables(scope=None): 3144 """Returns all variables created with `trainable=True`. 3145 3146 When passed `trainable=True`, the `Variable()` constructor automatically 3147 adds new variables to the graph collection 3148 `GraphKeys.TRAINABLE_VARIABLES`. This convenience function returns the 3149 contents of that collection. 3150 3151 Args: 3152 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3153 include only items whose `name` attribute matches `scope` using 3154 `re.match`. Items without a `name` attribute are never returned if a scope 3155 is supplied. The choice of `re.match` means that a `scope` without special 3156 tokens filters by prefix. 3157 3158 Returns: 3159 A list of Variable objects. 3160 """ 3161 return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, scope) 3162 3163 3164@tf_export(v1=["moving_average_variables"]) 3165def moving_average_variables(scope=None): 3166 """Returns all variables that maintain their moving averages. 3167 3168 If an `ExponentialMovingAverage` object is created and the `apply()` 3169 method is called on a list of variables, these variables will 3170 be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection. 3171 This convenience function returns the contents of that collection. 3172 3173 Args: 3174 scope: (Optional.) A string. If supplied, the resulting list is filtered to 3175 include only items whose `name` attribute matches `scope` using 3176 `re.match`. Items without a `name` attribute are never returned if a scope 3177 is supplied. The choice of `re.match` means that a `scope` without special 3178 tokens filters by prefix. 3179 3180 Returns: 3181 A list of Variable objects. 3182 """ 3183 return ops.get_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, scope) 3184 3185 3186@tf_export(v1=["initializers.variables", "variables_initializer"]) 3187def variables_initializer(var_list, name="init"): 3188 """Returns an Op that initializes a list of variables. 3189 3190 After you launch the graph in a session, you can run the returned Op to 3191 initialize all the variables in `var_list`. This Op runs all the 3192 initializers of the variables in `var_list` in parallel. 3193 3194 Calling `initialize_variables()` is equivalent to passing the list of 3195 initializers to `Group()`. 3196 3197 If `var_list` is empty, however, the function still returns an Op that can 3198 be run. That Op just has no effect. 3199 3200 Args: 3201 var_list: List of `Variable` objects to initialize. 3202 name: Optional name for the returned operation. 3203 3204 Returns: 3205 An Op that run the initializers of all the specified variables. 3206 """ 3207 if var_list and not context.executing_eagerly(): 3208 return control_flow_ops.group(*[v.initializer for v in var_list], name=name) 3209 return control_flow_ops.no_op(name=name) 3210 3211 3212@tf_export(v1=["initialize_variables"]) 3213@tf_should_use.should_use_result 3214@deprecated("2017-03-02", "Use `tf.variables_initializer` instead.") 3215def initialize_variables(var_list, name="init"): 3216 """See `tf.compat.v1.variables_initializer`.""" 3217 return variables_initializer(var_list, name=name) 3218 3219 3220@tf_export(v1=["initializers.global_variables", "global_variables_initializer"]) 3221def global_variables_initializer(): 3222 """Returns an Op that initializes global variables. 3223 3224 This is just a shortcut for `variables_initializer(global_variables())` 3225 3226 Returns: 3227 An Op that initializes global variables in the graph. 3228 """ 3229 if context.executing_eagerly(): 3230 return control_flow_ops.no_op(name="global_variables_initializer") 3231 return variables_initializer(global_variables()) 3232 3233 3234@tf_export(v1=["initialize_all_variables"]) 3235@tf_should_use.should_use_result 3236@deprecated("2017-03-02", "Use `tf.global_variables_initializer` instead.") 3237def initialize_all_variables(): 3238 """See `tf.compat.v1.global_variables_initializer`.""" 3239 return global_variables_initializer() 3240 3241 3242@tf_export(v1=["initializers.local_variables", "local_variables_initializer"]) 3243def local_variables_initializer(): 3244 """Returns an Op that initializes all local variables. 3245 3246 This is just a shortcut for `variables_initializer(local_variables())` 3247 3248 Returns: 3249 An Op that initializes all local variables in the graph. 3250 """ 3251 if context.executing_eagerly(): 3252 return control_flow_ops.no_op(name="local_variables_initializer") 3253 return variables_initializer(local_variables()) 3254 3255 3256@tf_export(v1=["initialize_local_variables"]) 3257@tf_should_use.should_use_result 3258@deprecated("2017-03-02", "Use `tf.local_variables_initializer` instead.") 3259def initialize_local_variables(): 3260 """See `tf.compat.v1.local_variables_initializer`.""" 3261 return local_variables_initializer() 3262 3263 3264@tf_export(v1=["is_variable_initialized"]) 3265@tf_should_use.should_use_result 3266def is_variable_initialized(variable): 3267 """Tests if a variable has been initialized. 3268 3269 Args: 3270 variable: A `Variable`. 3271 3272 Returns: 3273 Returns a scalar boolean Tensor, `True` if the variable has been 3274 initialized, `False` otherwise. 3275 """ 3276 return state_ops.is_variable_initialized(variable) 3277 3278 3279@tf_export(v1=["assert_variables_initialized"]) 3280@tf_should_use.should_use_result 3281def assert_variables_initialized(var_list=None): 3282 """Returns an Op to check if variables are initialized. 3283 3284 NOTE: This function is obsolete and will be removed in 6 months. Please 3285 change your implementation to use `report_uninitialized_variables()`. 3286 3287 When run, the returned Op will raise the exception `FailedPreconditionError` 3288 if any of the variables has not yet been initialized. 3289 3290 Note: This function is implemented by trying to fetch the values of the 3291 variables. If one of the variables is not initialized a message may be 3292 logged by the C++ runtime. This is expected. 3293 3294 Args: 3295 var_list: List of `Variable` objects to check. Defaults to the value of 3296 `global_variables().` 3297 3298 Returns: 3299 An Op, or None if there are no variables. 3300 """ 3301 if var_list is None: 3302 var_list = global_variables() + local_variables() 3303 # Backwards compatibility for old-style variables. TODO(touts): remove. 3304 if not var_list: 3305 var_list = [] 3306 for op in ops.get_default_graph().get_operations(): 3307 if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]: 3308 var_list.append(op.outputs[0]) 3309 if not var_list: 3310 return None 3311 else: 3312 ranks = [] 3313 for var in var_list: 3314 with ops.colocate_with(var.op): 3315 ranks.append(array_ops.rank_internal(var, optimize=False)) 3316 if len(ranks) == 1: 3317 return ranks[0] 3318 else: 3319 return array_ops.stack(ranks) 3320 3321 3322@tf_export(v1=["report_uninitialized_variables"]) 3323@tf_should_use.should_use_result 3324def report_uninitialized_variables(var_list=None, 3325 name="report_uninitialized_variables"): 3326 """Adds ops to list the names of uninitialized variables. 3327 3328 When run, it returns a 1-D tensor containing the names of uninitialized 3329 variables if there are any, or an empty array if there are none. 3330 3331 Args: 3332 var_list: List of `Variable` objects to check. Defaults to the value of 3333 `global_variables() + local_variables()` 3334 name: Optional name of the `Operation`. 3335 3336 Returns: 3337 A 1-D tensor containing names of the uninitialized variables, or an empty 3338 1-D tensor if there are no variables or no uninitialized variables. 3339 """ 3340 if var_list is None: 3341 var_list = global_variables() + local_variables() 3342 # Backwards compatibility for old-style variables. TODO(touts): remove. 3343 if not var_list: 3344 var_list = [] 3345 for op in ops.get_default_graph().get_operations(): 3346 if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]: 3347 var_list.append(op.outputs[0]) 3348 with ops.name_scope(name): 3349 # Run all operations on CPU 3350 if var_list: 3351 init_vars = [state_ops.is_variable_initialized(v) for v in var_list] 3352 local_device = os.environ.get( 3353 "TF_DEVICE_FOR_UNINITIALIZED_VARIABLE_REPORTING", "/cpu:0") 3354 with ops.device(local_device): 3355 if not var_list: 3356 # Return an empty tensor so we only need to check for returned tensor 3357 # size being 0 as an indication of model ready. 3358 return array_ops.constant([], dtype=dtypes.string) 3359 else: 3360 # Get a 1-D boolean tensor listing whether each variable is initialized. 3361 variables_mask = math_ops.logical_not(array_ops.stack(init_vars)) 3362 # Get a 1-D string tensor containing all the variable names. 3363 variable_names_tensor = array_ops.constant( 3364 [s.op.name for s in var_list]) 3365 # Return a 1-D tensor containing all the names of 3366 # uninitialized variables. 3367 return array_ops.boolean_mask(variable_names_tensor, variables_mask) 3368 3369 3370ops.register_tensor_conversion_function( 3371 PartitionedVariable, PartitionedVariable._TensorConversionFunction) # pylint: disable=protected-access 3372