1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Variable class.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import enum # pylint: disable=g-bad-import-order 21import functools 22import os 23import six 24 25from tensorflow.core.framework import attr_value_pb2 26from tensorflow.core.framework import variable_pb2 27from tensorflow.python.eager import context 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import gen_array_ops 34from tensorflow.python.ops import gen_state_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import state_ops 37from tensorflow.python.platform import tf_logging as logging 38from tensorflow.python.training.tracking import base as trackable 39from tensorflow.python.util import compat 40from tensorflow.python.util import tf_should_use 41from tensorflow.python.util.deprecation import deprecated 42from tensorflow.python.util.tf_export import tf_export 43 44 45def default_variable_creator(_, **kwds): 46 del kwds 47 raise NotImplementedError("variable_scope needs to be imported") 48 49 50def default_variable_creator_v2(_, **kwds): 51 del kwds 52 raise NotImplementedError("variable_scope needs to be imported") 53 54 55def _make_getter(captured_getter, captured_previous): 56 """To avoid capturing loop variables.""" 57 def getter(**kwargs): 58 return captured_getter(captured_previous, **kwargs) 59 return getter 60 61 62@tf_export("VariableSynchronization") 63class VariableSynchronization(enum.Enum): 64 """Indicates when a distributed variable will be synced. 65 66 * `AUTO`: Indicates that the synchronization will be determined by the current 67 `DistributionStrategy` (eg. With `MirroredStrategy` this would be 68 `ON_WRITE`). 69 * `NONE`: Indicates that there will only be one copy of the variable, so 70 there is no need to sync. 71 * `ON_WRITE`: Indicates that the variable will be updated across devices 72 every time it is written. 73 * `ON_READ`: Indicates that the variable will be aggregated across devices 74 when it is read (eg. when checkpointing or when evaluating an op that uses 75 the variable). 76 """ 77 AUTO = 0 78 NONE = 1 79 ON_WRITE = 2 80 ON_READ = 3 81 82 83@tf_export("VariableAggregation", v1=[]) 84class VariableAggregationV2(enum.Enum): 85 """Indicates how a distributed variable will be aggregated. 86 87 `tf.contrib.distribute.DistributionStrategy` distributes a model by making 88 multiple copies (called "replicas") acting data-parallel on different elements 89 of the input batch. When performing some variable-update operation, say 90 `var.assign_add(x)`, in a model, we need to resolve how to combine the 91 different values for `x` computed in the different replicas. 92 93 * `NONE`: This is the default, giving an error if you use a 94 variable-update operation with multiple replicas. 95 * `SUM`: Add the updates across replicas. 96 * `MEAN`: Take the arithmetic mean ("average") of the updates across replicas. 97 * `ONLY_FIRST_REPLICA`: This is for when every replica is performing the same 98 update, but we only want to perform the update once. Used, e.g., for the 99 global step counter. 100 """ 101 NONE = 0 102 SUM = 1 103 MEAN = 2 104 ONLY_FIRST_REPLICA = 3 105 106 def __hash__(self): 107 return hash(self.value) 108 109 def __eq__(self, other): 110 if self is other: 111 return True 112 elif isinstance(other, VariableAggregation): 113 return int(self.value) == int(other.value) 114 else: 115 return False 116 117 118@tf_export(v1=["VariableAggregation"]) 119class VariableAggregation(enum.Enum): 120 NONE = 0 121 SUM = 1 122 MEAN = 2 123 ONLY_FIRST_REPLICA = 3 124 ONLY_FIRST_TOWER = 3 # DEPRECATED 125 126 def __hash__(self): 127 return hash(self.value) 128 129 130VariableAggregation.__doc__ = ( 131 VariableAggregationV2.__doc__ + 132 "* `ONLY_FIRST_TOWER`: Deprecated alias for `ONLY_FIRST_REPLICA`.\n ") 133 134 135class VariableMetaclass(type): 136 """Metaclass to allow construction of tf.Variable to be overridden.""" 137 138 def _variable_v1_call(cls, 139 initial_value=None, 140 trainable=None, 141 collections=None, 142 validate_shape=True, 143 caching_device=None, 144 name=None, 145 variable_def=None, 146 dtype=None, 147 expected_shape=None, 148 import_scope=None, 149 constraint=None, 150 use_resource=None, 151 synchronization=VariableSynchronization.AUTO, 152 aggregation=VariableAggregation.NONE): 153 """Call on Variable class. Useful to force the signature.""" 154 previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs) 155 for _, getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access 156 previous_getter = _make_getter(getter, previous_getter) 157 158 # Reset `aggregation` that is explicitly set as `None` to the enum NONE. 159 if aggregation is None: 160 aggregation = VariableAggregation.NONE 161 return previous_getter( 162 initial_value=initial_value, 163 trainable=trainable, 164 collections=collections, 165 validate_shape=validate_shape, 166 caching_device=caching_device, 167 name=name, 168 variable_def=variable_def, 169 dtype=dtype, 170 expected_shape=expected_shape, 171 import_scope=import_scope, 172 constraint=constraint, 173 use_resource=use_resource, 174 synchronization=synchronization, 175 aggregation=aggregation) 176 177 def _variable_v2_call(cls, 178 initial_value=None, 179 trainable=None, 180 validate_shape=True, 181 caching_device=None, 182 name=None, 183 variable_def=None, 184 dtype=None, 185 import_scope=None, 186 constraint=None, 187 synchronization=VariableSynchronization.AUTO, 188 aggregation=VariableAggregation.NONE): 189 """Call on Variable class. Useful to force the signature.""" 190 previous_getter = lambda **kws: default_variable_creator_v2(None, **kws) 191 for _, getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access 192 previous_getter = _make_getter(getter, previous_getter) 193 194 # Reset `aggregation` that is explicitly set as `None` to the enum NONE. 195 if aggregation is None: 196 aggregation = VariableAggregation.NONE 197 return previous_getter( 198 initial_value=initial_value, 199 trainable=trainable, 200 validate_shape=validate_shape, 201 caching_device=caching_device, 202 name=name, 203 variable_def=variable_def, 204 dtype=dtype, 205 import_scope=import_scope, 206 constraint=constraint, 207 synchronization=synchronization, 208 aggregation=aggregation) 209 210 def __call__(cls, *args, **kwargs): 211 if cls is VariableV1: 212 return cls._variable_v1_call(*args, **kwargs) 213 elif cls is Variable: 214 return cls._variable_v2_call(*args, **kwargs) 215 else: 216 return super(VariableMetaclass, cls).__call__(*args, **kwargs) 217 218 219@tf_export("Variable", v1=[]) 220class Variable(six.with_metaclass(VariableMetaclass, 221 trackable.Trackable)): 222 """See the [Variables Guide](https://tensorflow.org/guide/variables). 223 224 A variable maintains state in the graph across calls to `run()`. You add a 225 variable to the graph by constructing an instance of the class `Variable`. 226 227 The `Variable()` constructor requires an initial value for the variable, 228 which can be a `Tensor` of any type and shape. The initial value defines the 229 type and shape of the variable. After construction, the type and shape of 230 the variable are fixed. The value can be changed using one of the assign 231 methods. 232 233 If you want to change the shape of a variable later you have to use an 234 `assign` Op with `validate_shape=False`. 235 236 Just like any `Tensor`, variables created with `Variable()` can be used as 237 inputs for other Ops in the graph. Additionally, all the operators 238 overloaded for the `Tensor` class are carried over to variables, so you can 239 also add nodes to the graph by just doing arithmetic on variables. 240 241 ```python 242 import tensorflow as tf 243 244 # Create a variable. 245 w = tf.Variable(<initial-value>, name=<optional-name>) 246 247 # Use the variable in the graph like any Tensor. 248 y = tf.matmul(w, ...another variable or tensor...) 249 250 # The overloaded operators are available too. 251 z = tf.sigmoid(w + y) 252 253 # Assign a new value to the variable with `assign()` or a related method. 254 w.assign(w + 1.0) 255 w.assign_add(1.0) 256 ``` 257 258 When you launch the graph, variables have to be explicitly initialized before 259 you can run Ops that use their value. You can initialize a variable by 260 running its *initializer op*, restoring the variable from a save file, or 261 simply running an `assign` Op that assigns a value to the variable. In fact, 262 the variable *initializer op* is just an `assign` Op that assigns the 263 variable's initial value to the variable itself. 264 265 ```python 266 # Launch the graph in a session. 267 with tf.Session() as sess: 268 # Run the variable initializer. 269 sess.run(w.initializer) 270 # ...you now can run ops that use the value of 'w'... 271 ``` 272 273 The most common initialization pattern is to use the convenience function 274 `global_variables_initializer()` to add an Op to the graph that initializes 275 all the variables. You then run that Op after launching the graph. 276 277 ```python 278 # Add an Op to initialize global variables. 279 init_op = tf.global_variables_initializer() 280 281 # Launch the graph in a session. 282 with tf.Session() as sess: 283 # Run the Op that initializes global variables. 284 sess.run(init_op) 285 # ...you can now run any Op that uses variable values... 286 ``` 287 288 If you need to create a variable with an initial value dependent on another 289 variable, use the other variable's `initialized_value()`. This ensures that 290 variables are initialized in the right order. 291 292 All variables are automatically collected in the graph where they are 293 created. By default, the constructor adds the new variable to the graph 294 collection `GraphKeys.GLOBAL_VARIABLES`. The convenience function 295 `global_variables()` returns the contents of that collection. 296 297 When building a machine learning model it is often convenient to distinguish 298 between variables holding the trainable model parameters and other variables 299 such as a `global step` variable used to count training steps. To make this 300 easier, the variable constructor supports a `trainable=<bool>` parameter. If 301 `True`, the new variable is also added to the graph collection 302 `GraphKeys.TRAINABLE_VARIABLES`. The convenience function 303 `trainable_variables()` returns the contents of this collection. The 304 various `Optimizer` classes use this collection as the default list of 305 variables to optimize. 306 307 WARNING: tf.Variable objects by default have a non-intuitive memory model. A 308 Variable is represented internally as a mutable Tensor which can 309 non-deterministically alias other Tensors in a graph. The set of operations 310 which consume a Variable and can lead to aliasing is undetermined and can 311 change across TensorFlow versions. Avoid writing code which relies on the 312 value of a Variable either changing or not changing as other operations 313 happen. For example, using Variable objects or simple functions thereof as 314 predicates in a `tf.cond` is dangerous and error-prone: 315 316 ``` 317 v = tf.Variable(True) 318 tf.cond(v, lambda: v.assign(False), my_false_fn) # Note: this is broken. 319 ``` 320 321 Here, adding `use_resource=True` when constructing the variable will 322 fix any nondeterminism issues: 323 324 ``` 325 v = tf.Variable(True, use_resource=True) 326 tf.cond(v, lambda: v.assign(False), my_false_fn) 327 ``` 328 329 To use the replacement for variables which does 330 not have these issues: 331 332 * Add `use_resource=True` when constructing `tf.Variable`; 333 * Call `tf.get_variable_scope().set_use_resource(True)` inside a 334 `tf.variable_scope` before the `tf.get_variable()` call. 335 """ 336 337 def __init__(self, 338 initial_value=None, 339 trainable=True, 340 validate_shape=True, 341 caching_device=None, 342 name=None, 343 variable_def=None, 344 dtype=None, 345 import_scope=None, 346 constraint=None, 347 synchronization=VariableSynchronization.AUTO, 348 aggregation=VariableAggregation.NONE): 349 """Creates a new variable with value `initial_value`. 350 351 The new variable is added to the graph collections listed in `collections`, 352 which defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 353 354 If `trainable` is `True` the variable is also added to the graph collection 355 `GraphKeys.TRAINABLE_VARIABLES`. 356 357 This constructor creates both a `variable` Op and an `assign` Op to set the 358 variable to its initial value. 359 360 Args: 361 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 362 which is the initial value for the Variable. The initial value must have 363 a shape specified unless `validate_shape` is set to False. Can also be a 364 callable with no argument that returns the initial value when called. In 365 that case, `dtype` must be specified. (Note that initializer functions 366 from init_ops.py must first be bound to a shape before being used here.) 367 trainable: If `True`, the default, GradientTapes automatically watch uses 368 of this variable. 369 validate_shape: If `False`, allows the variable to be initialized with a 370 value of unknown shape. If `True`, the default, the shape of 371 `initial_value` must be known. 372 caching_device: Optional device string describing where the Variable 373 should be cached for reading. Defaults to the Variable's device. 374 If not `None`, caches on another device. Typical use is to cache 375 on the device where the Ops using the Variable reside, to deduplicate 376 copying through `Switch` and other conditional statements. 377 name: Optional name for the variable. Defaults to `'Variable'` and gets 378 uniquified automatically. 379 variable_def: `VariableDef` protocol buffer. If not `None`, recreates 380 the Variable object with its contents, referencing the variable's nodes 381 in the graph, which must already exist. The graph is not changed. 382 `variable_def` and the other arguments are mutually exclusive. 383 dtype: If set, initial_value will be converted to the given type. 384 If `None`, either the datatype will be kept (if `initial_value` is 385 a Tensor), or `convert_to_tensor` will decide. 386 import_scope: Optional `string`. Name scope to add to the 387 `Variable.` Only used when initializing from protocol buffer. 388 constraint: An optional projection function to be applied to the variable 389 after being updated by an `Optimizer` (e.g. used to implement norm 390 constraints or value constraints for layer weights). The function must 391 take as input the unprojected Tensor representing the value of the 392 variable and return the Tensor for the projected value 393 (which must have the same shape). Constraints are not safe to 394 use when doing asynchronous distributed training. 395 synchronization: Indicates when a distributed a variable will be 396 aggregated. Accepted values are constants defined in the class 397 `tf.VariableSynchronization`. By default the synchronization is set to 398 `AUTO` and the current `DistributionStrategy` chooses 399 when to synchronize. If `synchronization` is set to `ON_READ`, 400 `trainable` must not be set to `True`. 401 aggregation: Indicates how a distributed variable will be aggregated. 402 Accepted values are constants defined in the class 403 `tf.VariableAggregation`. 404 405 Raises: 406 ValueError: If both `variable_def` and initial_value are specified. 407 ValueError: If the initial value is not specified, or does not have a 408 shape and `validate_shape` is `True`. 409 RuntimeError: If eager execution is enabled. 410 """ 411 raise NotImplementedError 412 413 def __repr__(self): 414 raise NotImplementedError 415 416 def value(self): 417 """Returns the last snapshot of this variable. 418 419 You usually do not need to call this method as all ops that need the value 420 of the variable call it automatically through a `convert_to_tensor()` call. 421 422 Returns a `Tensor` which holds the value of the variable. You can not 423 assign a new value to this tensor as it is not a reference to the variable. 424 425 To avoid copies, if the consumer of the returned value is on the same device 426 as the variable, this actually returns the live value of the variable, not 427 a copy. Updates to the variable are seen by the consumer. If the consumer 428 is on a different device it will get a copy of the variable. 429 430 Returns: 431 A `Tensor` containing the value of the variable. 432 """ 433 raise NotImplementedError 434 435 def read_value(self): 436 """Returns the value of this variable, read in the current context. 437 438 Can be different from value() if it's on another device, with control 439 dependencies, etc. 440 441 Returns: 442 A `Tensor` containing the value of the variable. 443 """ 444 raise NotImplementedError 445 446 def set_shape(self, shape): 447 """Overrides the shape for this variable. 448 449 Args: 450 shape: the `TensorShape` representing the overridden shape. 451 """ 452 raise NotImplementedError 453 454 @property 455 def trainable(self): 456 raise NotImplementedError 457 458 def eval(self, session=None): 459 """In a session, computes and returns the value of this variable. 460 461 This is not a graph construction method, it does not add ops to the graph. 462 463 This convenience method requires a session where the graph 464 containing this variable has been launched. If no session is 465 passed, the default session is used. See `tf.Session` for more 466 information on launching a graph and on sessions. 467 468 ```python 469 v = tf.Variable([1, 2]) 470 init = tf.global_variables_initializer() 471 472 with tf.Session() as sess: 473 sess.run(init) 474 # Usage passing the session explicitly. 475 print(v.eval(sess)) 476 # Usage with the default session. The 'with' block 477 # above makes 'sess' the default session. 478 print(v.eval()) 479 ``` 480 481 Args: 482 session: The session to use to evaluate this variable. If 483 none, the default session is used. 484 485 Returns: 486 A numpy `ndarray` with a copy of the value of this variable. 487 """ 488 raise NotImplementedError 489 490 @deprecated( 491 None, 492 "Use Variable.read_value. Variables in 2.X are initialized " 493 "automatically both in eager and graph (inside tf.defun) contexts.") 494 def initialized_value(self): 495 """Returns the value of the initialized variable. 496 497 You should use this instead of the variable itself to initialize another 498 variable with a value that depends on the value of this variable. 499 500 ```python 501 # Initialize 'v' with a random tensor. 502 v = tf.Variable(tf.truncated_normal([10, 40])) 503 # Use `initialized_value` to guarantee that `v` has been 504 # initialized before its value is used to initialize `w`. 505 # The random values are picked only once. 506 w = tf.Variable(v.initialized_value() * 2.0) 507 ``` 508 509 Returns: 510 A `Tensor` holding the value of this variable after its initializer 511 has run. 512 """ 513 with ops.init_scope(): 514 return control_flow_ops.cond(is_variable_initialized(self), 515 self.read_value, 516 lambda: self.initial_value) 517 518 @property 519 def initial_value(self): 520 """Returns the Tensor used as the initial value for the variable. 521 522 Note that this is different from `initialized_value()` which runs 523 the op that initializes the variable before returning its value. 524 This method returns the tensor that is used by the op that initializes 525 the variable. 526 527 Returns: 528 A `Tensor`. 529 """ 530 raise NotImplementedError 531 532 @property 533 def constraint(self): 534 """Returns the constraint function associated with this variable. 535 536 Returns: 537 The constraint function that was passed to the variable constructor. 538 Can be `None` if no constraint was passed. 539 """ 540 raise NotImplementedError 541 542 def assign(self, value, use_locking=False, name=None, read_value=True): 543 """Assigns a new value to the variable. 544 545 This is essentially a shortcut for `assign(self, value)`. 546 547 Args: 548 value: A `Tensor`. The new value for this variable. 549 use_locking: If `True`, use locking during the assignment. 550 name: The name of the operation to be created 551 read_value: if True, will return something which evaluates to the 552 new value of the variable; if False will return the assign op. 553 554 Returns: 555 A `Tensor` that will hold the new value of this variable after 556 the assignment has completed. 557 """ 558 raise NotImplementedError 559 560 def assign_add(self, delta, use_locking=False, name=None, read_value=True): 561 """Adds a value to this variable. 562 563 This is essentially a shortcut for `assign_add(self, delta)`. 564 565 Args: 566 delta: A `Tensor`. The value to add to this variable. 567 use_locking: If `True`, use locking during the operation. 568 name: The name of the operation to be created 569 read_value: if True, will return something which evaluates to the 570 new value of the variable; if False will return the assign op. 571 572 Returns: 573 A `Tensor` that will hold the new value of this variable after 574 the addition has completed. 575 """ 576 raise NotImplementedError 577 578 def assign_sub(self, delta, use_locking=False, name=None, read_value=True): 579 """Subtracts a value from this variable. 580 581 This is essentially a shortcut for `assign_sub(self, delta)`. 582 583 Args: 584 delta: A `Tensor`. The value to subtract from this variable. 585 use_locking: If `True`, use locking during the operation. 586 name: The name of the operation to be created 587 read_value: if True, will return something which evaluates to the 588 new value of the variable; if False will return the assign op. 589 590 Returns: 591 A `Tensor` that will hold the new value of this variable after 592 the subtraction has completed. 593 """ 594 raise NotImplementedError 595 596 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 597 """Subtracts `IndexedSlices` from this variable. 598 599 Args: 600 sparse_delta: `IndexedSlices` to be subtracted from this variable. 601 use_locking: If `True`, use locking during the operation. 602 name: the name of the operation. 603 604 Returns: 605 A `Tensor` that will hold the new value of this variable after 606 the scattered subtraction has completed. 607 608 Raises: 609 ValueError: if `sparse_delta` is not an `IndexedSlices`. 610 """ 611 raise NotImplementedError 612 613 def scatter_add(self, sparse_delta, use_locking=False, name=None): 614 """Adds `IndexedSlices` to this variable. 615 616 Args: 617 sparse_delta: `IndexedSlices` to be assigned to this variable. 618 use_locking: If `True`, use locking during the operation. 619 name: the name of the operation. 620 621 Returns: 622 A `Tensor` that will hold the new value of this variable after 623 the scattered subtraction has completed. 624 625 Raises: 626 ValueError: if `sparse_delta` is not an `IndexedSlices`. 627 """ 628 raise NotImplementedError 629 630 def scatter_update(self, sparse_delta, use_locking=False, name=None): 631 """Assigns `IndexedSlices` to this variable. 632 633 Args: 634 sparse_delta: `IndexedSlices` to be assigned to this variable. 635 use_locking: If `True`, use locking during the operation. 636 name: the name of the operation. 637 638 Returns: 639 A `Tensor` that will hold the new value of this variable after 640 the scattered subtraction has completed. 641 642 Raises: 643 ValueError: if `sparse_delta` is not an `IndexedSlices`. 644 """ 645 raise NotImplementedError 646 647 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 648 """Assigns `IndexedSlices` to this variable batch-wise. 649 650 Analogous to `batch_gather`. This assumes that this variable and the 651 sparse_delta IndexedSlices have a series of leading dimensions that are the 652 same for all of them, and the updates are performed on the last dimension of 653 indices. In other words, the dimensions should be the following: 654 655 `num_prefix_dims = sparse_delta.indices.ndims - 1` 656 `batch_dim = num_prefix_dims + 1` 657 `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[ 658 batch_dim:]` 659 660 where 661 662 `sparse_delta.updates.shape[:num_prefix_dims]` 663 `== sparse_delta.indices.shape[:num_prefix_dims]` 664 `== var.shape[:num_prefix_dims]` 665 666 And the operation performed can be expressed as: 667 668 `var[i_1, ..., i_n, 669 sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[ 670 i_1, ..., i_n, j]` 671 672 When sparse_delta.indices is a 1D tensor, this operation is equivalent to 673 `scatter_update`. 674 675 To avoid this operation one can looping over the first `ndims` of the 676 variable and using `scatter_update` on the subtensors that result of slicing 677 the first dimension. This is a valid option for `ndims = 1`, but less 678 efficient than this implementation. 679 680 Args: 681 sparse_delta: `IndexedSlices` to be assigned to this variable. 682 use_locking: If `True`, use locking during the operation. 683 name: the name of the operation. 684 685 Returns: 686 A `Tensor` that will hold the new value of this variable after 687 the scattered subtraction has completed. 688 689 Raises: 690 ValueError: if `sparse_delta` is not an `IndexedSlices`. 691 """ 692 raise NotImplementedError 693 694 def scatter_nd_sub(self, indices, updates, name=None): 695 """Applies sparse subtraction to individual values or slices in a Variable. 696 697 Assuming the variable has rank `P` and `indices` is a `Tensor` of rank `Q`. 698 699 `indices` must be integer tensor, containing indices into self. 700 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 701 702 The innermost dimension of `indices` (with length `K`) corresponds to 703 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 704 dimension of self. 705 706 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 707 708 ``` 709 [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]]. 710 ``` 711 712 For example, say we want to add 4 scattered elements to a rank-1 tensor to 713 8 elements. In Python, that update would look like this: 714 715 ```python 716 v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 717 indices = tf.constant([[4], [3], [1] ,[7]]) 718 updates = tf.constant([9, 10, 11, 12]) 719 op = v.scatter_nd_sub(indices, updates) 720 with tf.Session() as sess: 721 print sess.run(op) 722 ``` 723 724 The resulting update to v would look like this: 725 726 [1, -9, 3, -6, -6, 6, 7, -4] 727 728 See `tf.scatter_nd` for more details about how to make updates to 729 slices. 730 731 Args: 732 indices: The indices to be used in the operation. 733 updates: The values to be used in the operation. 734 name: the name of the operation. 735 736 Returns: 737 A `Tensor` that will hold the new value of this variable after 738 the scattered subtraction has completed. 739 740 Raises: 741 ValueError: if `sparse_delta` is not an `IndexedSlices`. 742 """ 743 raise NotImplementedError 744 745 def scatter_nd_add(self, indices, updates, name=None): 746 """Applies sparse addition to individual values or slices in a Variable. 747 748 The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`. 749 750 `indices` must be integer tensor, containing indices into self. 751 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 752 753 The innermost dimension of `indices` (with length `K`) corresponds to 754 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 755 dimension of self. 756 757 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 758 759 ``` 760 [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]]. 761 ``` 762 763 For example, say we want to add 4 scattered elements to a rank-1 tensor to 764 8 elements. In Python, that update would look like this: 765 766 ```python 767 v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 768 indices = tf.constant([[4], [3], [1] ,[7]]) 769 updates = tf.constant([9, 10, 11, 12]) 770 add = v.scatter_nd_add(indices, updates) 771 with tf.Session() as sess: 772 print sess.run(add) 773 ``` 774 775 The resulting update to v would look like this: 776 777 [1, 13, 3, 14, 14, 6, 7, 20] 778 779 See `tf.scatter_nd` for more details about how to make updates to 780 slices. 781 782 Args: 783 indices: The indices to be used in the operation. 784 updates: The values to be used in the operation. 785 name: the name of the operation. 786 787 Returns: 788 A `Tensor` that will hold the new value of this variable after 789 the scattered subtraction has completed. 790 791 Raises: 792 ValueError: if `sparse_delta` is not an `IndexedSlices`. 793 """ 794 raise NotImplementedError 795 796 def scatter_nd_update(self, indices, updates, name=None): 797 """Applies sparse assignment to individual values or slices in a Variable. 798 799 The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`. 800 801 `indices` must be integer tensor, containing indices into self. 802 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 803 804 The innermost dimension of `indices` (with length `K`) corresponds to 805 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 806 dimension of self. 807 808 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 809 810 ``` 811 [d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]]. 812 ``` 813 814 For example, say we want to add 4 scattered elements to a rank-1 tensor to 815 8 elements. In Python, that update would look like this: 816 817 ```python 818 v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 819 indices = tf.constant([[4], [3], [1] ,[7]]) 820 updates = tf.constant([9, 10, 11, 12]) 821 op = v.scatter_nd_assign(indices, updates) 822 with tf.Session() as sess: 823 print sess.run(op) 824 ``` 825 826 The resulting update to v would look like this: 827 828 [1, 11, 3, 10, 9, 6, 7, 12] 829 830 See `tf.scatter_nd` for more details about how to make updates to 831 slices. 832 833 Args: 834 indices: The indices to be used in the operation. 835 updates: The values to be used in the operation. 836 name: the name of the operation. 837 838 Returns: 839 A `Tensor` that will hold the new value of this variable after 840 the scattered subtraction has completed. 841 842 Raises: 843 ValueError: if `sparse_delta` is not an `IndexedSlices`. 844 """ 845 raise NotImplementedError 846 847 @deprecated(None, "Prefer Dataset.range instead.") 848 def count_up_to(self, limit): 849 """Increments this variable until it reaches `limit`. 850 851 When that Op is run it tries to increment the variable by `1`. If 852 incrementing the variable would bring it above `limit` then the Op raises 853 the exception `OutOfRangeError`. 854 855 If no error is raised, the Op outputs the value of the variable before 856 the increment. 857 858 This is essentially a shortcut for `count_up_to(self, limit)`. 859 860 Args: 861 limit: value at which incrementing the variable raises an error. 862 863 Returns: 864 A `Tensor` that will hold the variable value before the increment. If no 865 other Op modifies this variable, the values produced will all be 866 distinct. 867 """ 868 raise NotImplementedError 869 870 @deprecated( 871 None, 872 "Prefer Variable.assign which has equivalent behavior in 2.X.") 873 def load(self, value, session=None): 874 """Load new value into this variable. 875 876 Writes new value to variable's memory. Doesn't add ops to the graph. 877 878 This convenience method requires a session where the graph 879 containing this variable has been launched. If no session is 880 passed, the default session is used. See `tf.Session` for more 881 information on launching a graph and on sessions. 882 883 ```python 884 v = tf.Variable([1, 2]) 885 init = tf.global_variables_initializer() 886 887 with tf.Session() as sess: 888 sess.run(init) 889 # Usage passing the session explicitly. 890 v.load([2, 3], sess) 891 print(v.eval(sess)) # prints [2 3] 892 # Usage with the default session. The 'with' block 893 # above makes 'sess' the default session. 894 v.load([3, 4], sess) 895 print(v.eval()) # prints [3 4] 896 ``` 897 898 Args: 899 value: New variable value 900 session: The session to use to evaluate this variable. If 901 none, the default session is used. 902 903 Raises: 904 ValueError: Session is not passed and no default session 905 """ 906 if context.executing_eagerly(): 907 self.assign(value) 908 else: 909 session = session or ops.get_default_session() 910 if session is None: 911 raise ValueError( 912 "Either session argument should be provided or default session " 913 "should be established") 914 session.run(self.initializer, {self.initializer.inputs[1]: value}) 915 916 # Conversion to tensor. 917 @staticmethod 918 def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name 919 """Utility function for converting a Variable to a Tensor.""" 920 _ = name 921 if dtype and not dtype.is_compatible_with(v.dtype): 922 raise ValueError( 923 "Incompatible type conversion requested to type '%s' for variable " 924 "of type '%s'" % (dtype.name, v.dtype.name)) 925 if as_ref: 926 return v._ref() # pylint: disable=protected-access 927 else: 928 return v.value() 929 930 @classmethod 931 def _OverloadAllOperators(cls): # pylint: disable=invalid-name 932 """Register overloads for all operators.""" 933 for operator in ops.Tensor.OVERLOADABLE_OPERATORS: 934 cls._OverloadOperator(operator) 935 # For slicing, bind getitem differently than a tensor (use SliceHelperVar 936 # instead) 937 # pylint: disable=protected-access 938 setattr(cls, "__getitem__", array_ops._SliceHelperVar) 939 940 @classmethod 941 def _OverloadOperator(cls, operator): # pylint: disable=invalid-name 942 """Defer an operator overload to `ops.Tensor`. 943 944 We pull the operator out of ops.Tensor dynamically to avoid ordering issues. 945 946 Args: 947 operator: string. The operator name. 948 """ 949 tensor_oper = getattr(ops.Tensor, operator) 950 951 def _run_op(a, *args, **kwargs): 952 # pylint: disable=protected-access 953 return tensor_oper(a.value(), *args, **kwargs) 954 955 functools.update_wrapper(_run_op, tensor_oper) 956 setattr(cls, operator, _run_op) 957 958 def __iter__(self): 959 """Dummy method to prevent iteration. Do not call. 960 961 NOTE(mrry): If we register __getitem__ as an overloaded operator, 962 Python will valiantly attempt to iterate over the variable's Tensor from 0 963 to infinity. Declaring this method prevents this unintended behavior. 964 965 Raises: 966 TypeError: when invoked. 967 """ 968 raise TypeError("'Variable' object is not iterable.") 969 970 # NOTE(mrry): This enables the Variable's overloaded "right" binary 971 # operators to run when the left operand is an ndarray, because it 972 # accords the Variable class higher priority than an ndarray, or a 973 # numpy matrix. 974 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ 975 # mechanism, which allows more control over how Variables interact 976 # with ndarrays. 977 __array_priority__ = 100 978 979 @property 980 def name(self): 981 """The name of this variable.""" 982 raise NotImplementedError 983 984 @property 985 def _shared_name(self): 986 """The shared name of the variable. 987 988 Unlike name(), shared_name doesn't have ":0" suffix. It is user-specified 989 name with name scope prefix. 990 991 Returns: 992 variable name. 993 """ 994 return self.name[:self.name.index(":")] 995 996 @property 997 def initializer(self): 998 """The initializer operation for this variable.""" 999 raise NotImplementedError 1000 1001 @property 1002 def device(self): 1003 """The device of this variable.""" 1004 raise NotImplementedError 1005 1006 @property 1007 def dtype(self): 1008 """The `DType` of this variable.""" 1009 raise NotImplementedError 1010 1011 @property 1012 def op(self): 1013 """The `Operation` of this variable.""" 1014 raise NotImplementedError 1015 1016 @property 1017 def graph(self): 1018 """The `Graph` of this variable.""" 1019 raise NotImplementedError 1020 1021 @property 1022 def shape(self): 1023 """The `TensorShape` of this variable. 1024 1025 Returns: 1026 A `TensorShape`. 1027 """ 1028 raise NotImplementedError 1029 1030 def get_shape(self): 1031 """Alias of `Variable.shape`.""" 1032 return self.shape 1033 1034 def _gather_saveables_for_checkpoint(self): 1035 """For implementing `Trackable`. This object is saveable on its own.""" 1036 return {trackable.VARIABLE_VALUE_KEY: self} 1037 1038 def to_proto(self, export_scope=None): 1039 """Converts a `Variable` to a `VariableDef` protocol buffer. 1040 1041 Args: 1042 export_scope: Optional `string`. Name scope to remove. 1043 1044 Returns: 1045 A `VariableDef` protocol buffer, or `None` if the `Variable` is not 1046 in the specified name scope. 1047 """ 1048 raise NotImplementedError 1049 1050 @staticmethod 1051 def from_proto(variable_def, import_scope=None): 1052 """Returns a `Variable` object created from `variable_def`.""" 1053 return RefVariable(variable_def=variable_def, 1054 import_scope=import_scope) 1055 1056 def _set_save_slice_info(self, save_slice_info): 1057 """Sets the slice info for this `Variable`. 1058 1059 Args: 1060 save_slice_info: A `Variable.SaveSliceInfo` object. 1061 """ 1062 self._save_slice_info = save_slice_info 1063 1064 def _get_save_slice_info(self): 1065 return self._save_slice_info 1066 1067 class SaveSliceInfo(object): 1068 """Information on how to save this Variable as a slice. 1069 1070 Provides internal support for saving variables as slices of a larger 1071 variable. This API is not public and is subject to change. 1072 1073 Available properties: 1074 1075 * full_name 1076 * full_shape 1077 * var_offset 1078 * var_shape 1079 """ 1080 1081 def __init__(self, 1082 full_name=None, 1083 full_shape=None, 1084 var_offset=None, 1085 var_shape=None, 1086 save_slice_info_def=None, 1087 import_scope=None): 1088 """Create a `SaveSliceInfo`. 1089 1090 Args: 1091 full_name: Name of the full variable of which this `Variable` is a 1092 slice. 1093 full_shape: Shape of the full variable, as a list of int. 1094 var_offset: Offset of this `Variable` into the full variable, as a 1095 list of int. 1096 var_shape: Shape of this `Variable`, as a list of int. 1097 save_slice_info_def: `SaveSliceInfoDef` protocol buffer. If not `None`, 1098 recreates the SaveSliceInfo object its contents. 1099 `save_slice_info_def` and other arguments are mutually 1100 exclusive. 1101 import_scope: Optional `string`. Name scope to add. Only used 1102 when initializing from protocol buffer. 1103 """ 1104 if save_slice_info_def: 1105 assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef) 1106 self.full_name = ops.prepend_name_scope( 1107 save_slice_info_def.full_name, import_scope=import_scope) 1108 self.full_shape = [i for i in save_slice_info_def.full_shape] 1109 self.var_offset = [i for i in save_slice_info_def.var_offset] 1110 self.var_shape = [i for i in save_slice_info_def.var_shape] 1111 else: 1112 self.full_name = full_name 1113 self.full_shape = full_shape 1114 self.var_offset = var_offset 1115 self.var_shape = var_shape 1116 1117 @property 1118 def spec(self): 1119 """Computes the spec string used for saving.""" 1120 full_shape_str = " ".join(["%d" % d for d in self.full_shape]) + " " 1121 sl_spec = ":".join([ 1122 "%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape) 1123 ]) 1124 return full_shape_str + sl_spec 1125 1126 def to_proto(self, export_scope=None): 1127 """Returns a SaveSliceInfoDef() proto. 1128 1129 Args: 1130 export_scope: Optional `string`. Name scope to remove. 1131 1132 Returns: 1133 A `SaveSliceInfoDef` protocol buffer, or None if the `Variable` is not 1134 in the specified name scope. 1135 """ 1136 if (export_scope is None or 1137 self.full_name.startswith(export_scope)): 1138 save_slice_info_def = variable_pb2.SaveSliceInfoDef() 1139 save_slice_info_def.full_name = ops.strip_name_scope( 1140 self.full_name, export_scope) 1141 for i in self.full_shape: 1142 save_slice_info_def.full_shape.append(i) 1143 for i in self.var_offset: 1144 save_slice_info_def.var_offset.append(i) 1145 for i in self.var_shape: 1146 save_slice_info_def.var_shape.append(i) 1147 return save_slice_info_def 1148 else: 1149 return None 1150 1151 1152Variable._OverloadAllOperators() # pylint: disable=protected-access 1153 1154 1155@tf_export(v1=["Variable"]) 1156class VariableV1(Variable): 1157 """See the [Variables Guide](https://tensorflow.org/guide/variables). 1158 1159 A variable maintains state in the graph across calls to `run()`. You add a 1160 variable to the graph by constructing an instance of the class `Variable`. 1161 1162 The `Variable()` constructor requires an initial value for the variable, 1163 which can be a `Tensor` of any type and shape. The initial value defines the 1164 type and shape of the variable. After construction, the type and shape of 1165 the variable are fixed. The value can be changed using one of the assign 1166 methods. 1167 1168 If you want to change the shape of a variable later you have to use an 1169 `assign` Op with `validate_shape=False`. 1170 1171 Just like any `Tensor`, variables created with `Variable()` can be used as 1172 inputs for other Ops in the graph. Additionally, all the operators 1173 overloaded for the `Tensor` class are carried over to variables, so you can 1174 also add nodes to the graph by just doing arithmetic on variables. 1175 1176 ```python 1177 import tensorflow as tf 1178 1179 # Create a variable. 1180 w = tf.Variable(<initial-value>, name=<optional-name>) 1181 1182 # Use the variable in the graph like any Tensor. 1183 y = tf.matmul(w, ...another variable or tensor...) 1184 1185 # The overloaded operators are available too. 1186 z = tf.sigmoid(w + y) 1187 1188 # Assign a new value to the variable with `assign()` or a related method. 1189 w.assign(w + 1.0) 1190 w.assign_add(1.0) 1191 ``` 1192 1193 When you launch the graph, variables have to be explicitly initialized before 1194 you can run Ops that use their value. You can initialize a variable by 1195 running its *initializer op*, restoring the variable from a save file, or 1196 simply running an `assign` Op that assigns a value to the variable. In fact, 1197 the variable *initializer op* is just an `assign` Op that assigns the 1198 variable's initial value to the variable itself. 1199 1200 ```python 1201 # Launch the graph in a session. 1202 with tf.Session() as sess: 1203 # Run the variable initializer. 1204 sess.run(w.initializer) 1205 # ...you now can run ops that use the value of 'w'... 1206 ``` 1207 1208 The most common initialization pattern is to use the convenience function 1209 `global_variables_initializer()` to add an Op to the graph that initializes 1210 all the variables. You then run that Op after launching the graph. 1211 1212 ```python 1213 # Add an Op to initialize global variables. 1214 init_op = tf.global_variables_initializer() 1215 1216 # Launch the graph in a session. 1217 with tf.Session() as sess: 1218 # Run the Op that initializes global variables. 1219 sess.run(init_op) 1220 # ...you can now run any Op that uses variable values... 1221 ``` 1222 1223 If you need to create a variable with an initial value dependent on another 1224 variable, use the other variable's `initialized_value()`. This ensures that 1225 variables are initialized in the right order. 1226 1227 All variables are automatically collected in the graph where they are 1228 created. By default, the constructor adds the new variable to the graph 1229 collection `GraphKeys.GLOBAL_VARIABLES`. The convenience function 1230 `global_variables()` returns the contents of that collection. 1231 1232 When building a machine learning model it is often convenient to distinguish 1233 between variables holding the trainable model parameters and other variables 1234 such as a `global step` variable used to count training steps. To make this 1235 easier, the variable constructor supports a `trainable=<bool>` parameter. If 1236 `True`, the new variable is also added to the graph collection 1237 `GraphKeys.TRAINABLE_VARIABLES`. The convenience function 1238 `trainable_variables()` returns the contents of this collection. The 1239 various `Optimizer` classes use this collection as the default list of 1240 variables to optimize. 1241 1242 WARNING: tf.Variable objects by default have a non-intuitive memory model. A 1243 Variable is represented internally as a mutable Tensor which can 1244 non-deterministically alias other Tensors in a graph. The set of operations 1245 which consume a Variable and can lead to aliasing is undetermined and can 1246 change across TensorFlow versions. Avoid writing code which relies on the 1247 value of a Variable either changing or not changing as other operations 1248 happen. For example, using Variable objects or simple functions thereof as 1249 predicates in a `tf.cond` is dangerous and error-prone: 1250 1251 ``` 1252 v = tf.Variable(True) 1253 tf.cond(v, lambda: v.assign(False), my_false_fn) # Note: this is broken. 1254 ``` 1255 1256 Here, adding `use_resource=True` when constructing the variable will 1257 fix any nondeterminism issues: 1258 ``` 1259 v = tf.Variable(True, use_resource=True) 1260 tf.cond(v, lambda: v.assign(False), my_false_fn) 1261 ``` 1262 1263 To use the replacement for variables which does 1264 not have these issues: 1265 1266 * Add `use_resource=True` when constructing `tf.Variable`; 1267 * Call `tf.get_variable_scope().set_use_resource(True)` inside a 1268 `tf.variable_scope` before the `tf.get_variable()` call. 1269 """ 1270 1271 def __init__(self, # pylint: disable=super-init-not-called 1272 initial_value=None, 1273 trainable=True, 1274 collections=None, 1275 validate_shape=True, 1276 caching_device=None, 1277 name=None, 1278 variable_def=None, 1279 dtype=None, 1280 expected_shape=None, 1281 import_scope=None, 1282 constraint=None, 1283 use_resource=None, 1284 synchronization=VariableSynchronization.AUTO, 1285 aggregation=VariableAggregation.NONE): 1286 """Creates a new variable with value `initial_value`. 1287 1288 The new variable is added to the graph collections listed in `collections`, 1289 which defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1290 1291 If `trainable` is `True` the variable is also added to the graph collection 1292 `GraphKeys.TRAINABLE_VARIABLES`. 1293 1294 This constructor creates both a `variable` Op and an `assign` Op to set the 1295 variable to its initial value. 1296 1297 Args: 1298 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 1299 which is the initial value for the Variable. The initial value must have 1300 a shape specified unless `validate_shape` is set to False. Can also be a 1301 callable with no argument that returns the initial value when called. In 1302 that case, `dtype` must be specified. (Note that initializer functions 1303 from init_ops.py must first be bound to a shape before being used here.) 1304 trainable: If `True`, the default, also adds the variable to the graph 1305 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 1306 the default list of variables to use by the `Optimizer` classes. 1307 collections: List of graph collections keys. The new variable is added to 1308 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1309 validate_shape: If `False`, allows the variable to be initialized with a 1310 value of unknown shape. If `True`, the default, the shape of 1311 `initial_value` must be known. 1312 caching_device: Optional device string describing where the Variable 1313 should be cached for reading. Defaults to the Variable's device. 1314 If not `None`, caches on another device. Typical use is to cache 1315 on the device where the Ops using the Variable reside, to deduplicate 1316 copying through `Switch` and other conditional statements. 1317 name: Optional name for the variable. Defaults to `'Variable'` and gets 1318 uniquified automatically. 1319 variable_def: `VariableDef` protocol buffer. If not `None`, recreates 1320 the Variable object with its contents, referencing the variable's nodes 1321 in the graph, which must already exist. The graph is not changed. 1322 `variable_def` and the other arguments are mutually exclusive. 1323 dtype: If set, initial_value will be converted to the given type. 1324 If `None`, either the datatype will be kept (if `initial_value` is 1325 a Tensor), or `convert_to_tensor` will decide. 1326 expected_shape: A TensorShape. If set, initial_value is expected 1327 to have this shape. 1328 import_scope: Optional `string`. Name scope to add to the 1329 `Variable.` Only used when initializing from protocol buffer. 1330 constraint: An optional projection function to be applied to the variable 1331 after being updated by an `Optimizer` (e.g. used to implement norm 1332 constraints or value constraints for layer weights). The function must 1333 take as input the unprojected Tensor representing the value of the 1334 variable and return the Tensor for the projected value 1335 (which must have the same shape). Constraints are not safe to 1336 use when doing asynchronous distributed training. 1337 use_resource: whether to use resource variables. 1338 synchronization: unused 1339 aggregation: unused 1340 1341 Raises: 1342 ValueError: If both `variable_def` and initial_value are specified. 1343 ValueError: If the initial value is not specified, or does not have a 1344 shape and `validate_shape` is `True`. 1345 RuntimeError: If eager execution is enabled. 1346 """ 1347 1348 SaveSliceInfo = Variable.SaveSliceInfo 1349 1350 1351# TODO(apassos): do not repeat all comments here 1352class RefVariable(VariableV1): 1353 """Ref-based implementation of variables.""" 1354 1355 def __init__(self, # pylint: disable=super-init-not-called 1356 initial_value=None, 1357 trainable=True, 1358 collections=None, 1359 validate_shape=True, 1360 caching_device=None, 1361 name=None, 1362 variable_def=None, 1363 dtype=None, 1364 expected_shape=None, 1365 import_scope=None, 1366 constraint=None): 1367 """Creates a new variable with value `initial_value`. 1368 1369 The new variable is added to the graph collections listed in `collections`, 1370 which defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1371 1372 If `trainable` is `True` the variable is also added to the graph collection 1373 `GraphKeys.TRAINABLE_VARIABLES`. 1374 1375 This constructor creates both a `variable` Op and an `assign` Op to set the 1376 variable to its initial value. 1377 1378 Args: 1379 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 1380 which is the initial value for the Variable. The initial value must have 1381 a shape specified unless `validate_shape` is set to False. Can also be a 1382 callable with no argument that returns the initial value when called. In 1383 that case, `dtype` must be specified. (Note that initializer functions 1384 from init_ops.py must first be bound to a shape before being used here.) 1385 trainable: If `True`, the default, also adds the variable to the graph 1386 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 1387 the default list of variables to use by the `Optimizer` classes. 1388 collections: List of graph collections keys. The new variable is added to 1389 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1390 validate_shape: If `False`, allows the variable to be initialized with a 1391 value of unknown shape. If `True`, the default, the shape of 1392 `initial_value` must be known. 1393 caching_device: Optional device string describing where the Variable 1394 should be cached for reading. Defaults to the Variable's device. 1395 If not `None`, caches on another device. Typical use is to cache 1396 on the device where the Ops using the Variable reside, to deduplicate 1397 copying through `Switch` and other conditional statements. 1398 name: Optional name for the variable. Defaults to `'Variable'` and gets 1399 uniquified automatically. 1400 variable_def: `VariableDef` protocol buffer. If not `None`, recreates 1401 the Variable object with its contents, referencing the variable's nodes 1402 in the graph, which must already exist. The graph is not changed. 1403 `variable_def` and the other arguments are mutually exclusive. 1404 dtype: If set, initial_value will be converted to the given type. 1405 If `None`, either the datatype will be kept (if `initial_value` is 1406 a Tensor), or `convert_to_tensor` will decide. 1407 expected_shape: A TensorShape. If set, initial_value is expected 1408 to have this shape. 1409 import_scope: Optional `string`. Name scope to add to the 1410 `Variable.` Only used when initializing from protocol buffer. 1411 constraint: An optional projection function to be applied to the variable 1412 after being updated by an `Optimizer` (e.g. used to implement norm 1413 constraints or value constraints for layer weights). The function must 1414 take as input the unprojected Tensor representing the value of the 1415 variable and return the Tensor for the projected value 1416 (which must have the same shape). Constraints are not safe to 1417 use when doing asynchronous distributed training. 1418 1419 Raises: 1420 ValueError: If both `variable_def` and initial_value are specified. 1421 ValueError: If the initial value is not specified, or does not have a 1422 shape and `validate_shape` is `True`. 1423 RuntimeError: If eager execution is enabled. 1424 """ 1425 self._in_graph_mode = True 1426 if variable_def: 1427 # If variable_def is provided, recreates the variable from its fields. 1428 if initial_value: 1429 raise ValueError("variable_def and initial_value are mutually " 1430 "exclusive.") 1431 self._init_from_proto(variable_def, import_scope=import_scope) 1432 else: 1433 # Create from initial_value. 1434 self._init_from_args( 1435 initial_value=initial_value, 1436 trainable=trainable, 1437 collections=collections, 1438 validate_shape=validate_shape, 1439 caching_device=caching_device, 1440 name=name, 1441 dtype=dtype, 1442 expected_shape=expected_shape, 1443 constraint=constraint) 1444 1445 def __repr__(self): 1446 if context.executing_eagerly() and not self._in_graph_mode: 1447 return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % ( 1448 self.name, self.get_shape(), self.dtype.name, 1449 ops.numpy_text(self.read_value(), is_repr=True)) 1450 else: 1451 return "<tf.Variable '%s' shape=%s dtype=%s>" % ( 1452 self.name, self.get_shape(), self.dtype.name) 1453 1454 def _init_from_args(self, 1455 initial_value=None, 1456 trainable=True, 1457 collections=None, 1458 validate_shape=True, 1459 caching_device=None, 1460 name=None, 1461 dtype=None, 1462 expected_shape=None, 1463 constraint=None): 1464 """Creates a new variable from arguments. 1465 1466 Args: 1467 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 1468 which is the initial value for the Variable. The initial value must have 1469 a shape specified unless `validate_shape` is set to False. Can also be a 1470 callable with no argument that returns the initial value when called. 1471 (Note that initializer functions from init_ops.py must first be bound 1472 to a shape before being used here.) 1473 trainable: If `True`, the default, also adds the variable to the graph 1474 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 1475 the default list of variables to use by the `Optimizer` classes. 1476 collections: List of graph collections keys. The new variable is added to 1477 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1478 validate_shape: If `False`, allows the variable to be initialized with a 1479 value of unknown shape. If `True`, the default, the shape of 1480 `initial_value` must be known. 1481 caching_device: Optional device string or function describing where the 1482 Variable should be cached for reading. Defaults to the Variable's 1483 device. If not `None`, caches on another device. Typical use is to 1484 cache on the device where the Ops using the Variable reside, to 1485 deduplicate copying through `Switch` and other conditional statements. 1486 name: Optional name for the variable. Defaults to `'Variable'` and gets 1487 uniquified automatically. 1488 dtype: If set, initial_value will be converted to the given type. 1489 If None, either the datatype will be kept (if initial_value is 1490 a Tensor) or float32 will be used (if it is a Python object convertible 1491 to a Tensor). 1492 expected_shape: Deprecated. Ignored. 1493 constraint: An optional projection function to be applied to the variable 1494 after being updated by an `Optimizer` (e.g. used to implement norm 1495 constraints or value constraints for layer weights). The function must 1496 take as input the unprojected Tensor representing the value of the 1497 variable and return the Tensor for the projected value 1498 (which must have the same shape). Constraints are not safe to 1499 use when doing asynchronous distributed training. 1500 1501 Raises: 1502 ValueError: If the initial value is not specified, or does not have a 1503 shape and `validate_shape` is `True`. 1504 RuntimeError: If lifted into the eager context. 1505 """ 1506 _ = expected_shape 1507 if initial_value is None: 1508 raise ValueError("initial_value must be specified.") 1509 init_from_fn = callable(initial_value) 1510 1511 if collections is None: 1512 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 1513 if not isinstance(collections, (list, tuple, set)): 1514 raise ValueError( 1515 "collections argument to Variable constructor must be a list, tuple, " 1516 "or set. Got %s of type %s" % (collections, type(collections))) 1517 if constraint is not None and not callable(constraint): 1518 raise ValueError("The `constraint` argument must be a callable.") 1519 1520 # Store the graph key so optimizers know how to only retrieve variables from 1521 # this graph. 1522 self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 1523 if isinstance(initial_value, trackable.CheckpointInitialValue): 1524 self._maybe_initialize_trackable() 1525 self._update_uid = initial_value.checkpoint_position.restore_uid 1526 initial_value = initial_value.wrapped_value 1527 1528 self._trainable = trainable 1529 if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: 1530 collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] 1531 with ops.init_scope(): 1532 # Ensure that we weren't lifted into the eager context. 1533 if context.executing_eagerly(): 1534 raise RuntimeError( 1535 "RefVariable not supported when eager execution is enabled. ") 1536 with ops.name_scope(name, "Variable", [] if init_from_fn else 1537 [initial_value]) as name: 1538 1539 if init_from_fn: 1540 # Use attr_scope and device(None) to simulate the behavior of 1541 # colocate_with when the variable we want to colocate with doesn't 1542 # yet exist. 1543 true_name = ops._name_from_scope_name(name) # pylint: disable=protected-access 1544 attr = attr_value_pb2.AttrValue( 1545 list=attr_value_pb2.AttrValue.ListValue( 1546 s=[compat.as_bytes("loc:@%s" % true_name)])) 1547 # pylint: disable=protected-access 1548 with ops.get_default_graph()._attr_scope({"_class": attr}): 1549 with ops.name_scope("Initializer"), ops.device(None): 1550 self._initial_value = ops.convert_to_tensor( 1551 initial_value(), name="initial_value", dtype=dtype) 1552 shape = (self._initial_value.get_shape() 1553 if validate_shape else tensor_shape.unknown_shape()) 1554 self._variable = state_ops.variable_op_v2( 1555 shape, 1556 self._initial_value.dtype.base_dtype, 1557 name=name) 1558 # pylint: enable=protected-access 1559 1560 # Or get the initial value from a Tensor or Python object. 1561 else: 1562 self._initial_value = ops.convert_to_tensor( 1563 initial_value, name="initial_value", dtype=dtype) 1564 # pylint: disable=protected-access 1565 if self._initial_value.op._get_control_flow_context() is not None: 1566 raise ValueError( 1567 "Initializer for variable %s is from inside a control-flow " 1568 "construct, such as a loop or conditional. When creating a " 1569 "variable inside a loop or conditional, use a lambda as the " 1570 "initializer." % name) 1571 # pylint: enable=protected-access 1572 shape = (self._initial_value.get_shape() 1573 if validate_shape else tensor_shape.unknown_shape()) 1574 # In this case, the variable op can't be created until after the 1575 # initial_value has been converted to a Tensor with a known type. 1576 self._variable = state_ops.variable_op_v2( 1577 shape, 1578 self._initial_value.dtype.base_dtype, 1579 name=name) 1580 1581 # Manually overrides the variable's shape with the initial value's. 1582 if validate_shape: 1583 initial_value_shape = self._initial_value.get_shape() 1584 if not initial_value_shape.is_fully_defined(): 1585 raise ValueError("initial_value must have a shape specified: %s" % 1586 self._initial_value) 1587 1588 # If 'initial_value' makes use of other variables, make sure we don't 1589 # have an issue if these other variables aren't initialized first by 1590 # using their initialized_value() method. 1591 self._initializer_op = state_ops.assign( 1592 self._variable, 1593 _try_guard_against_uninitialized_dependencies( 1594 name, 1595 self._initial_value), 1596 validate_shape=validate_shape).op 1597 1598 # TODO(vrv): Change this class to not take caching_device, but 1599 # to take the op to colocate the snapshot with, so we can use 1600 # colocation rather than devices. 1601 if caching_device is not None: 1602 with ops.device(caching_device): 1603 self._snapshot = array_ops.identity(self._variable, name="read") 1604 else: 1605 with ops.colocate_with(self._variable.op): 1606 self._snapshot = array_ops.identity(self._variable, name="read") 1607 ops.add_to_collections(collections, self) 1608 1609 self._caching_device = caching_device 1610 self._save_slice_info = None 1611 self._constraint = constraint 1612 1613 def _init_from_proto(self, variable_def, import_scope=None): 1614 """Recreates the Variable object from a `VariableDef` protocol buffer. 1615 1616 Args: 1617 variable_def: `VariableDef` protocol buffer, describing a variable 1618 whose nodes already exists in the graph. 1619 import_scope: Optional `string`. Name scope to add. 1620 """ 1621 assert isinstance(variable_def, variable_pb2.VariableDef) 1622 # Create from variable_def. 1623 g = ops.get_default_graph() 1624 self._variable = g.as_graph_element( 1625 ops.prepend_name_scope(variable_def.variable_name, 1626 import_scope=import_scope)) 1627 self._initializer_op = g.as_graph_element( 1628 ops.prepend_name_scope(variable_def.initializer_name, 1629 import_scope=import_scope)) 1630 # Tests whether initial_value_name exists first for backwards compatibility. 1631 if (hasattr(variable_def, "initial_value_name") and 1632 variable_def.initial_value_name): 1633 self._initial_value = g.as_graph_element( 1634 ops.prepend_name_scope(variable_def.initial_value_name, 1635 import_scope=import_scope)) 1636 else: 1637 self._initial_value = None 1638 self._trainable = getattr(variable_def, "trainable", True) 1639 self._snapshot = g.as_graph_element( 1640 ops.prepend_name_scope(variable_def.snapshot_name, 1641 import_scope=import_scope)) 1642 if variable_def.HasField("save_slice_info_def"): 1643 self._save_slice_info = Variable.SaveSliceInfo( 1644 save_slice_info_def=variable_def.save_slice_info_def, 1645 import_scope=import_scope) 1646 else: 1647 self._save_slice_info = None 1648 self._caching_device = None 1649 self._constraint = None 1650 1651 def _as_graph_element(self): 1652 """Conversion function for Graph.as_graph_element().""" 1653 return self._variable 1654 1655 def value(self): 1656 """Returns the last snapshot of this variable. 1657 1658 You usually do not need to call this method as all ops that need the value 1659 of the variable call it automatically through a `convert_to_tensor()` call. 1660 1661 Returns a `Tensor` which holds the value of the variable. You can not 1662 assign a new value to this tensor as it is not a reference to the variable. 1663 1664 To avoid copies, if the consumer of the returned value is on the same device 1665 as the variable, this actually returns the live value of the variable, not 1666 a copy. Updates to the variable are seen by the consumer. If the consumer 1667 is on a different device it will get a copy of the variable. 1668 1669 Returns: 1670 A `Tensor` containing the value of the variable. 1671 """ 1672 return self._snapshot 1673 1674 def read_value(self): 1675 """Returns the value of this variable, read in the current context. 1676 1677 Can be different from value() if it's on another device, with control 1678 dependencies, etc. 1679 1680 Returns: 1681 A `Tensor` containing the value of the variable. 1682 """ 1683 return array_ops.identity(self._variable, name="read") 1684 1685 def _ref(self): 1686 """Returns a reference to this variable. 1687 1688 You usually do not need to call this method as all ops that need a reference 1689 to the variable call it automatically. 1690 1691 Returns is a `Tensor` which holds a reference to the variable. You can 1692 assign a new value to the variable by passing the tensor to an assign op. 1693 See `tf.Variable.value` if you want to get the value of the 1694 variable. 1695 1696 Returns: 1697 A `Tensor` that is a reference to the variable. 1698 """ 1699 return self._variable 1700 1701 def set_shape(self, shape): 1702 """Overrides the shape for this variable. 1703 1704 Args: 1705 shape: the `TensorShape` representing the overridden shape. 1706 """ 1707 self._ref().set_shape(shape) 1708 self.value().set_shape(shape) 1709 1710 @property 1711 def trainable(self): 1712 return self._trainable 1713 1714 def eval(self, session=None): 1715 """In a session, computes and returns the value of this variable. 1716 1717 This is not a graph construction method, it does not add ops to the graph. 1718 1719 This convenience method requires a session where the graph 1720 containing this variable has been launched. If no session is 1721 passed, the default session is used. See `tf.Session` for more 1722 information on launching a graph and on sessions. 1723 1724 ```python 1725 v = tf.Variable([1, 2]) 1726 init = tf.global_variables_initializer() 1727 1728 with tf.Session() as sess: 1729 sess.run(init) 1730 # Usage passing the session explicitly. 1731 print(v.eval(sess)) 1732 # Usage with the default session. The 'with' block 1733 # above makes 'sess' the default session. 1734 print(v.eval()) 1735 ``` 1736 1737 Args: 1738 session: The session to use to evaluate this variable. If 1739 none, the default session is used. 1740 1741 Returns: 1742 A numpy `ndarray` with a copy of the value of this variable. 1743 """ 1744 return self._variable.eval(session=session) 1745 1746 @property 1747 def initial_value(self): 1748 """Returns the Tensor used as the initial value for the variable. 1749 1750 Note that this is different from `initialized_value()` which runs 1751 the op that initializes the variable before returning its value. 1752 This method returns the tensor that is used by the op that initializes 1753 the variable. 1754 1755 Returns: 1756 A `Tensor`. 1757 """ 1758 return self._initial_value 1759 1760 @property 1761 def constraint(self): 1762 """Returns the constraint function associated with this variable. 1763 1764 Returns: 1765 The constraint function that was passed to the variable constructor. 1766 Can be `None` if no constraint was passed. 1767 """ 1768 return self._constraint 1769 1770 def assign(self, value, use_locking=False, name=None, read_value=True): 1771 """Assigns a new value to the variable. 1772 1773 This is essentially a shortcut for `assign(self, value)`. 1774 1775 Args: 1776 value: A `Tensor`. The new value for this variable. 1777 use_locking: If `True`, use locking during the assignment. 1778 name: The name of the operation to be created 1779 read_value: if True, will return something which evaluates to the 1780 new value of the variable; if False will return the assign op. 1781 1782 Returns: 1783 A `Tensor` that will hold the new value of this variable after 1784 the assignment has completed. 1785 """ 1786 assign = state_ops.assign(self._variable, value, use_locking=use_locking, 1787 name=name) 1788 if read_value: 1789 return assign 1790 return assign.op 1791 1792 def assign_add(self, delta, use_locking=False, name=None, read_value=True): 1793 """Adds a value to this variable. 1794 1795 This is essentially a shortcut for `assign_add(self, delta)`. 1796 1797 Args: 1798 delta: A `Tensor`. The value to add to this variable. 1799 use_locking: If `True`, use locking during the operation. 1800 name: The name of the operation to be created 1801 read_value: if True, will return something which evaluates to the 1802 new value of the variable; if False will return the assign op. 1803 1804 Returns: 1805 A `Tensor` that will hold the new value of this variable after 1806 the addition has completed. 1807 """ 1808 assign = state_ops.assign_add( 1809 self._variable, delta, use_locking=use_locking, name=name) 1810 if read_value: 1811 return assign 1812 return assign.op 1813 1814 def assign_sub(self, delta, use_locking=False, name=None, read_value=True): 1815 """Subtracts a value from this variable. 1816 1817 This is essentially a shortcut for `assign_sub(self, delta)`. 1818 1819 Args: 1820 delta: A `Tensor`. The value to subtract from this variable. 1821 use_locking: If `True`, use locking during the operation. 1822 name: The name of the operation to be created 1823 read_value: if True, will return something which evaluates to the 1824 new value of the variable; if False will return the assign op. 1825 1826 Returns: 1827 A `Tensor` that will hold the new value of this variable after 1828 the subtraction has completed. 1829 """ 1830 assign = state_ops.assign_sub( 1831 self._variable, delta, use_locking=use_locking, name=name) 1832 if read_value: 1833 return assign 1834 return assign.op 1835 1836 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 1837 """Subtracts `IndexedSlices` from this variable. 1838 1839 Args: 1840 sparse_delta: `IndexedSlices` to be subtracted from this variable. 1841 use_locking: If `True`, use locking during the operation. 1842 name: the name of the operation. 1843 1844 Returns: 1845 A `Tensor` that will hold the new value of this variable after 1846 the scattered subtraction has completed. 1847 1848 Raises: 1849 ValueError: if `sparse_delta` is not an `IndexedSlices`. 1850 """ 1851 if not isinstance(sparse_delta, ops.IndexedSlices): 1852 raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 1853 return gen_state_ops.scatter_sub( 1854 self._variable, 1855 sparse_delta.indices, 1856 sparse_delta.values, 1857 use_locking=use_locking, 1858 name=name) 1859 1860 def scatter_add(self, sparse_delta, use_locking=False, name=None): 1861 """Adds `IndexedSlices` from this variable. 1862 1863 Args: 1864 sparse_delta: `IndexedSlices` to be added to this variable. 1865 use_locking: If `True`, use locking during the operation. 1866 name: the name of the operation. 1867 1868 Returns: 1869 A `Tensor` that will hold the new value of this variable after 1870 the scattered subtraction has completed. 1871 1872 Raises: 1873 ValueError: if `sparse_delta` is not an `IndexedSlices`. 1874 """ 1875 if not isinstance(sparse_delta, ops.IndexedSlices): 1876 raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 1877 return gen_state_ops.scatter_add( 1878 self._variable, 1879 sparse_delta.indices, 1880 sparse_delta.values, 1881 use_locking=use_locking, 1882 name=name) 1883 1884 def scatter_update(self, sparse_delta, use_locking=False, name=None): 1885 """Assigns `IndexedSlices` to this variable. 1886 1887 Args: 1888 sparse_delta: `IndexedSlices` to be assigned to this variable. 1889 use_locking: If `True`, use locking during the operation. 1890 name: the name of the operation. 1891 1892 Returns: 1893 A `Tensor` that will hold the new value of this variable after 1894 the scattered subtraction has completed. 1895 1896 Raises: 1897 ValueError: if `sparse_delta` is not an `IndexedSlices`. 1898 """ 1899 if not isinstance(sparse_delta, ops.IndexedSlices): 1900 raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 1901 return gen_state_ops.scatter_update( 1902 self._variable, 1903 sparse_delta.indices, 1904 sparse_delta.values, 1905 use_locking=use_locking, 1906 name=name) 1907 1908 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 1909 """Assigns `IndexedSlices` to this variable batch-wise. 1910 1911 Analogous to `batch_gather`. This assumes that this variable and the 1912 sparse_delta IndexedSlices have a series of leading dimensions that are the 1913 same for all of them, and the updates are performed on the last dimension of 1914 indices. In other words, the dimensions should be the following: 1915 1916 `num_prefix_dims = sparse_delta.indices.ndims - 1` 1917 `batch_dim = num_prefix_dims + 1` 1918 `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[ 1919 batch_dim:]` 1920 1921 where 1922 1923 `sparse_delta.updates.shape[:num_prefix_dims]` 1924 `== sparse_delta.indices.shape[:num_prefix_dims]` 1925 `== var.shape[:num_prefix_dims]` 1926 1927 And the operation performed can be expressed as: 1928 1929 `var[i_1, ..., i_n, 1930 sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[ 1931 i_1, ..., i_n, j]` 1932 1933 When sparse_delta.indices is a 1D tensor, this operation is equivalent to 1934 `scatter_update`. 1935 1936 To avoid this operation one can looping over the first `ndims` of the 1937 variable and using `scatter_update` on the subtensors that result of slicing 1938 the first dimension. This is a valid option for `ndims = 1`, but less 1939 efficient than this implementation. 1940 1941 Args: 1942 sparse_delta: `IndexedSlices` to be assigned to this variable. 1943 use_locking: If `True`, use locking during the operation. 1944 name: the name of the operation. 1945 1946 Returns: 1947 A `Tensor` that will hold the new value of this variable after 1948 the scattered subtraction has completed. 1949 1950 Raises: 1951 ValueError: if `sparse_delta` is not an `IndexedSlices`. 1952 """ 1953 return state_ops.batch_scatter_update( 1954 self, sparse_delta.indices, sparse_delta.values, 1955 use_locking=use_locking, name=name) 1956 1957 def scatter_nd_sub(self, indices, updates, name=None): 1958 """Applies sparse subtraction to individual values or slices in a Variable. 1959 1960 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 1961 1962 `indices` must be integer tensor, containing indices into `ref`. 1963 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 1964 1965 The innermost dimension of `indices` (with length `K`) corresponds to 1966 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 1967 dimension of `ref`. 1968 1969 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 1970 1971 ``` 1972 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 1973 ``` 1974 1975 For example, say we want to add 4 scattered elements to a rank-1 tensor to 1976 8 elements. In Python, that update would look like this: 1977 1978 ```python 1979 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 1980 indices = tf.constant([[4], [3], [1] ,[7]]) 1981 updates = tf.constant([9, 10, 11, 12]) 1982 op = ref.scatter_nd_sub(indices, updates) 1983 with tf.Session() as sess: 1984 print sess.run(op) 1985 ``` 1986 1987 The resulting update to ref would look like this: 1988 1989 [1, -9, 3, -6, -6, 6, 7, -4] 1990 1991 See `tf.scatter_nd` for more details about how to make updates to 1992 slices. 1993 1994 Args: 1995 indices: The indices to be used in the operation. 1996 updates: The values to be used in the operation. 1997 name: the name of the operation. 1998 1999 Returns: 2000 A `Tensor` that will hold the new value of this variable after 2001 the scattered subtraction has completed. 2002 2003 Raises: 2004 ValueError: if `sparse_delta` is not an `IndexedSlices`. 2005 """ 2006 return gen_state_ops.scatter_nd_sub( 2007 self._variable, indices, updates, use_locking=True, name=name) 2008 2009 def scatter_nd_add(self, indices, updates, name=None): 2010 """Applies sparse addition to individual values or slices in a Variable. 2011 2012 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 2013 2014 `indices` must be integer tensor, containing indices into `ref`. 2015 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 2016 2017 The innermost dimension of `indices` (with length `K`) corresponds to 2018 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 2019 dimension of `ref`. 2020 2021 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 2022 2023 ``` 2024 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 2025 ``` 2026 2027 For example, say we want to add 4 scattered elements to a rank-1 tensor to 2028 8 elements. In Python, that update would look like this: 2029 2030 ```python 2031 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 2032 indices = tf.constant([[4], [3], [1] ,[7]]) 2033 updates = tf.constant([9, 10, 11, 12]) 2034 add = ref.scatter_nd_add(indices, updates) 2035 with tf.Session() as sess: 2036 print sess.run(add) 2037 ``` 2038 2039 The resulting update to ref would look like this: 2040 2041 [1, 13, 3, 14, 14, 6, 7, 20] 2042 2043 See `tf.scatter_nd` for more details about how to make updates to 2044 slices. 2045 2046 Args: 2047 indices: The indices to be used in the operation. 2048 updates: The values to be used in the operation. 2049 name: the name of the operation. 2050 2051 Returns: 2052 A `Tensor` that will hold the new value of this variable after 2053 the scattered subtraction has completed. 2054 2055 Raises: 2056 ValueError: if `sparse_delta` is not an `IndexedSlices`. 2057 """ 2058 return gen_state_ops.scatter_nd_add( 2059 self._variable, indices, updates, use_locking=True, name=name) 2060 2061 def scatter_nd_update(self, indices, updates, name=None): 2062 """Applies sparse assignment to individual values or slices in a Variable. 2063 2064 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 2065 2066 `indices` must be integer tensor, containing indices into `ref`. 2067 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 2068 2069 The innermost dimension of `indices` (with length `K`) corresponds to 2070 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 2071 dimension of `ref`. 2072 2073 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 2074 2075 ``` 2076 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 2077 ``` 2078 2079 For example, say we want to add 4 scattered elements to a rank-1 tensor to 2080 8 elements. In Python, that update would look like this: 2081 2082 ```python 2083 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 2084 indices = tf.constant([[4], [3], [1] ,[7]]) 2085 updates = tf.constant([9, 10, 11, 12]) 2086 op = ref.scatter_nd_update(indices, updates) 2087 with tf.Session() as sess: 2088 print sess.run(op) 2089 ``` 2090 2091 The resulting update to ref would look like this: 2092 2093 [1, 11, 3, 10, 9, 6, 7, 12] 2094 2095 See `tf.scatter_nd` for more details about how to make updates to 2096 slices. 2097 2098 Args: 2099 indices: The indices to be used in the operation. 2100 updates: The values to be used in the operation. 2101 name: the name of the operation. 2102 2103 Returns: 2104 A `Tensor` that will hold the new value of this variable after 2105 the scattered subtraction has completed. 2106 2107 Raises: 2108 ValueError: if `sparse_delta` is not an `IndexedSlices`. 2109 """ 2110 return gen_state_ops.scatter_nd_update( 2111 self._variable, indices, updates, use_locking=True, name=name) 2112 2113 def _strided_slice_assign(self, 2114 begin, 2115 end, 2116 strides, 2117 value, 2118 name, 2119 begin_mask, 2120 end_mask, 2121 ellipsis_mask, 2122 new_axis_mask, 2123 shrink_axis_mask): 2124 return gen_array_ops.strided_slice_assign(ref=self._ref(), 2125 begin=begin, 2126 end=end, 2127 strides=strides, 2128 value=value, 2129 name=name, 2130 begin_mask=begin_mask, 2131 end_mask=end_mask, 2132 ellipsis_mask=ellipsis_mask, 2133 new_axis_mask=new_axis_mask, 2134 shrink_axis_mask=shrink_axis_mask) 2135 2136 @deprecated(None, "Prefer Dataset.range instead.") 2137 def count_up_to(self, limit): 2138 """Increments this variable until it reaches `limit`. 2139 2140 When that Op is run it tries to increment the variable by `1`. If 2141 incrementing the variable would bring it above `limit` then the Op raises 2142 the exception `OutOfRangeError`. 2143 2144 If no error is raised, the Op outputs the value of the variable before 2145 the increment. 2146 2147 This is essentially a shortcut for `count_up_to(self, limit)`. 2148 2149 Args: 2150 limit: value at which incrementing the variable raises an error. 2151 2152 Returns: 2153 A `Tensor` that will hold the variable value before the increment. If no 2154 other Op modifies this variable, the values produced will all be 2155 distinct. 2156 """ 2157 return state_ops.count_up_to(self._variable, limit=limit) 2158 2159 # Conversion to tensor. 2160 @staticmethod 2161 def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name 2162 """Utility function for converting a Variable to a Tensor.""" 2163 _ = name 2164 if dtype and not dtype.is_compatible_with(v.dtype): 2165 raise ValueError( 2166 "Incompatible type conversion requested to type '%s' for variable " 2167 "of type '%s'" % (dtype.name, v.dtype.name)) 2168 if as_ref: 2169 return v._ref() # pylint: disable=protected-access 2170 else: 2171 return v.value() 2172 2173 # NOTE(mrry): This enables the Variable's overloaded "right" binary 2174 # operators to run when the left operand is an ndarray, because it 2175 # accords the Variable class higher priority than an ndarray, or a 2176 # numpy matrix. 2177 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ 2178 # mechanism, which allows more control over how Variables interact 2179 # with ndarrays. 2180 __array_priority__ = 100 2181 2182 @property 2183 def name(self): 2184 """The name of this variable.""" 2185 return self._variable.name 2186 2187 @property 2188 def initializer(self): 2189 """The initializer operation for this variable.""" 2190 return self._initializer_op 2191 2192 @property 2193 def device(self): 2194 """The device of this variable.""" 2195 return self._variable.device 2196 2197 @property 2198 def dtype(self): 2199 """The `DType` of this variable.""" 2200 return self._variable.dtype 2201 2202 @property 2203 def op(self): 2204 """The `Operation` of this variable.""" 2205 return self._variable.op 2206 2207 @property 2208 def graph(self): 2209 """The `Graph` of this variable.""" 2210 return self._variable.graph 2211 2212 @property 2213 def _distribute_strategy(self): 2214 """The `tf.distribute.Strategy` that this variable was created under.""" 2215 return None # Ref variables are never created inside a strategy. 2216 2217 @property 2218 def shape(self): 2219 """The `TensorShape` of this variable. 2220 2221 Returns: 2222 A `TensorShape`. 2223 """ 2224 return self._variable.get_shape() 2225 2226 def to_proto(self, export_scope=None): 2227 """Converts a `Variable` to a `VariableDef` protocol buffer. 2228 2229 Args: 2230 export_scope: Optional `string`. Name scope to remove. 2231 2232 Returns: 2233 A `VariableDef` protocol buffer, or `None` if the `Variable` is not 2234 in the specified name scope. 2235 """ 2236 if (export_scope is None or 2237 self._variable.name.startswith(export_scope)): 2238 var_def = variable_pb2.VariableDef() 2239 var_def.variable_name = ops.strip_name_scope( 2240 self._variable.name, export_scope) 2241 if self._initial_value is not None: 2242 # For backwards compatibility. 2243 var_def.initial_value_name = ops.strip_name_scope( 2244 self._initial_value.name, export_scope) 2245 var_def.trainable = self.trainable 2246 var_def.initializer_name = ops.strip_name_scope( 2247 self.initializer.name, export_scope) 2248 var_def.snapshot_name = ops.strip_name_scope( 2249 self._snapshot.name, export_scope) 2250 if self._save_slice_info: 2251 var_def.save_slice_info_def.MergeFrom(self._save_slice_info.to_proto( 2252 export_scope=export_scope)) 2253 return var_def 2254 else: 2255 return None 2256 2257 def __iadd__(self, other): 2258 logging.log_first_n( 2259 logging.WARN, 2260 "Variable += will be deprecated. Use variable.assign_add" 2261 " if you want assignment to the variable value or 'x = x + y'" 2262 " if you want a new python Tensor object.", 1) 2263 return self + other 2264 2265 def __isub__(self, other): 2266 logging.log_first_n( 2267 logging.WARN, 2268 "Variable -= will be deprecated. Use variable.assign_sub" 2269 " if you want assignment to the variable value or 'x = x - y'" 2270 " if you want a new python Tensor object.", 1) 2271 return self - other 2272 2273 def __imul__(self, other): 2274 logging.log_first_n( 2275 logging.WARN, 2276 "Variable *= will be deprecated. Use `var.assign(var * other)`" 2277 " if you want assignment to the variable value or `x = x * y`" 2278 " if you want a new python Tensor object.", 1) 2279 return self * other 2280 2281 def __idiv__(self, other): 2282 logging.log_first_n( 2283 logging.WARN, 2284 "Variable /= will be deprecated. Use `var.assign(var / other)`" 2285 " if you want assignment to the variable value or `x = x / y`" 2286 " if you want a new python Tensor object.", 1) 2287 return self / other 2288 2289 def __itruediv__(self, other): 2290 logging.log_first_n( 2291 logging.WARN, 2292 "Variable /= will be deprecated. Use `var.assign(var / other)`" 2293 " if you want assignment to the variable value or `x = x / y`" 2294 " if you want a new python Tensor object.", 1) 2295 return self / other 2296 2297 def __irealdiv__(self, other): 2298 logging.log_first_n( 2299 logging.WARN, 2300 "Variable /= will be deprecated. Use `var.assign(var / other)`" 2301 " if you want assignment to the variable value or `x = x / y`" 2302 " if you want a new python Tensor object.", 1) 2303 return self / other 2304 2305 def __ipow__(self, other): 2306 logging.log_first_n( 2307 logging.WARN, 2308 "Variable **= will be deprecated. Use `var.assign(var ** other)`" 2309 " if you want assignment to the variable value or `x = x ** y`" 2310 " if you want a new python Tensor object.", 1) 2311 return self ** other 2312 2313 2314def _try_guard_against_uninitialized_dependencies(name, initial_value): 2315 """Attempt to guard against dependencies on uninitialized variables. 2316 2317 Replace references to variables in `initial_value` with references to the 2318 variable's initialized values. The initialized values are essentially 2319 conditional TensorFlow graphs that return a variable's value if it is 2320 initialized or its `initial_value` if it hasn't been initialized. This 2321 replacement is done on a best effort basis: 2322 2323 - If the `initial_value` graph contains cycles, we don't do any 2324 replacements for that graph. 2325 - If the variables that `initial_value` depends on are not present in the 2326 `GLOBAL_VARIABLES` or `LOCAL_VARIABLES` we don't replace them. 2327 2328 In these cases, it is up to the caller to ensure that the `initial_value` 2329 graph uses initialized variables or that they guard access to variables 2330 using their `initialized_value` method. 2331 2332 Args: 2333 name: Variable name. 2334 initial_value: `Tensor`. The initial value. 2335 Returns: 2336 A `Tensor` suitable to initialize a variable. 2337 Raises: 2338 TypeError: If `initial_value` is not a `Tensor`. 2339 """ 2340 if not isinstance(initial_value, ops.Tensor): 2341 raise TypeError("initial_value needs to be a Tensor: %s" % initial_value) 2342 2343 # Don't modify initial_value if it contains any cyclic dependencies. 2344 if _has_cycle(initial_value.op, path=set()): 2345 return initial_value 2346 return _safe_initial_value_from_tensor(name, initial_value, op_cache={}) 2347 2348 2349def _has_cycle(op, path): 2350 """Detect cycles in the dependencies of `initial_value`.""" 2351 if op.name in path: 2352 return True 2353 path.add(op.name) 2354 for op_input in op.inputs: 2355 if _has_cycle(op_input.op, path): 2356 return True 2357 for op_control_input in op.control_inputs: 2358 if _has_cycle(op_control_input, path): 2359 return True 2360 path.remove(op.name) 2361 return False 2362 2363 2364def _safe_initial_value_from_tensor(name, tensor, op_cache): 2365 """Replace dependencies on variables with their initialized values. 2366 2367 Args: 2368 name: Variable name. 2369 tensor: A `Tensor`. The tensor to replace. 2370 op_cache: A dict mapping operation names to `Operation`s. Used to memoize 2371 the results so as to avoid creating redundant operations. 2372 Returns: 2373 A `Tensor` compatible with `tensor`. Any inputs that lead to variable 2374 values will be replaced with a corresponding graph that uses the 2375 variable's initialized values. This is done on a best-effort basis. If no 2376 modifications need to be made then `tensor` will be returned unchanged. 2377 """ 2378 op = tensor.op 2379 new_op = op_cache.get(op.name) 2380 if new_op is None: 2381 new_op = _safe_initial_value_from_op(name, op, op_cache) 2382 op_cache[op.name] = new_op 2383 return new_op.outputs[tensor.value_index] 2384 2385 2386def _safe_initial_value_from_op(name, op, op_cache): 2387 """Replace dependencies on variables with their initialized values. 2388 2389 Args: 2390 name: Variable name. 2391 op: An `Operation`. The operation to replace. 2392 op_cache: A dict mapping operation names to `Operation`s. Used to memoize 2393 the results so as to avoid creating redundant operations. 2394 Returns: 2395 An `Operation` compatible with `op`. Any inputs that lead to variable 2396 values will be replaced with a corresponding graph that uses the 2397 variable's initialized values. This is done on a best-effort basis. If no 2398 modifications need to be made then `op` will be returned unchanged. 2399 """ 2400 op_type = op.node_def.op 2401 if op_type in ("IsVariableInitialized", "VarIsInitializedOp", 2402 "ReadVariableOp"): 2403 return op 2404 2405 # Attempt to find the initialized_value of any variable reference / handles. 2406 # TODO(b/70206927): Fix handling of ResourceVariables. 2407 if op_type in ("Variable", "VariableV2", "VarHandleOp"): 2408 initialized_value = _find_initialized_value_for_variable(op) 2409 return op if initialized_value is None else initialized_value.op 2410 2411 # Recursively build initializer expressions for inputs. 2412 modified = False 2413 new_op_inputs = [] 2414 for op_input in op.inputs: 2415 new_op_input = _safe_initial_value_from_tensor(name, op_input, op_cache) 2416 new_op_inputs.append(new_op_input) 2417 modified = modified or (new_op_input != op_input) 2418 2419 # If at least one input was modified, replace the op. 2420 if modified: 2421 new_op_type = op_type 2422 if new_op_type == "RefSwitch": 2423 new_op_type = "Switch" 2424 new_op_name = op.node_def.name + "_" + name 2425 new_op_name = new_op_name.replace(":", "_") 2426 return op.graph.create_op( 2427 new_op_type, new_op_inputs, 2428 op._output_types, # pylint: disable=protected-access 2429 name=new_op_name, attrs=op.node_def.attr) 2430 2431 return op 2432 2433 2434def _find_initialized_value_for_variable(variable_op): 2435 """Find the initialized value for a variable op. 2436 2437 To do so, lookup the variable op in the variables collection. 2438 2439 Args: 2440 variable_op: A variable `Operation`. 2441 Returns: 2442 A `Tensor` representing the initialized value for the variable or `None` 2443 if the initialized value could not be found. 2444 """ 2445 try: 2446 var_names = [variable_op.node_def.name, variable_op.node_def.name + ":0"] 2447 for collection_name in (ops.GraphKeys.GLOBAL_VARIABLES, 2448 ops.GraphKeys.LOCAL_VARIABLES): 2449 for var in variable_op.graph.get_collection(collection_name): 2450 if var.name in var_names: 2451 return var.initialized_value() 2452 except AttributeError: 2453 # Return None when an incomplete user-defined variable type was put in 2454 # the collection. 2455 return None 2456 return None 2457 2458 2459class PartitionedVariable(object): 2460 """A container for partitioned `Variable` objects. 2461 2462 @compatibility(eager) `tf.PartitionedVariable` is not compatible with 2463 eager execution. Use `tf.Variable` instead which is compatible 2464 with both eager execution and graph construction. See [the 2465 TensorFlow Eager Execution 2466 guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers) 2467 for details on how variables work in eager execution. 2468 @end_compatibility 2469 """ 2470 2471 def __init__(self, name, shape, dtype, variable_list, partitions): 2472 """Creates a new partitioned variable wrapper. 2473 2474 Variables passed via the variable_list must contain a save_slice_info 2475 field. Concatenation and iteration is in lexicographic order according 2476 to the var_offset property of the save_slice_info. 2477 2478 Args: 2479 name: String. Overall name of the variables. 2480 shape: List of integers. Overall shape of the variables. 2481 dtype: Type of the variables. 2482 variable_list: List of `Variable` that comprise this partitioned variable. 2483 partitions: List of integers. Number of partitions for each dimension. 2484 2485 Raises: 2486 TypeError: If `variable_list` is not a list of `Variable` objects, or 2487 `partitions` is not a list. 2488 ValueError: If `variable_list` is empty, or the `Variable` shape 2489 information does not match `shape`, or `partitions` has invalid values. 2490 """ 2491 if not isinstance(variable_list, (list, tuple)): 2492 raise TypeError( 2493 "variable_list is not a list or tuple: %s" % variable_list) 2494 if not isinstance(partitions, (list, tuple)): 2495 raise TypeError("partitions is not a list or tuple: %s" % partitions) 2496 if not all(p >= 1 for p in partitions): 2497 raise ValueError("partition values must be positive: %s" % partitions) 2498 if not variable_list: 2499 raise ValueError("variable_list may not be empty") 2500 # pylint: disable=protected-access 2501 for v in variable_list: 2502 # Sort the variable_list lexicographically according to var offset value. 2503 if not all(v._get_save_slice_info() is not None for v in variable_list): 2504 raise ValueError( 2505 "All variables must have a save_slice_info available: %s" 2506 % [v.name for v in variable_list]) 2507 if len(shape) != len(partitions): 2508 raise ValueError("len(shape) != len(partitions): %s vs. %s" 2509 % (shape, partitions)) 2510 if v._get_save_slice_info().full_shape != shape: 2511 raise ValueError( 2512 "All variables' full shapes must match shape: %s; " 2513 "but full shapes were: %s" 2514 % (shape, str([v._get_save_slice_info().full_shape]))) 2515 self._variable_list = sorted( 2516 variable_list, key=lambda v: v._get_save_slice_info().var_offset) 2517 # pylint: enable=protected-access 2518 2519 self._name = name 2520 self._shape = shape 2521 self._dtype = dtype 2522 self._partitions = partitions 2523 self._as_tensor = None 2524 2525 def __iter__(self): 2526 """Return an iterable for accessing the underlying partition Variables.""" 2527 return iter(self._variable_list) 2528 2529 def __len__(self): 2530 num_partition_axes = len(self._partition_axes()) 2531 if num_partition_axes > 1: 2532 raise ValueError("Cannot get a length for %d > 1 partition axes" 2533 % num_partition_axes) 2534 return len(self._variable_list) 2535 2536 def _partition_axes(self): 2537 if all(p == 1 for p in self._partitions): 2538 return [0] 2539 else: 2540 return [i for i, p in enumerate(self._partitions) if p > 1] 2541 2542 def _concat(self): 2543 """Returns the overall concatenated value as a `Tensor`. 2544 2545 This is different from using the partitioned variable directly as a tensor 2546 (through tensor conversion and `as_tensor`) in that it creates a new set of 2547 operations that keeps the control dependencies from its scope. 2548 2549 Returns: 2550 `Tensor` containing the concatenated value. 2551 """ 2552 if len(self._variable_list) == 1: 2553 with ops.name_scope(None): 2554 return array_ops.identity(self._variable_list[0], name=self._name) 2555 2556 partition_axes = self._partition_axes() 2557 2558 if len(partition_axes) > 1: 2559 raise NotImplementedError( 2560 "Cannot concatenate along more than one dimension: %s. " 2561 "Multi-axis partition concat is not supported" % str(partition_axes)) 2562 partition_ix = partition_axes[0] 2563 2564 with ops.name_scope(self._name + "/ConcatPartitions/"): 2565 concatenated = array_ops.concat(self._variable_list, partition_ix) 2566 2567 with ops.name_scope(None): 2568 return array_ops.identity(concatenated, name=self._name) 2569 2570 def as_tensor(self): 2571 """Returns the overall concatenated value as a `Tensor`. 2572 2573 The returned tensor will not inherit the control dependencies from the scope 2574 where the value is used, which is similar to getting the value of 2575 `Variable`. 2576 2577 Returns: 2578 `Tensor` containing the concatenated value. 2579 """ 2580 with ops.control_dependencies(None): 2581 return self._concat() 2582 2583 @staticmethod 2584 def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): 2585 # pylint: disable=invalid-name 2586 _ = name 2587 if dtype is not None and not dtype.is_compatible_with(v.dtype): 2588 raise ValueError( 2589 "Incompatible type conversion requested to type '%s' for variable " 2590 "of type '%s'" % (dtype.name, v.dtype.name)) 2591 if as_ref: 2592 raise NotImplementedError( 2593 "PartitionedVariable doesn't support being used as a reference.") 2594 else: 2595 return v.as_tensor() 2596 2597 @property 2598 def name(self): 2599 return self._name 2600 2601 @property 2602 def dtype(self): 2603 return self._dtype 2604 2605 @property 2606 def shape(self): 2607 return self.get_shape() 2608 2609 @property 2610 def _distribute_strategy(self): 2611 """The `tf.distribute.Strategy` that this variable was created under.""" 2612 # NOTE(yuefengz): Today, no partitioned variables in a distribute strategy. 2613 return None 2614 2615 def get_shape(self): 2616 return self._shape 2617 2618 def _get_variable_list(self): 2619 return self._variable_list 2620 2621 def _get_partitions(self): 2622 return self._partitions 2623 2624 def _apply_assign_fn(self, assign_fn, value): 2625 partition_axes = self._partition_axes() 2626 if len(partition_axes) > 1: 2627 raise NotImplementedError( 2628 "Cannot do assign action along more than one dimension: %s. " 2629 "Multi-axis partition assign action is not supported " % 2630 str(partition_axes)) 2631 if isinstance(value, list): 2632 assert len(value) == len(self._variable_list) 2633 value_list = value 2634 elif isinstance(value, PartitionedVariable): 2635 value_list = [var_part for var_part in value] 2636 else: 2637 partition_ix = partition_axes[0] 2638 size_splits_list = [ 2639 tensor_shape.dimension_value(var.shape[partition_ix]) 2640 for var in self._variable_list 2641 ] 2642 value_list = array_ops.split(value, size_splits_list, axis=partition_ix) 2643 2644 op_list = [ 2645 assign_fn(var, value_list[idx]) 2646 for idx, var in enumerate(self._variable_list) 2647 ] 2648 return op_list 2649 2650 def assign(self, value, use_locking=False, name=None, read_value=True): 2651 assign_fn = lambda var, r_value: var.assign( 2652 r_value, use_locking=use_locking, 2653 name=name, read_value=read_value) 2654 assign_list = self._apply_assign_fn(assign_fn, value) 2655 if read_value: 2656 return assign_list 2657 return [assign.op for assign in assign_list] 2658 2659 def assign_add(self, value, use_locking=False, name=None, read_value=True): 2660 assign_fn = lambda var, r_value: var.assign_add( 2661 r_value, use_locking=use_locking, 2662 name=name, read_value=read_value) 2663 assign_list = self._apply_assign_fn(assign_fn, value) 2664 if read_value: 2665 return assign_list 2666 return [assign.op for assign in assign_list] 2667 2668 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 2669 assign_fn = lambda var, r_value: var.assign_sub( 2670 r_value, use_locking=use_locking, 2671 name=name, read_value=read_value) 2672 assign_list = self._apply_assign_fn(assign_fn, value) 2673 if read_value: 2674 return assign_list 2675 return [assign.op for assign in assign_list] 2676 2677 2678# Register a conversion function which reads the value of the variable, 2679# allowing instances of the class to be used as tensors. 2680ops.register_tensor_conversion_function( 2681 RefVariable, 2682 RefVariable._TensorConversionFunction) # pylint: disable=protected-access 2683ops.register_dense_tensor_like_type(RefVariable) 2684 2685 2686@tf_export(v1=["global_variables"]) 2687def global_variables(scope=None): 2688 """Returns global variables. 2689 2690 Global variables are variables that are shared across machines in a 2691 distributed environment. The `Variable()` constructor or `get_variable()` 2692 automatically adds new variables to the graph collection 2693 `GraphKeys.GLOBAL_VARIABLES`. 2694 This convenience function returns the contents of that collection. 2695 2696 An alternative to global variables are local variables. See 2697 `tf.local_variables` 2698 2699 Args: 2700 scope: (Optional.) A string. If supplied, the resulting list is filtered 2701 to include only items whose `name` attribute matches `scope` using 2702 `re.match`. Items without a `name` attribute are never returned if a 2703 scope is supplied. The choice of `re.match` means that a `scope` without 2704 special tokens filters by prefix. 2705 2706 Returns: 2707 A list of `Variable` objects. 2708 """ 2709 return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) 2710 2711 2712@tf_export(v1=["all_variables"]) 2713@deprecated("2017-03-02", "Please use tf.global_variables instead.") 2714def all_variables(): 2715 """See `tf.global_variables`.""" 2716 return global_variables() 2717 2718 2719def _all_saveable_objects(scope=None): 2720 """Returns all variables and `SaveableObject`s that must be checkpointed. 2721 2722 Args: 2723 scope: (Optional.) A string. If supplied, the resulting list is filtered 2724 to include only items whose `name` attribute matches `scope` using 2725 `re.match`. Items without a `name` attribute are never returned if a 2726 scope is supplied. The choice of `re.match` means that a `scope` without 2727 special tokens filters by prefix. 2728 2729 Returns: 2730 A list of `Variable` and `SaveableObject` to be checkpointed 2731 """ 2732 # TODO(andreasst): make this function public once things are settled. 2733 return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) + 2734 ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope)) 2735 2736 2737@tf_export(v1=["local_variables"]) 2738def local_variables(scope=None): 2739 """Returns local variables. 2740 2741 Local variables - per process variables, usually not saved/restored to 2742 checkpoint and used for temporary or intermediate values. 2743 For example, they can be used as counters for metrics computation or 2744 number of epochs this machine has read data. 2745 The `tf.contrib.framework.local_variable()` function automatically adds the 2746 new variable to `GraphKeys.LOCAL_VARIABLES`. 2747 This convenience function returns the contents of that collection. 2748 2749 An alternative to local variables are global variables. See 2750 `tf.global_variables` 2751 2752 Args: 2753 scope: (Optional.) A string. If supplied, the resulting list is filtered 2754 to include only items whose `name` attribute matches `scope` using 2755 `re.match`. Items without a `name` attribute are never returned if a 2756 scope is supplied. The choice of `re.match` means that a `scope` without 2757 special tokens filters by prefix. 2758 2759 Returns: 2760 A list of local `Variable` objects. 2761 """ 2762 return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES, scope) 2763 2764 2765@tf_export(v1=["model_variables"]) 2766def model_variables(scope=None): 2767 """Returns all variables in the MODEL_VARIABLES collection. 2768 2769 Args: 2770 scope: (Optional.) A string. If supplied, the resulting list is filtered 2771 to include only items whose `name` attribute matches `scope` using 2772 `re.match`. Items without a `name` attribute are never returned if a 2773 scope is supplied. The choice of `re.match` means that a `scope` without 2774 special tokens filters by prefix. 2775 2776 Returns: 2777 A list of local Variable objects. 2778 """ 2779 return ops.get_collection(ops.GraphKeys.MODEL_VARIABLES, scope) 2780 2781 2782@tf_export(v1=["trainable_variables"]) 2783def trainable_variables(scope=None): 2784 """Returns all variables created with `trainable=True`. 2785 2786 When passed `trainable=True`, the `Variable()` constructor automatically 2787 adds new variables to the graph collection 2788 `GraphKeys.TRAINABLE_VARIABLES`. This convenience function returns the 2789 contents of that collection. 2790 2791 Args: 2792 scope: (Optional.) A string. If supplied, the resulting list is filtered 2793 to include only items whose `name` attribute matches `scope` using 2794 `re.match`. Items without a `name` attribute are never returned if a 2795 scope is supplied. The choice of `re.match` means that a `scope` without 2796 special tokens filters by prefix. 2797 2798 Returns: 2799 A list of Variable objects. 2800 """ 2801 return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, scope) 2802 2803 2804@tf_export(v1=["moving_average_variables"]) 2805def moving_average_variables(scope=None): 2806 """Returns all variables that maintain their moving averages. 2807 2808 If an `ExponentialMovingAverage` object is created and the `apply()` 2809 method is called on a list of variables, these variables will 2810 be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection. 2811 This convenience function returns the contents of that collection. 2812 2813 Args: 2814 scope: (Optional.) A string. If supplied, the resulting list is filtered 2815 to include only items whose `name` attribute matches `scope` using 2816 `re.match`. Items without a `name` attribute are never returned if a 2817 scope is supplied. The choice of `re.match` means that a `scope` without 2818 special tokens filters by prefix. 2819 2820 Returns: 2821 A list of Variable objects. 2822 """ 2823 return ops.get_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, scope) 2824 2825 2826@tf_export(v1=["initializers.variables", "variables_initializer"]) 2827def variables_initializer(var_list, name="init"): 2828 """Returns an Op that initializes a list of variables. 2829 2830 After you launch the graph in a session, you can run the returned Op to 2831 initialize all the variables in `var_list`. This Op runs all the 2832 initializers of the variables in `var_list` in parallel. 2833 2834 Calling `initialize_variables()` is equivalent to passing the list of 2835 initializers to `Group()`. 2836 2837 If `var_list` is empty, however, the function still returns an Op that can 2838 be run. That Op just has no effect. 2839 2840 Args: 2841 var_list: List of `Variable` objects to initialize. 2842 name: Optional name for the returned operation. 2843 2844 Returns: 2845 An Op that run the initializers of all the specified variables. 2846 """ 2847 if var_list and not context.executing_eagerly(): 2848 return control_flow_ops.group(*[v.initializer for v in var_list], name=name) 2849 return control_flow_ops.no_op(name=name) 2850 2851 2852@tf_export(v1=["initialize_variables"]) 2853@tf_should_use.should_use_result 2854@deprecated("2017-03-02", "Use `tf.variables_initializer` instead.") 2855def initialize_variables(var_list, name="init"): 2856 """See `tf.variables_initializer`.""" 2857 return variables_initializer(var_list, name=name) 2858 2859 2860@tf_export(v1=["initializers.global_variables", "global_variables_initializer"]) 2861def global_variables_initializer(): 2862 """Returns an Op that initializes global variables. 2863 2864 This is just a shortcut for `variables_initializer(global_variables())` 2865 2866 Returns: 2867 An Op that initializes global variables in the graph. 2868 """ 2869 if context.executing_eagerly(): 2870 return control_flow_ops.no_op(name="global_variables_initializer") 2871 return variables_initializer(global_variables()) 2872 2873 2874@tf_export(v1=["initialize_all_variables"]) 2875@tf_should_use.should_use_result 2876@deprecated("2017-03-02", "Use `tf.global_variables_initializer` instead.") 2877def initialize_all_variables(): 2878 """See `tf.global_variables_initializer`.""" 2879 return global_variables_initializer() 2880 2881 2882@tf_export(v1=["initializers.local_variables", "local_variables_initializer"]) 2883def local_variables_initializer(): 2884 """Returns an Op that initializes all local variables. 2885 2886 This is just a shortcut for `variables_initializer(local_variables())` 2887 2888 Returns: 2889 An Op that initializes all local variables in the graph. 2890 """ 2891 if context.executing_eagerly(): 2892 return control_flow_ops.no_op(name="local_variables_initializer") 2893 return variables_initializer(local_variables()) 2894 2895 2896@tf_export(v1=["initialize_local_variables"]) 2897@tf_should_use.should_use_result 2898@deprecated("2017-03-02", "Use `tf.local_variables_initializer` instead.") 2899def initialize_local_variables(): 2900 """See `tf.local_variables_initializer`.""" 2901 return local_variables_initializer() 2902 2903 2904@tf_export(v1=["is_variable_initialized"]) 2905@tf_should_use.should_use_result 2906def is_variable_initialized(variable): 2907 """Tests if a variable has been initialized. 2908 2909 Args: 2910 variable: A `Variable`. 2911 2912 Returns: 2913 Returns a scalar boolean Tensor, `True` if the variable has been 2914 initialized, `False` otherwise. 2915 """ 2916 return state_ops.is_variable_initialized(variable) 2917 2918 2919@tf_export(v1=["assert_variables_initialized"]) 2920@tf_should_use.should_use_result 2921def assert_variables_initialized(var_list=None): 2922 """Returns an Op to check if variables are initialized. 2923 2924 NOTE: This function is obsolete and will be removed in 6 months. Please 2925 change your implementation to use `report_uninitialized_variables()`. 2926 2927 When run, the returned Op will raise the exception `FailedPreconditionError` 2928 if any of the variables has not yet been initialized. 2929 2930 Note: This function is implemented by trying to fetch the values of the 2931 variables. If one of the variables is not initialized a message may be 2932 logged by the C++ runtime. This is expected. 2933 2934 Args: 2935 var_list: List of `Variable` objects to check. Defaults to the 2936 value of `global_variables().` 2937 2938 Returns: 2939 An Op, or None if there are no variables. 2940 """ 2941 if var_list is None: 2942 var_list = global_variables() + local_variables() 2943 # Backwards compatibility for old-style variables. TODO(touts): remove. 2944 if not var_list: 2945 var_list = [] 2946 for op in ops.get_default_graph().get_operations(): 2947 if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]: 2948 var_list.append(op.outputs[0]) 2949 if not var_list: 2950 return None 2951 else: 2952 ranks = [] 2953 for var in var_list: 2954 with ops.colocate_with(var.op): 2955 ranks.append(array_ops.rank_internal(var, optimize=False)) 2956 if len(ranks) == 1: 2957 return ranks[0] 2958 else: 2959 return array_ops.stack(ranks) 2960 2961 2962@tf_export(v1=["report_uninitialized_variables"]) 2963@tf_should_use.should_use_result 2964def report_uninitialized_variables(var_list=None, 2965 name="report_uninitialized_variables"): 2966 """Adds ops to list the names of uninitialized variables. 2967 2968 When run, it returns a 1-D tensor containing the names of uninitialized 2969 variables if there are any, or an empty array if there are none. 2970 2971 Args: 2972 var_list: List of `Variable` objects to check. Defaults to the 2973 value of `global_variables() + local_variables()` 2974 name: Optional name of the `Operation`. 2975 2976 Returns: 2977 A 1-D tensor containing names of the uninitialized variables, or an empty 2978 1-D tensor if there are no variables or no uninitialized variables. 2979 """ 2980 if var_list is None: 2981 var_list = global_variables() + local_variables() 2982 # Backwards compatibility for old-style variables. TODO(touts): remove. 2983 if not var_list: 2984 var_list = [] 2985 for op in ops.get_default_graph().get_operations(): 2986 if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]: 2987 var_list.append(op.outputs[0]) 2988 with ops.name_scope(name): 2989 # Run all operations on CPU 2990 if var_list: 2991 init_vars = [state_ops.is_variable_initialized(v) for v in var_list] 2992 local_device = os.environ.get( 2993 "TF_DEVICE_FOR_UNINITIALIZED_VARIABLE_REPORTING", "/cpu:0") 2994 with ops.device(local_device): 2995 if not var_list: 2996 # Return an empty tensor so we only need to check for returned tensor 2997 # size being 0 as an indication of model ready. 2998 return array_ops.constant([], dtype=dtypes.string) 2999 else: 3000 # Get a 1-D boolean tensor listing whether each variable is initialized. 3001 variables_mask = math_ops.logical_not(array_ops.stack(init_vars)) 3002 # Get a 1-D string tensor containing all the variable names. 3003 variable_names_tensor = array_ops.constant( 3004 [s.op.name for s in var_list]) 3005 # Return a 1-D tensor containing all the names of 3006 # uninitialized variables. 3007 return array_ops.boolean_mask(variable_names_tensor, variables_mask) 3008 3009 3010ops.register_tensor_conversion_function( 3011 PartitionedVariable, 3012 PartitionedVariable._TensorConversionFunction) # pylint: disable=protected-access 3013