1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ops to use variables as resources.""" 16 17# pylint: disable=g-bad-name 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import contextlib 23import functools 24import weakref 25 26import numpy as np 27 28from tensorflow.core.framework import attr_value_pb2 29from tensorflow.core.framework import variable_pb2 30from tensorflow.python.client import pywrap_tf_session 31from tensorflow.python.eager import context 32from tensorflow.python.eager import tape 33from tensorflow.python.framework import auto_control_deps_utils as acd 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import cpp_shape_inference_pb2 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import errors 38from tensorflow.python.framework import ops 39from tensorflow.python.framework import tensor_shape 40from tensorflow.python.framework import tensor_spec 41from tensorflow.python.ops import array_ops 42from tensorflow.python.ops import gen_array_ops 43from tensorflow.python.ops import gen_resource_variable_ops 44from tensorflow.python.ops import gen_state_ops 45from tensorflow.python.ops import handle_data_util 46from tensorflow.python.ops import math_ops 47from tensorflow.python.ops import state_ops 48from tensorflow.python.ops import variables 49# go/tf-wildcard-import 50# pylint: disable=wildcard-import 51from tensorflow.python.ops.gen_resource_variable_ops import * 52# pylint: enable=wildcard-import 53from tensorflow.python.training.tracking import base as trackable 54from tensorflow.python.types import core 55from tensorflow.python.util import _pywrap_utils 56from tensorflow.python.util import compat 57from tensorflow.python.util.deprecation import deprecated 58from tensorflow.python.util.tf_export import tf_export 59 60acd.register_read_only_resource_op("ReadVariableOp") 61acd.register_read_only_resource_op("VariableShape") 62acd.register_read_only_resource_op("ResourceGather") 63acd.register_read_only_resource_op("ResourceGatherNd") 64acd.register_read_only_resource_op("_ReadVariablesOp") 65 66 67# TODO(allenl): Remove this alias and migrate callers. 68get_resource_handle_data = handle_data_util.get_resource_handle_data 69 70 71def get_eager_safe_handle_data(handle): 72 """Get the data handle from the Tensor `handle`.""" 73 assert isinstance(handle, ops.Tensor) 74 75 if isinstance(handle, ops.EagerTensor): 76 return handle._handle_data # pylint: disable=protected-access 77 else: 78 return get_resource_handle_data(handle) 79 80 81def _set_handle_shapes_and_types(tensor, handle_data, graph_mode): 82 """Sets the shape inference result HandleData on tensor. 83 84 Args: 85 tensor: A `Tensor` or `EagerTensor`. 86 handle_data: A `CppShapeInferenceResult.HandleData`. 87 graph_mode: A python bool. 88 """ 89 tensor._handle_data = handle_data # pylint: disable=protected-access 90 if not graph_mode: 91 return 92 93 # Not an EagerTensor, so a graph tensor. 94 shapes, types = zip(*[(pair.shape, pair.dtype) 95 for pair in handle_data.shape_and_type]) 96 ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] 97 shapes = [ 98 [d.size for d in s.dim] # pylint: disable=g-complex-comprehension 99 if not s.unknown_rank else None for s in shapes 100 ] 101 pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper( 102 tensor._op._graph._c_graph, # pylint: disable=protected-access 103 tensor._as_tf_output(), # pylint: disable=protected-access 104 shapes, 105 ranks, 106 types) 107 108 109def _combine_handle_data(handle, initial_value): 110 """Concats HandleData from tensors `handle` and `initial_value`. 111 112 Args: 113 handle: A `Tensor` of dtype `resource`. 114 initial_value: A `Tensor`. 115 116 Returns: 117 A `CppShapeInferenceResult.HandleData`. If `initial_value` has dtype 118 `variant`, the `HandleData` contains the concatenation of the shape_and_type 119 from both `handle` and `initial_value`. 120 121 Raises: 122 RuntimeError: If handle, which was returned by VarHandleOp, either has 123 no handle data, or its len(handle_data.shape_and_type) != 1. 124 """ 125 assert handle.dtype == dtypes.resource 126 127 variable_handle_data = get_eager_safe_handle_data(handle) 128 129 if initial_value.dtype != dtypes.variant: 130 return variable_handle_data 131 132 extra_handle_data = get_eager_safe_handle_data(initial_value) 133 if extra_handle_data is not None and extra_handle_data.is_set: 134 if (variable_handle_data is None or not variable_handle_data.is_set or 135 len(variable_handle_data.shape_and_type) != 1): 136 raise RuntimeError( 137 "Expected VarHandleOp to return a length==1 shape_and_type, " 138 "but saw: '%s'" % (variable_handle_data,)) 139 variable_handle_data.shape_and_type.extend(extra_handle_data.shape_and_type) 140 return variable_handle_data 141 142 143def _variable_handle_from_shape_and_dtype(shape, 144 dtype, 145 shared_name, 146 name, 147 graph_mode, 148 initial_value=None): 149 """Create a variable handle, copying in handle data from `initial_value`.""" 150 container = ops.get_default_graph()._container # pylint: disable=protected-access 151 if container is None: 152 container = "" 153 shape = tensor_shape.as_shape(shape) 154 dtype = dtypes.as_dtype(dtype) 155 if not graph_mode: 156 if shared_name is not None: 157 raise errors.InternalError( 158 "Using an explicit shared_name is not supported executing eagerly.") 159 shared_name = context.shared_name() 160 161 handle = gen_resource_variable_ops.var_handle_op( 162 shape=shape, 163 dtype=dtype, 164 shared_name=shared_name, 165 name=name, 166 container=container) 167 if initial_value is None: 168 initial_value = handle 169 if graph_mode: 170 full_handle_data = _combine_handle_data(handle, initial_value) 171 _set_handle_shapes_and_types(handle, full_handle_data, graph_mode) 172 return handle 173 else: 174 handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData() 175 handle_data.is_set = True 176 handle_data.shape_and_type.append( 177 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType( 178 shape=shape.as_proto(), dtype=dtype.as_datatype_enum)) 179 180 if initial_value is not None and initial_value.dtype == dtypes.variant: 181 extra_handle_data = get_eager_safe_handle_data(initial_value) 182 if extra_handle_data is not None and extra_handle_data.is_set: 183 if (not handle_data.is_set or len(handle_data.shape_and_type) != 1): 184 raise RuntimeError( 185 "Expected VarHandleOp to return a length==1 shape_and_type, " 186 "but saw: '%s'" % (handle_data,)) 187 handle_data.shape_and_type.extend(extra_handle_data.shape_and_type) 188 189 _set_handle_shapes_and_types(handle, handle_data, graph_mode) 190 return handle 191 192 193def eager_safe_variable_handle(initial_value, shape, shared_name, name, 194 graph_mode): 195 """Creates a variable handle with information to do shape inference. 196 197 The dtype is read from `initial_value` and stored in the returned 198 resource tensor's handle data. 199 200 If `initial_value.dtype == tf.variant`, we additionally extract the handle 201 data (if any) from `initial_value` and append it to the `handle_data`. 202 In this case, the returned tensor's handle data is in the form 203 204 ``` 205 is_set: true 206 shape_and_type { 207 shape { 208 // initial_value.shape 209 } 210 dtype: DT_VARIANT 211 } 212 shape_and_type { 213 // handle_data(initial_value).shape_and_type[0] 214 } 215 shape_and_type { 216 // handle_data(initial_value).shape_and_type[1] 217 } 218 ... 219 ``` 220 221 Ops that read from this tensor, such as `ReadVariableOp` and 222 `AssignVariableOp`, know that `handle_data(handle).shape_and_type[1:]` 223 correspond to the handle data of the variant(s) stored in the Variable. 224 225 Args: 226 initial_value: A `Tensor`. 227 shape: The shape of the handle data. Can be `TensorShape(None)` (i.e. 228 unknown shape). 229 shared_name: A string. 230 name: A string. 231 graph_mode: A python bool. 232 233 Returns: 234 The handle, a `Tensor` of type `resource`. 235 """ 236 dtype = initial_value.dtype.base_dtype 237 return _variable_handle_from_shape_and_dtype(shape, dtype, shared_name, name, 238 graph_mode, initial_value) 239 240 241@contextlib.contextmanager 242def _handle_graph(handle): 243 # Note: might have an eager tensor but not be executing eagerly when building 244 # functions. 245 if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor) or 246 ops.has_default_graph()): 247 yield 248 else: 249 with handle.graph.as_default(): 250 yield 251 252 253class EagerResourceDeleter(object): 254 """An object which cleans up a resource handle. 255 256 An alternative to defining a __del__ method on an object. The intended use is 257 that ResourceVariables or other objects with resource handles will maintain a 258 single reference to this object. When the parent object is collected, this 259 object will be too. Even if the parent object is part of a reference cycle, 260 the cycle will be collectable. 261 """ 262 263 __slots__ = ["_handle", "_handle_device", "_context"] 264 265 def __init__(self, handle, handle_device): 266 if not isinstance(handle, ops.Tensor): 267 raise ValueError( 268 ("Passed handle=%s to EagerResourceDeleter. Was expecting a handle " 269 "Tensor." % (handle,))) 270 self._handle = handle 271 self._handle_device = handle_device 272 # This is held since the __del__ function runs an op, and if the context() 273 # is collected before this object, there will be a segfault when running the 274 # op. 275 self._context = context.context() 276 277 def __del__(self): 278 # Resources follow object-identity when executing eagerly, so it is safe to 279 # delete the resource we have a handle to. 280 try: 281 # A packed EagerTensor doesn't own any resource. 282 if isinstance(self._handle, ops.EagerTensor) and self._handle.is_packed: 283 return 284 # This resource was created in eager mode. However, this destructor may be 285 # running in graph mode (especially during unit tests). To clean up 286 # successfully, we switch back into eager mode temporarily. 287 with context.eager_mode(): 288 with ops.device(self._handle_device): 289 gen_resource_variable_ops.destroy_resource_op( 290 self._handle, ignore_lookup_error=True) 291 except TypeError: 292 # Suppress some exceptions, mainly for the case when we're running on 293 # module deletion. Things that can go wrong include the context module 294 # already being unloaded, self._handle._handle_data no longer being 295 # valid, and so on. Printing warnings in these cases is silly 296 # (exceptions raised from __del__ are printed as warnings to stderr). 297 pass # 'NoneType' object is not callable when the handle has been 298 # partially unloaded. 299 except AttributeError: 300 pass # 'NoneType' object has no attribute 'eager_mode' when context has 301 # been unloaded. Will catch other module unloads as well. 302 303 304def shape_safe_assign_variable_handle(handle, shape, value, name=None): 305 """Helper that checks shape compatibility and assigns variable.""" 306 with _handle_graph(handle): 307 value_tensor = ops.convert_to_tensor(value) 308 shape.assert_is_compatible_with(value_tensor.shape) 309 return gen_resource_variable_ops.assign_variable_op( 310 handle, value_tensor, name=name) 311 312 313def _maybe_set_handle_data(dtype, handle, tensor): 314 if dtype == dtypes.variant: 315 # For DT_VARIANT types, the handle's shape_and_type[1:] stores the 316 # variant's handle data. Extract it. 317 handle_data = get_eager_safe_handle_data(handle) 318 if handle_data.is_set and len(handle_data.shape_and_type) > 1: 319 tensor._handle_data = ( # pylint: disable=protected-access 320 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData( 321 is_set=True, shape_and_type=handle_data.shape_and_type[1:])) 322 323 324def variable_accessed(variable): 325 """Records that `variable` was accessed for the tape and FuncGraph.""" 326 if hasattr(ops.get_default_graph(), "watch_variable"): 327 ops.get_default_graph().watch_variable(variable) 328 if variable.trainable: 329 tape.variable_accessed(variable) 330 331 332class BaseResourceVariable(variables.VariableV1, core.Tensor): 333 """A python variable from an existing handle.""" 334 335 # TODO(wangpeng): Deprecate `constraint` when callers no long pass it in. 336 def __init__( # pylint: disable=super-init-not-called 337 self, 338 trainable=None, 339 shape=None, 340 dtype=None, 341 handle=None, 342 constraint=None, 343 synchronization=None, 344 aggregation=None, 345 distribute_strategy=None, 346 name=None, 347 unique_id=None, 348 handle_name=None, 349 graph_element=None, 350 initial_value=None, 351 initializer_op=None, 352 is_initialized_op=None, 353 cached_value=None, 354 save_slice_info=None, 355 handle_deleter=None, 356 caching_device=None, 357 **unused_kwargs): 358 """Creates a variable from a handle. 359 360 Args: 361 trainable: If `True`, GradientTapes automatically watch uses of this 362 Variable. 363 shape: The variable's shape. 364 dtype: The variable's dtype. 365 handle: The variable's handle 366 constraint: An optional projection function to be applied to the variable 367 after being updated by an `Optimizer` (e.g. used to implement norm 368 constraints or value constraints for layer weights). The function must 369 take as input the unprojected Tensor representing the value of the 370 variable and return the Tensor for the projected value (which must have 371 the same shape). Constraints are not safe to use when doing asynchronous 372 distributed training. 373 synchronization: Indicates when a distributed a variable will be 374 aggregated. Accepted values are constants defined in the class 375 `tf.VariableSynchronization`. By default the synchronization is set to 376 `AUTO` and the current `DistributionStrategy` chooses when to 377 synchronize. 378 aggregation: Indicates how a distributed variable will be aggregated. 379 Accepted values are constants defined in the class 380 `tf.VariableAggregation`. 381 distribute_strategy: The distribution strategy this variable was created 382 under. 383 name: The name for this variable. 384 unique_id: Internal. Unique ID for this variable's handle. 385 handle_name: The name for the variable's handle. 386 graph_element: Optional, required only in session.run-mode. Pre-created 387 tensor which reads this variable's value. 388 initial_value: Optional. Variable's initial value. 389 initializer_op: Operation which assigns the variable's initial value. 390 is_initialized_op: Pre-created operation to check whether this variable is 391 initialized. 392 cached_value: Pre-created operation to read this variable in a specific 393 device. 394 save_slice_info: Metadata for variable partitioning. 395 handle_deleter: EagerResourceDeleter responsible for cleaning up the 396 handle. 397 caching_device: Optional device string or function describing where the 398 Variable should be cached for reading. Defaults to the Variable's 399 device. If not `None`, caches on another device. Typical use is to 400 cache on the device where the Ops using the Variable reside, to 401 deduplicate copying through `Switch` and other conditional statements. 402 """ 403 with ops.init_scope(): 404 self._in_graph_mode = not context.executing_eagerly() 405 synchronization, aggregation, trainable = ( 406 variables.validate_synchronization_aggregation_trainable( 407 synchronization, aggregation, trainable, name)) 408 self._trainable = trainable 409 self._synchronization = synchronization 410 self._aggregation = aggregation 411 self._save_slice_info = save_slice_info 412 self._initial_value = initial_value 413 self._initializer_op = initializer_op 414 self._is_initialized_op = is_initialized_op 415 self._graph_element = graph_element 416 self._caching_device = caching_device 417 self._cached_value = cached_value 418 self._distribute_strategy = distribute_strategy 419 # Store the graph key so optimizers know how to only retrieve variables from 420 # this graph. Guaranteed to be the same as the eager graph_key. 421 self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 422 self._shape = tensor_shape.as_shape(shape) 423 self._dtype = dtypes.as_dtype(dtype) 424 self._handle = handle 425 self._unique_id = unique_id 426 self._handle_name = handle_name + ":0" 427 self._constraint = constraint 428 # After the handle has been created, set up a way to clean it up when 429 # executing eagerly. We'll hold the only reference to the deleter, so that 430 # when this object is garbage collected the deleter will be too. This 431 # means ResourceVariables can be part of reference cycles without those 432 # cycles being uncollectable. 433 if not self._in_graph_mode: 434 if handle_deleter is None: 435 handle_deleter = EagerResourceDeleter( 436 handle=self._handle, handle_device=self._handle.device) 437 self._handle_deleter = handle_deleter 438 self._cached_shape_as_list = None 439 440 def __repr__(self): 441 if context.executing_eagerly() and not self._in_graph_mode: 442 # If we cannot read the value for any reason, still produce a __repr__. 443 try: 444 value_text = ops.numpy_text(self.read_value(), is_repr=True) 445 except: # pylint: disable=bare-except 446 value_text = "<unavailable>" 447 448 return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % ( 449 self.name, self.get_shape(), self.dtype.name, value_text) 450 else: 451 return "<tf.Variable '%s' shape=%s dtype=%s>" % ( 452 self.name, self.get_shape(), self.dtype.name) 453 454 @contextlib.contextmanager 455 def _assign_dependencies(self): 456 """Makes assignments depend on the cached value, if any. 457 458 This prevents undefined behavior with reads not ordered wrt writes. 459 460 Yields: 461 None. 462 """ 463 if self._cached_value is not None: 464 with ops.control_dependencies([self._cached_value]): 465 yield 466 else: 467 yield 468 469 def __array__(self): 470 """Allows direct conversion to a numpy array. 471 472 >>> np.array(tf.Variable([1.0])) 473 array([1.], dtype=float32) 474 475 Returns: 476 The variable value as a numpy array. 477 """ 478 # You can't return `self.numpy()` here because for scalars 479 # that raises: 480 # ValueError: object __array__ method not producing an array 481 # Even `self.read_value().__array__()` and `self.read_value()._numpy()` give 482 # the same error. The `EagerTensor` class must be doing something behind the 483 # scenes to make `np.array(tf.constant(1))` work. 484 return np.asarray(self.numpy()) 485 486 def __nonzero__(self): 487 return self.__bool__() 488 489 def __bool__(self): 490 return bool(self.read_value()) 491 492 def __copy__(self): 493 return self 494 495 def __deepcopy__(self, memo): 496 if not context.executing_eagerly(): 497 raise NotImplementedError( 498 "__deepcopy__() is only available when eager execution is enabled.") 499 copied_variable = ResourceVariable( 500 initial_value=self.read_value(), 501 trainable=self._trainable, 502 constraint=self._constraint, 503 dtype=self._dtype, 504 name=self._shared_name, 505 distribute_strategy=self._distribute_strategy, 506 synchronization=self.synchronization, 507 aggregation=self.aggregation) 508 memo[self._unique_id] = copied_variable 509 return copied_variable 510 511 @property 512 def dtype(self): 513 """The dtype of this variable.""" 514 return self._dtype 515 516 @property 517 def device(self): 518 """The device this variable is on.""" 519 return self._handle.device 520 521 @property 522 def graph(self): 523 """The `Graph` of this variable.""" 524 return self._handle.graph 525 526 @property 527 def name(self): 528 """The name of the handle for this variable.""" 529 return self._handle_name 530 531 @property 532 def shape(self): 533 """The shape of this variable.""" 534 return self._shape 535 536 def set_shape(self, shape): 537 self._shape = self._shape.merge_with(shape) 538 539 def _shape_as_list(self): 540 if self.shape.ndims is None: 541 return None 542 return [dim.value for dim in self.shape.dims] 543 544 def _shape_tuple(self): 545 shape = self._shape_as_list() 546 if shape is None: 547 return None 548 return tuple(shape) 549 550 @property 551 def create(self): 552 """The op responsible for initializing this variable.""" 553 if not self._in_graph_mode: 554 raise RuntimeError("Calling create is not supported when eager execution" 555 " is enabled.") 556 return self._initializer_op 557 558 @property 559 def handle(self): 560 """The handle by which this variable can be accessed.""" 561 return self._handle 562 563 def value(self): 564 """A cached operation which reads the value of this variable.""" 565 if self._cached_value is not None: 566 return self._cached_value 567 with ops.colocate_with(None, ignore_existing=True): 568 return self._read_variable_op() 569 570 def _as_graph_element(self): 571 """Conversion function for Graph.as_graph_element().""" 572 return self._graph_element 573 574 @property 575 def initializer(self): 576 """The op responsible for initializing this variable.""" 577 return self._initializer_op 578 579 @property 580 def initial_value(self): 581 """Returns the Tensor used as the initial value for the variable.""" 582 if context.executing_eagerly(): 583 raise RuntimeError("initial_value not supported in EAGER mode.") 584 return self._initial_value 585 586 @property 587 def constraint(self): 588 """Returns the constraint function associated with this variable. 589 590 Returns: 591 The constraint function that was passed to the variable constructor. 592 Can be `None` if no constraint was passed. 593 """ 594 return self._constraint 595 596 @property 597 def op(self): 598 """The op for this variable.""" 599 return self._handle.op 600 601 @property 602 def trainable(self): 603 return self._trainable 604 605 @property 606 def synchronization(self): 607 return self._synchronization 608 609 @property 610 def aggregation(self): 611 return self._aggregation 612 613 def eval(self, session=None): 614 """Evaluates and returns the value of this variable.""" 615 if context.executing_eagerly(): 616 raise RuntimeError("Trying to eval in EAGER mode") 617 return self._graph_element.eval(session=session) 618 619 def numpy(self): 620 if context.executing_eagerly(): 621 return self.read_value().numpy() 622 raise NotImplementedError( 623 "numpy() is only available when eager execution is enabled.") 624 625 @deprecated(None, "Prefer Dataset.range instead.") 626 def count_up_to(self, limit): 627 """Increments this variable until it reaches `limit`. 628 629 When that Op is run it tries to increment the variable by `1`. If 630 incrementing the variable would bring it above `limit` then the Op raises 631 the exception `OutOfRangeError`. 632 633 If no error is raised, the Op outputs the value of the variable before 634 the increment. 635 636 This is essentially a shortcut for `count_up_to(self, limit)`. 637 638 Args: 639 limit: value at which incrementing the variable raises an error. 640 641 Returns: 642 A `Tensor` that will hold the variable value before the increment. If no 643 other Op modifies this variable, the values produced will all be 644 distinct. 645 """ 646 return gen_state_ops.resource_count_up_to( 647 self.handle, limit=limit, T=self.dtype) 648 649 def _map_resources(self, save_options): 650 """For implementing `Trackable`.""" 651 new_variable = None 652 if save_options.experimental_variable_policy._save_variable_devices(): # pylint:disable=protected-access 653 with ops.device(self.device): 654 new_variable = copy_to_graph_uninitialized(self) 655 else: 656 new_variable = copy_to_graph_uninitialized(self) 657 obj_map = {self: new_variable} 658 resource_map = {self._handle: new_variable.handle} 659 return obj_map, resource_map 660 661 def _read_variable_op(self): 662 variable_accessed(self) 663 664 def read_and_set_handle(): 665 result = gen_resource_variable_ops.read_variable_op( 666 self._handle, self._dtype) 667 _maybe_set_handle_data(self._dtype, self._handle, result) 668 return result 669 670 if getattr(self, "_caching_device", None) is not None: 671 with ops.colocate_with(None, ignore_existing=True): 672 with ops.device(self._caching_device): 673 result = read_and_set_handle() 674 else: 675 result = read_and_set_handle() 676 677 if not context.executing_eagerly(): 678 # Note that if a control flow context is active the input of the read op 679 # might not actually be the handle. This line bypasses it. 680 tape.record_operation( 681 "ReadVariableOp", [result], [self._handle], 682 backward_function=lambda x: [x], 683 forward_function=lambda x: [x]) 684 return result 685 686 def read_value(self): 687 """Constructs an op which reads the value of this variable. 688 689 Should be used when there are multiple reads, or when it is desirable to 690 read the value only after some condition is true. 691 692 Returns: 693 the read operation. 694 """ 695 with ops.name_scope("Read"): 696 value = self._read_variable_op() 697 # Return an identity so it can get placed on whatever device the context 698 # specifies instead of the device where the variable is. 699 return array_ops.identity(value) 700 701 def sparse_read(self, indices, name=None): 702 """Reads the value of this variable sparsely, using `gather`.""" 703 with ops.name_scope("Gather" if name is None else name) as name: 704 variable_accessed(self) 705 value = gen_resource_variable_ops.resource_gather( 706 self._handle, indices, dtype=self._dtype, name=name) 707 708 if self._dtype == dtypes.variant: 709 # For DT_VARIANT types, the handle's shape_and_type[1:] stores the 710 # variant's handle data. Extract it. 711 handle_data = get_eager_safe_handle_data(self._handle) 712 if handle_data.is_set and len(handle_data.shape_and_type) > 1: 713 value._handle_data = ( # pylint: disable=protected-access 714 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData( 715 is_set=True, shape_and_type=handle_data.shape_and_type[1:])) 716 717 return array_ops.identity(value) 718 719 def gather_nd(self, indices, name=None): 720 """Reads the value of this variable sparsely, using `gather_nd`.""" 721 with ops.name_scope("GatherNd" if name is None else name) as name: 722 if self.trainable: 723 variable_accessed(self) 724 value = gen_resource_variable_ops.resource_gather_nd( 725 self._handle, indices, dtype=self._dtype, name=name) 726 727 return array_ops.identity(value) 728 729 def to_proto(self, export_scope=None): 730 """Converts a `ResourceVariable` to a `VariableDef` protocol buffer. 731 732 Args: 733 export_scope: Optional `string`. Name scope to remove. 734 735 Raises: 736 RuntimeError: If run in EAGER mode. 737 738 Returns: 739 A `VariableDef` protocol buffer, or `None` if the `Variable` is not 740 in the specified name scope. 741 """ 742 if context.executing_eagerly(): 743 raise RuntimeError("to_proto not supported in EAGER mode.") 744 if export_scope is None or self.handle.name.startswith(export_scope): 745 var_def = variable_pb2.VariableDef() 746 var_def.variable_name = ops.strip_name_scope(self.handle.name, 747 export_scope) 748 if self._initial_value is not None: 749 # This is inside an if-statement for backwards compatibility, since 750 # self._initial_value might be None for variables constructed from old 751 # protos. 752 var_def.initial_value_name = ops.strip_name_scope( 753 self._initial_value.name, export_scope) 754 var_def.initializer_name = ops.strip_name_scope(self.initializer.name, 755 export_scope) 756 if self._cached_value is not None: 757 var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name, 758 export_scope) 759 else: 760 # Store the graph_element here 761 var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name, 762 export_scope) 763 var_def.is_resource = True 764 var_def.trainable = self.trainable 765 var_def.synchronization = self.synchronization.value 766 var_def.aggregation = self.aggregation.value 767 if self._save_slice_info: 768 var_def.save_slice_info_def.MergeFrom( 769 self._save_slice_info.to_proto(export_scope=export_scope)) 770 return var_def 771 else: 772 return None 773 774 @staticmethod 775 def from_proto(variable_def, import_scope=None): 776 if context.executing_eagerly(): 777 raise RuntimeError("from_proto not supported in EAGER mode.") 778 return ResourceVariable( 779 variable_def=variable_def, import_scope=import_scope) 780 781 __array_priority__ = 100 782 783 def is_initialized(self, name=None): 784 """Checks whether a resource variable has been initialized. 785 786 Outputs boolean scalar indicating whether the tensor has been initialized. 787 788 Args: 789 name: A name for the operation (optional). 790 791 Returns: 792 A `Tensor` of type `bool`. 793 """ 794 # TODO(b/169792703): The current device placement logic never overrides an 795 # explicit placement with a custom device, causing `v.is_initalized()` to 796 # fail under a non-custom device context if `v` is in a custom device. The 797 # explicit placement below makes this work, but should not be necessary once 798 # the logic is updated to handle cases like this. 799 with ops.device(self.device): 800 return gen_resource_variable_ops.var_is_initialized_op(self.handle, name) 801 802 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 803 """Subtracts a value from this variable. 804 805 Args: 806 delta: A `Tensor`. The value to subtract from this variable. 807 use_locking: If `True`, use locking during the operation. 808 name: The name to use for the operation. 809 read_value: A `bool`. Whether to read and return the new value of the 810 variable or not. 811 812 Returns: 813 If `read_value` is `True`, this method will return the new value of the 814 variable after the assignment has completed. Otherwise, when in graph mode 815 it will return the `Operation` that does the assignment, and when in eager 816 mode it will return `None`. 817 """ 818 # TODO(apassos): this here and below is not atomic. Consider making it 819 # atomic if there's a way to do so without a performance cost for those who 820 # don't need it. 821 with _handle_graph(self.handle), self._assign_dependencies(): 822 assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op( 823 self.handle, 824 ops.convert_to_tensor(delta, dtype=self.dtype), 825 name=name) 826 if read_value: 827 return self._lazy_read(assign_sub_op) 828 return assign_sub_op 829 830 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 831 """Adds a value to this variable. 832 833 Args: 834 delta: A `Tensor`. The value to add to this variable. 835 use_locking: If `True`, use locking during the operation. 836 name: The name to use for the operation. 837 read_value: A `bool`. Whether to read and return the new value of the 838 variable or not. 839 840 Returns: 841 If `read_value` is `True`, this method will return the new value of the 842 variable after the assignment has completed. Otherwise, when in graph mode 843 it will return the `Operation` that does the assignment, and when in eager 844 mode it will return `None`. 845 """ 846 with _handle_graph(self.handle), self._assign_dependencies(): 847 assign_add_op = gen_resource_variable_ops.assign_add_variable_op( 848 self.handle, 849 ops.convert_to_tensor(delta, dtype=self.dtype), 850 name=name) 851 if read_value: 852 return self._lazy_read(assign_add_op) 853 return assign_add_op 854 855 def _lazy_read(self, op): 856 variable_accessed(self) 857 return _UnreadVariable( 858 handle=self._handle, 859 dtype=self.dtype, 860 shape=self._shape, 861 in_graph_mode=self._in_graph_mode, 862 deleter=self._handle_deleter if not self._in_graph_mode else None, 863 parent_op=op, 864 unique_id=self._unique_id) 865 866 def assign(self, value, use_locking=None, name=None, read_value=True): 867 """Assigns a new value to this variable. 868 869 Args: 870 value: A `Tensor`. The new value for this variable. 871 use_locking: If `True`, use locking during the assignment. 872 name: The name to use for the assignment. 873 read_value: A `bool`. Whether to read and return the new value of the 874 variable or not. 875 876 Returns: 877 If `read_value` is `True`, this method will return the new value of the 878 variable after the assignment has completed. Otherwise, when in graph mode 879 it will return the `Operation` that does the assignment, and when in eager 880 mode it will return `None`. 881 """ 882 # Note: not depending on the cached value here since this can be used to 883 # initialize the variable. 884 with _handle_graph(self.handle): 885 value_tensor = ops.convert_to_tensor(value, dtype=self.dtype) 886 if not self._shape.is_compatible_with(value_tensor.shape): 887 if self.name is None: 888 tensor_name = "" 889 else: 890 tensor_name = " " + str(self.name) 891 raise ValueError( 892 ("Cannot assign to variable%s due to variable shape %s and value " 893 "shape %s are incompatible") % 894 (tensor_name, self._shape, value_tensor.shape)) 895 assign_op = gen_resource_variable_ops.assign_variable_op( 896 self.handle, value_tensor, name=name) 897 if read_value: 898 return self._lazy_read(assign_op) 899 return assign_op 900 901 def __reduce__(self): 902 # The implementation mirrors that of __deepcopy__. 903 return functools.partial( 904 ResourceVariable, 905 initial_value=self.numpy(), 906 trainable=self.trainable, 907 name=self._shared_name, 908 dtype=self.dtype, 909 constraint=self.constraint, 910 distribute_strategy=self._distribute_strategy), () 911 912 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 913 """Subtracts `tf.IndexedSlices` from this variable. 914 915 Args: 916 sparse_delta: `tf.IndexedSlices` to be subtracted from this variable. 917 use_locking: If `True`, use locking during the operation. 918 name: the name of the operation. 919 920 Returns: 921 The updated variable. 922 923 Raises: 924 TypeError: if `sparse_delta` is not an `IndexedSlices`. 925 """ 926 if not isinstance(sparse_delta, ops.IndexedSlices): 927 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 928 return self._lazy_read( 929 gen_resource_variable_ops.resource_scatter_sub( 930 self.handle, 931 sparse_delta.indices, 932 ops.convert_to_tensor(sparse_delta.values, self.dtype), 933 name=name)) 934 935 def scatter_add(self, sparse_delta, use_locking=False, name=None): 936 """Adds `tf.IndexedSlices` to this variable. 937 938 Args: 939 sparse_delta: `tf.IndexedSlices` to be added to this variable. 940 use_locking: If `True`, use locking during the operation. 941 name: the name of the operation. 942 943 Returns: 944 The updated variable. 945 946 Raises: 947 TypeError: if `sparse_delta` is not an `IndexedSlices`. 948 """ 949 if not isinstance(sparse_delta, ops.IndexedSlices): 950 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 951 return self._lazy_read( 952 gen_resource_variable_ops.resource_scatter_add( 953 self.handle, 954 sparse_delta.indices, 955 ops.convert_to_tensor(sparse_delta.values, self.dtype), 956 name=name)) 957 958 def scatter_max(self, sparse_delta, use_locking=False, name=None): 959 """Updates this variable with the max of `tf.IndexedSlices` and itself. 960 961 Args: 962 sparse_delta: `tf.IndexedSlices` to use as an argument of max with this 963 variable. 964 use_locking: If `True`, use locking during the operation. 965 name: the name of the operation. 966 967 Returns: 968 The updated variable. 969 970 Raises: 971 TypeError: if `sparse_delta` is not an `IndexedSlices`. 972 """ 973 if not isinstance(sparse_delta, ops.IndexedSlices): 974 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 975 return self._lazy_read( 976 gen_resource_variable_ops.resource_scatter_max( 977 self.handle, 978 sparse_delta.indices, 979 ops.convert_to_tensor(sparse_delta.values, self.dtype), 980 name=name)) 981 982 def scatter_min(self, sparse_delta, use_locking=False, name=None): 983 """Updates this variable with the min of `tf.IndexedSlices` and itself. 984 985 Args: 986 sparse_delta: `tf.IndexedSlices` to use as an argument of min with this 987 variable. 988 use_locking: If `True`, use locking during the operation. 989 name: the name of the operation. 990 991 Returns: 992 The updated variable. 993 994 Raises: 995 TypeError: if `sparse_delta` is not an `IndexedSlices`. 996 """ 997 if not isinstance(sparse_delta, ops.IndexedSlices): 998 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 999 return self._lazy_read( 1000 gen_resource_variable_ops.resource_scatter_min( 1001 self.handle, 1002 sparse_delta.indices, 1003 ops.convert_to_tensor(sparse_delta.values, self.dtype), 1004 name=name)) 1005 1006 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 1007 """Multiply this variable by `tf.IndexedSlices`. 1008 1009 Args: 1010 sparse_delta: `tf.IndexedSlices` to multiply this variable by. 1011 use_locking: If `True`, use locking during the operation. 1012 name: the name of the operation. 1013 1014 Returns: 1015 The updated variable. 1016 1017 Raises: 1018 TypeError: if `sparse_delta` is not an `IndexedSlices`. 1019 """ 1020 if not isinstance(sparse_delta, ops.IndexedSlices): 1021 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 1022 return self._lazy_read( 1023 gen_resource_variable_ops.resource_scatter_mul( 1024 self.handle, 1025 sparse_delta.indices, 1026 ops.convert_to_tensor(sparse_delta.values, self.dtype), 1027 name=name)) 1028 1029 def scatter_div(self, sparse_delta, use_locking=False, name=None): 1030 """Divide this variable by `tf.IndexedSlices`. 1031 1032 Args: 1033 sparse_delta: `tf.IndexedSlices` to divide this variable by. 1034 use_locking: If `True`, use locking during the operation. 1035 name: the name of the operation. 1036 1037 Returns: 1038 The updated variable. 1039 1040 Raises: 1041 TypeError: if `sparse_delta` is not an `IndexedSlices`. 1042 """ 1043 if not isinstance(sparse_delta, ops.IndexedSlices): 1044 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 1045 return self._lazy_read( 1046 gen_resource_variable_ops.resource_scatter_div( 1047 self.handle, 1048 sparse_delta.indices, 1049 ops.convert_to_tensor(sparse_delta.values, self.dtype), 1050 name=name)) 1051 1052 def scatter_update(self, sparse_delta, use_locking=False, name=None): 1053 """Assigns `tf.IndexedSlices` to this variable. 1054 1055 Args: 1056 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 1057 use_locking: If `True`, use locking during the operation. 1058 name: the name of the operation. 1059 1060 Returns: 1061 The updated variable. 1062 1063 Raises: 1064 TypeError: if `sparse_delta` is not an `IndexedSlices`. 1065 """ 1066 if not isinstance(sparse_delta, ops.IndexedSlices): 1067 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 1068 return self._lazy_read( 1069 gen_resource_variable_ops.resource_scatter_update( 1070 self.handle, 1071 sparse_delta.indices, 1072 ops.convert_to_tensor(sparse_delta.values, self.dtype), 1073 name=name)) 1074 1075 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 1076 """Assigns `tf.IndexedSlices` to this variable batch-wise. 1077 1078 Analogous to `batch_gather`. This assumes that this variable and the 1079 sparse_delta IndexedSlices have a series of leading dimensions that are the 1080 same for all of them, and the updates are performed on the last dimension of 1081 indices. In other words, the dimensions should be the following: 1082 1083 `num_prefix_dims = sparse_delta.indices.ndims - 1` 1084 `batch_dim = num_prefix_dims + 1` 1085 `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[ 1086 batch_dim:]` 1087 1088 where 1089 1090 `sparse_delta.updates.shape[:num_prefix_dims]` 1091 `== sparse_delta.indices.shape[:num_prefix_dims]` 1092 `== var.shape[:num_prefix_dims]` 1093 1094 And the operation performed can be expressed as: 1095 1096 `var[i_1, ..., i_n, 1097 sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[ 1098 i_1, ..., i_n, j]` 1099 1100 When sparse_delta.indices is a 1D tensor, this operation is equivalent to 1101 `scatter_update`. 1102 1103 To avoid this operation one can looping over the first `ndims` of the 1104 variable and using `scatter_update` on the subtensors that result of slicing 1105 the first dimension. This is a valid option for `ndims = 1`, but less 1106 efficient than this implementation. 1107 1108 Args: 1109 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 1110 use_locking: If `True`, use locking during the operation. 1111 name: the name of the operation. 1112 1113 Returns: 1114 The updated variable. 1115 1116 Raises: 1117 TypeError: if `sparse_delta` is not an `IndexedSlices`. 1118 """ 1119 if not isinstance(sparse_delta, ops.IndexedSlices): 1120 raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 1121 return self._lazy_read( 1122 state_ops.batch_scatter_update( 1123 self, 1124 sparse_delta.indices, 1125 sparse_delta.values, 1126 use_locking=use_locking, 1127 name=name)) 1128 1129 def scatter_nd_sub(self, indices, updates, name=None): 1130 """Applies sparse subtraction to individual values or slices in a Variable. 1131 1132 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 1133 1134 `indices` must be integer tensor, containing indices into `ref`. 1135 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 1136 1137 The innermost dimension of `indices` (with length `K`) corresponds to 1138 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 1139 dimension of `ref`. 1140 1141 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 1142 1143 ``` 1144 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 1145 ``` 1146 1147 For example, say we want to add 4 scattered elements to a rank-1 tensor to 1148 8 elements. In Python, that update would look like this: 1149 1150 ```python 1151 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 1152 indices = tf.constant([[4], [3], [1] ,[7]]) 1153 updates = tf.constant([9, 10, 11, 12]) 1154 op = ref.scatter_nd_sub(indices, updates) 1155 with tf.compat.v1.Session() as sess: 1156 print sess.run(op) 1157 ``` 1158 1159 The resulting update to ref would look like this: 1160 1161 [1, -9, 3, -6, -6, 6, 7, -4] 1162 1163 See `tf.scatter_nd` for more details about how to make updates to 1164 slices. 1165 1166 Args: 1167 indices: The indices to be used in the operation. 1168 updates: The values to be used in the operation. 1169 name: the name of the operation. 1170 1171 Returns: 1172 The updated variable. 1173 """ 1174 return self._lazy_read( 1175 gen_state_ops.resource_scatter_nd_sub( 1176 self.handle, 1177 indices, 1178 ops.convert_to_tensor(updates, self.dtype), 1179 name=name)) 1180 1181 def scatter_nd_add(self, indices, updates, name=None): 1182 """Applies sparse addition to individual values or slices in a Variable. 1183 1184 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 1185 1186 `indices` must be integer tensor, containing indices into `ref`. 1187 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 1188 1189 The innermost dimension of `indices` (with length `K`) corresponds to 1190 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 1191 dimension of `ref`. 1192 1193 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 1194 1195 ``` 1196 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 1197 ``` 1198 1199 For example, say we want to add 4 scattered elements to a rank-1 tensor to 1200 8 elements. In Python, that update would look like this: 1201 1202 ```python 1203 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 1204 indices = tf.constant([[4], [3], [1] ,[7]]) 1205 updates = tf.constant([9, 10, 11, 12]) 1206 add = ref.scatter_nd_add(indices, updates) 1207 with tf.compat.v1.Session() as sess: 1208 print sess.run(add) 1209 ``` 1210 1211 The resulting update to ref would look like this: 1212 1213 [1, 13, 3, 14, 14, 6, 7, 20] 1214 1215 See `tf.scatter_nd` for more details about how to make updates to 1216 slices. 1217 1218 Args: 1219 indices: The indices to be used in the operation. 1220 updates: The values to be used in the operation. 1221 name: the name of the operation. 1222 1223 Returns: 1224 The updated variable. 1225 """ 1226 return self._lazy_read( 1227 gen_state_ops.resource_scatter_nd_add( 1228 self.handle, 1229 indices, 1230 ops.convert_to_tensor(updates, self.dtype), 1231 name=name)) 1232 1233 def scatter_nd_update(self, indices, updates, name=None): 1234 """Applies sparse assignment to individual values or slices in a Variable. 1235 1236 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 1237 1238 `indices` must be integer tensor, containing indices into `ref`. 1239 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 1240 1241 The innermost dimension of `indices` (with length `K`) corresponds to 1242 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 1243 dimension of `ref`. 1244 1245 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 1246 1247 ``` 1248 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 1249 ``` 1250 1251 For example, say we want to add 4 scattered elements to a rank-1 tensor to 1252 8 elements. In Python, that update would look like this: 1253 1254 ```python 1255 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 1256 indices = tf.constant([[4], [3], [1] ,[7]]) 1257 updates = tf.constant([9, 10, 11, 12]) 1258 op = ref.scatter_nd_update(indices, updates) 1259 with tf.compat.v1.Session() as sess: 1260 print sess.run(op) 1261 ``` 1262 1263 The resulting update to ref would look like this: 1264 1265 [1, 11, 3, 10, 9, 6, 7, 12] 1266 1267 See `tf.scatter_nd` for more details about how to make updates to 1268 slices. 1269 1270 Args: 1271 indices: The indices to be used in the operation. 1272 updates: The values to be used in the operation. 1273 name: the name of the operation. 1274 1275 Returns: 1276 The updated variable. 1277 """ 1278 return self._lazy_read( 1279 gen_state_ops.resource_scatter_nd_update( 1280 self.handle, 1281 indices, 1282 ops.convert_to_tensor(updates, self.dtype), 1283 name=name)) 1284 1285 def scatter_nd_max(self, indices, updates, name=None): 1286 """Updates this variable with the max of `tf.IndexedSlices` and itself. 1287 1288 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 1289 1290 `indices` must be integer tensor, containing indices into `ref`. 1291 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 1292 1293 The innermost dimension of `indices` (with length `K`) corresponds to 1294 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 1295 dimension of `ref`. 1296 1297 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 1298 1299 ``` 1300 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 1301 ``` 1302 1303 See `tf.scatter_nd` for more details about how to make updates to 1304 slices. 1305 1306 Args: 1307 indices: The indices to be used in the operation. 1308 updates: The values to be used in the operation. 1309 name: the name of the operation. 1310 1311 Returns: 1312 The updated variable. 1313 """ 1314 return self._lazy_read( 1315 gen_state_ops.resource_scatter_nd_max( 1316 self.handle, 1317 indices, 1318 ops.convert_to_tensor(updates, self.dtype), 1319 name=name)) 1320 1321 def scatter_nd_min(self, indices, updates, name=None): 1322 """Updates this variable with the min of `tf.IndexedSlices` and itself. 1323 1324 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 1325 1326 `indices` must be integer tensor, containing indices into `ref`. 1327 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 1328 1329 The innermost dimension of `indices` (with length `K`) corresponds to 1330 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 1331 dimension of `ref`. 1332 1333 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 1334 1335 ``` 1336 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 1337 ``` 1338 1339 See `tf.scatter_nd` for more details about how to make updates to 1340 slices. 1341 1342 Args: 1343 indices: The indices to be used in the operation. 1344 updates: The values to be used in the operation. 1345 name: the name of the operation. 1346 1347 Returns: 1348 The updated variable. 1349 """ 1350 return self._lazy_read( 1351 gen_state_ops.resource_scatter_nd_min( 1352 self.handle, 1353 indices, 1354 ops.convert_to_tensor(updates, self.dtype), 1355 name=name)) 1356 1357 def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask, 1358 end_mask, ellipsis_mask, new_axis_mask, 1359 shrink_axis_mask): 1360 with _handle_graph(self.handle), self._assign_dependencies(): 1361 return self._lazy_read( 1362 gen_array_ops.resource_strided_slice_assign( 1363 ref=self.handle, 1364 begin=begin, 1365 end=end, 1366 strides=strides, 1367 value=ops.convert_to_tensor(value, dtype=self.dtype), 1368 name=name, 1369 begin_mask=begin_mask, 1370 end_mask=end_mask, 1371 ellipsis_mask=ellipsis_mask, 1372 new_axis_mask=new_axis_mask, 1373 shrink_axis_mask=shrink_axis_mask)) 1374 1375 def __complex__(self): 1376 return complex(self.value().numpy()) 1377 1378 def __int__(self): 1379 return int(self.value().numpy()) 1380 1381 def __long__(self): 1382 return long(self.value().numpy()) 1383 1384 def __float__(self): 1385 return float(self.value().numpy()) 1386 1387 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 1388 del name 1389 if dtype is not None and not dtype.is_compatible_with(self.dtype): 1390 raise ValueError( 1391 "Incompatible type conversion requested to type {!r} for variable " 1392 "of type {!r}".format(dtype.name, self.dtype.name)) 1393 if as_ref: 1394 return self.read_value().op.inputs[0] 1395 else: 1396 return self.value() 1397 1398 def __iadd__(self, unused_other): 1399 raise RuntimeError("Variable += value not supported. Use " 1400 "variable.assign_add(value) to modify the variable " 1401 "value and variable = variable + value to get a new " 1402 "Tensor object.") 1403 1404 def __isub__(self, unused_other): 1405 raise RuntimeError("Variable -= value not supported. Use " 1406 "variable.assign_sub(value) to modify the variable " 1407 "value and variable = variable - value to get a new " 1408 "Tensor object.") 1409 1410 def __imul__(self, unused_other): 1411 raise RuntimeError("Variable *= value not supported. Use " 1412 "`var.assign(var * value)` to modify the variable or " 1413 "`var = var * value` to get a new Tensor object.") 1414 1415 def __idiv__(self, unused_other): 1416 raise RuntimeError("Variable /= value not supported. Use " 1417 "`var.assign(var / value)` to modify the variable or " 1418 "`var = var / value` to get a new Tensor object.") 1419 1420 def __itruediv__(self, unused_other): 1421 raise RuntimeError("Variable /= value not supported. Use " 1422 "`var.assign(var / value)` to modify the variable or " 1423 "`var = var / value` to get a new Tensor object.") 1424 1425 def __irealdiv__(self, unused_other): 1426 raise RuntimeError("Variable /= value not supported. Use " 1427 "`var.assign(var / value)` to modify the variable or " 1428 "`var = var / value` to get a new Tensor object.") 1429 1430 def __ipow__(self, unused_other): 1431 raise RuntimeError("Variable **= value not supported. Use " 1432 "`var.assign(var ** value)` to modify the variable or " 1433 "`var = var ** value` to get a new Tensor object.") 1434 1435 1436class ResourceVariable(BaseResourceVariable): 1437 """Variable based on resource handles. 1438 1439 See the [Variables How To](https://tensorflow.org/guide/variables) 1440 for a high level overview. 1441 1442 A `ResourceVariable` allows you to maintain state across subsequent calls to 1443 session.run. 1444 1445 The `ResourceVariable` constructor requires an initial value for the variable, 1446 which can be a `Tensor` of any type and shape. The initial value defines the 1447 type and shape of the variable. After construction, the type and shape of 1448 the variable are fixed. The value can be changed using one of the assign 1449 methods. 1450 1451 Just like any `Tensor`, variables created with 1452 `tf.Variable(use_resource=True)` can be used as inputs for other Ops in the 1453 graph. Additionally, all the operators overloaded for the `Tensor` class are 1454 carried over to variables, so you can also add nodes to the graph by just 1455 doing arithmetic on variables. 1456 1457 Unlike ref-based variable, a ResourceVariable has well-defined semantics. Each 1458 usage of a ResourceVariable in a TensorFlow graph adds a read_value operation 1459 to the graph. The Tensors returned by a read_value operation are guaranteed to 1460 see all modifications to the value of the variable which happen in any 1461 operation on which the read_value depends on (either directly, indirectly, or 1462 via a control dependency) and guaranteed to not see any modification to the 1463 value of the variable from operations that depend on the read_value operation. 1464 Updates from operations that have no dependency relationship to the read_value 1465 operation might or might not be visible to read_value. 1466 1467 For example, if there is more than one assignment to a ResourceVariable in 1468 a single session.run call there is a well-defined value for each operation 1469 which uses the variable's value if the assignments and the read are connected 1470 by edges in the graph. Consider the following example, in which two writes 1471 can cause tf.Variable and tf.ResourceVariable to behave differently: 1472 1473 ```python 1474 a = tf.Variable(1.0, use_resource=True) 1475 a.initializer.run() 1476 1477 assign = a.assign(2.0) 1478 with tf.control_dependencies([assign]): 1479 b = a.read_value() 1480 with tf.control_dependencies([b]): 1481 other_assign = a.assign(3.0) 1482 with tf.control_dependencies([other_assign]): 1483 # Will print 2.0 because the value was read before other_assign ran. If 1484 # `a` was a tf.Variable instead, 2.0 or 3.0 could be printed. 1485 tf.compat.v1.Print(b, [b]).eval() 1486 ``` 1487 """ 1488 1489 def __init__( 1490 self, # pylint: disable=super-init-not-called 1491 initial_value=None, 1492 trainable=None, 1493 collections=None, 1494 validate_shape=True, # pylint: disable=unused-argument 1495 caching_device=None, 1496 name=None, 1497 dtype=None, 1498 variable_def=None, 1499 import_scope=None, 1500 constraint=None, 1501 distribute_strategy=None, 1502 synchronization=None, 1503 aggregation=None, 1504 shape=None): 1505 """Creates a variable. 1506 1507 Args: 1508 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 1509 which is the initial value for the Variable. Can also be a callable with 1510 no argument that returns the initial value when called. (Note that 1511 initializer functions from init_ops.py must first be bound to a shape 1512 before being used here.) 1513 trainable: If `True`, the default, also adds the variable to the graph 1514 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 1515 the default list of variables to use by the `Optimizer` classes. 1516 Defaults to `True`, unless `synchronization` is set to `ON_READ`, in 1517 which case it defaults to `False`. 1518 collections: List of graph collections keys. The new variable is added to 1519 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1520 validate_shape: Ignored. Provided for compatibility with tf.Variable. 1521 caching_device: Optional device string or function describing where the 1522 Variable should be cached for reading. Defaults to the Variable's 1523 device. If not `None`, caches on another device. Typical use is to 1524 cache on the device where the Ops using the Variable reside, to 1525 deduplicate copying through `Switch` and other conditional statements. 1526 name: Optional name for the variable. Defaults to `'Variable'` and gets 1527 uniquified automatically. 1528 dtype: If set, initial_value will be converted to the given type. If None, 1529 either the datatype will be kept (if initial_value is a Tensor) or 1530 float32 will be used (if it is a Python object convertible to a Tensor). 1531 variable_def: `VariableDef` protocol buffer. If not None, recreates the 1532 `ResourceVariable` object with its contents. `variable_def` and other 1533 arguments (except for import_scope) are mutually exclusive. 1534 import_scope: Optional `string`. Name scope to add to the 1535 ResourceVariable. Only used when `variable_def` is provided. 1536 constraint: An optional projection function to be applied to the variable 1537 after being updated by an `Optimizer` (e.g. used to implement norm 1538 constraints or value constraints for layer weights). The function must 1539 take as input the unprojected Tensor representing the value of the 1540 variable and return the Tensor for the projected value (which must have 1541 the same shape). Constraints are not safe to use when doing asynchronous 1542 distributed training. 1543 distribute_strategy: The tf.distribute.Strategy this variable is being 1544 created inside of. 1545 synchronization: Indicates when a distributed a variable will be 1546 aggregated. Accepted values are constants defined in the class 1547 `tf.VariableSynchronization`. By default the synchronization is set to 1548 `AUTO` and the current `DistributionStrategy` chooses when to 1549 synchronize. 1550 aggregation: Indicates how a distributed variable will be aggregated. 1551 Accepted values are constants defined in the class 1552 `tf.VariableAggregation`. 1553 shape: (optional) The shape of this variable. If None, the shape of 1554 `initial_value` will be used. When setting this argument to 1555 `tf.TensorShape(None)` (representing an unspecified shape), the variable 1556 can be assigned with values of different shapes. 1557 1558 Raises: 1559 ValueError: If the initial value is not specified, or does not have a 1560 shape and `validate_shape` is `True`. 1561 1562 @compatibility(eager) 1563 When Eager Execution is enabled, the default for the `collections` argument 1564 is `None`, which signifies that this `Variable` will not be added to any 1565 collections. 1566 @end_compatibility 1567 """ 1568 if variable_def: 1569 if initial_value is not None: 1570 raise ValueError("variable_def and initial_value are mutually " 1571 "exclusive.") 1572 if context.executing_eagerly(): 1573 raise ValueError("Creating ResourceVariable from variable_def is " 1574 "not supported when eager execution is enabled.") 1575 self._init_from_proto(variable_def, import_scope=import_scope) 1576 else: 1577 self._init_from_args( 1578 initial_value=initial_value, 1579 trainable=trainable, 1580 collections=collections, 1581 caching_device=caching_device, 1582 name=name, 1583 dtype=dtype, 1584 constraint=constraint, 1585 synchronization=synchronization, 1586 aggregation=aggregation, 1587 shape=shape, 1588 distribute_strategy=distribute_strategy) 1589 1590 def _init_from_args(self, 1591 initial_value=None, 1592 trainable=None, 1593 collections=None, 1594 caching_device=None, 1595 name=None, 1596 dtype=None, 1597 constraint=None, 1598 synchronization=None, 1599 aggregation=None, 1600 distribute_strategy=None, 1601 shape=None): 1602 """Creates a variable. 1603 1604 Args: 1605 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 1606 which is the initial value for the Variable. The initial value must have 1607 a shape specified unless `validate_shape` is set to False. Can also be a 1608 callable with no argument that returns the initial value when called. 1609 (Note that initializer functions from init_ops.py must first be bound to 1610 a shape before being used here.) 1611 trainable: If `True`, the default, also adds the variable to the graph 1612 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 1613 the default list of variables to use by the `Optimizer` classes. 1614 Defaults to `True`, unless `synchronization` is set to `ON_READ`, in 1615 which case it defaults to `False`. 1616 collections: List of graph collections keys. The new variable is added to 1617 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1618 caching_device: Optional device string or function describing where the 1619 Variable should be cached for reading. Defaults to the Variable's 1620 device. If not `None`, caches on another device. Typical use is to 1621 cache on the device where the Ops using the Variable reside, to 1622 deduplicate copying through `Switch` and other conditional statements. 1623 name: Optional name for the variable. Defaults to `'Variable'` and gets 1624 uniquified automatically. 1625 dtype: If set, initial_value will be converted to the given type. If None, 1626 either the datatype will be kept (if initial_value is a Tensor) or 1627 float32 will be used (if it is a Python object convertible to a Tensor). 1628 constraint: An optional projection function to be applied to the variable 1629 after being updated by an `Optimizer` (e.g. used to implement norm 1630 constraints or value constraints for layer weights). The function must 1631 take as input the unprojected Tensor representing the value of the 1632 variable and return the Tensor for the projected value (which must have 1633 the same shape). Constraints are not safe to use when doing asynchronous 1634 distributed training. 1635 synchronization: Indicates when a distributed a variable will be 1636 aggregated. Accepted values are constants defined in the class 1637 `tf.VariableSynchronization`. By default the synchronization is set to 1638 `AUTO` and the current `DistributionStrategy` chooses when to 1639 synchronize. 1640 aggregation: Indicates how a distributed variable will be aggregated. 1641 Accepted values are constants defined in the class 1642 `tf.VariableAggregation`. 1643 distribute_strategy: DistributionStrategy under which this variable was 1644 created. 1645 shape: (optional) The shape of this variable. If None, the shape of 1646 `initial_value` will be used. When setting this argument to 1647 `tf.TensorShape(None)` (representing an unspecified shape), the variable 1648 can be assigned with values of different shapes. 1649 1650 Raises: 1651 ValueError: If the initial value is not specified, or does not have a 1652 shape and `validate_shape` is `True`. 1653 1654 @compatibility(eager) 1655 When Eager Execution is enabled, variables are never added to collections. 1656 It is not implicitly added to the `GLOBAL_VARIABLES` or 1657 `TRAINABLE_VARIABLES` collections, and the `collections` argument is 1658 ignored. 1659 @end_compatibility 1660 """ 1661 synchronization, aggregation, trainable = ( 1662 variables.validate_synchronization_aggregation_trainable( 1663 synchronization, aggregation, trainable, name)) 1664 if initial_value is None: 1665 raise ValueError("initial_value must be specified.") 1666 init_from_fn = callable(initial_value) 1667 1668 if isinstance(initial_value, ops.Tensor) and hasattr( 1669 initial_value, "graph") and initial_value.graph.building_function: 1670 raise ValueError("Tensor-typed variable initializers must either be " 1671 "wrapped in an init_scope or callable " 1672 "(e.g., `tf.Variable(lambda : " 1673 "tf.truncated_normal([10, 40]))`) when building " 1674 "functions. Please file a feature request if this " 1675 "restriction inconveniences you.") 1676 1677 if collections is None: 1678 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 1679 if not isinstance(collections, (list, tuple, set)): 1680 raise ValueError( 1681 "collections argument to Variable constructor must be a list, tuple, " 1682 "or set. Got %s of type %s" % (collections, type(collections))) 1683 if constraint is not None and not callable(constraint): 1684 raise ValueError("The `constraint` argument must be a callable.") 1685 1686 if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: 1687 collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] 1688 with ops.init_scope(): 1689 self._in_graph_mode = not context.executing_eagerly() 1690 with ops.name_scope( 1691 name, 1692 "Variable", [] if init_from_fn else [initial_value], 1693 skip_on_eager=False) as name: 1694 # pylint: disable=protected-access 1695 handle_name = ops.name_from_scope_name(name) 1696 if self._in_graph_mode: 1697 shared_name = handle_name 1698 unique_id = shared_name 1699 else: 1700 # When in eager mode use a uid for the shared_name, to prevent 1701 # accidental sharing. 1702 unique_id = "%s_%d" % (handle_name, ops.uid()) 1703 shared_name = None # Never shared 1704 # Use attr_scope and device(None) to simulate the behavior of 1705 # colocate_with when the variable we want to colocate with doesn't 1706 # yet exist. 1707 device_context_manager = ( 1708 ops.device if self._in_graph_mode else ops.NullContextmanager) 1709 attr = attr_value_pb2.AttrValue( 1710 list=attr_value_pb2.AttrValue.ListValue( 1711 s=[compat.as_bytes("loc:@%s" % handle_name)])) 1712 with ops.get_default_graph()._attr_scope({"_class": attr}): 1713 with ops.name_scope("Initializer"), device_context_manager(None): 1714 if init_from_fn: 1715 initial_value = initial_value() 1716 if isinstance(initial_value, trackable.CheckpointInitialValue): 1717 self._maybe_initialize_trackable() 1718 self._update_uid = initial_value.checkpoint_position.restore_uid 1719 initial_value = initial_value.wrapped_value 1720 initial_value = ops.convert_to_tensor(initial_value, 1721 name="initial_value", 1722 dtype=dtype) 1723 if shape is not None: 1724 if not initial_value.shape.is_compatible_with(shape): 1725 raise ValueError( 1726 "The initial value's shape (%s) is not compatible with " 1727 "the explicitly supplied `shape` argument (%s)." % 1728 (initial_value.shape, shape)) 1729 else: 1730 shape = initial_value.shape 1731 handle = eager_safe_variable_handle( 1732 initial_value=initial_value, 1733 shape=shape, 1734 shared_name=shared_name, 1735 name=name, 1736 graph_mode=self._in_graph_mode) 1737 # pylint: disable=protected-access 1738 if (self._in_graph_mode and initial_value is not None and 1739 initial_value.op._get_control_flow_context() is not None): 1740 raise ValueError( 1741 "Initializer for variable %s is from inside a control-flow " 1742 "construct, such as a loop or conditional. When creating a " 1743 "variable inside a loop or conditional, use a lambda as the " 1744 "initializer." % name) 1745 # pylint: enable=protected-access 1746 dtype = initial_value.dtype.base_dtype 1747 1748 if self._in_graph_mode: 1749 with ops.name_scope("IsInitialized"): 1750 is_initialized_op = ( 1751 gen_resource_variable_ops.var_is_initialized_op(handle)) 1752 if initial_value is not None: 1753 # pylint: disable=g-backslash-continuation 1754 with ops.name_scope("Assign") as n, \ 1755 ops.colocate_with(None, ignore_existing=True), \ 1756 ops.device(handle.device): 1757 # pylint: disable=protected-access 1758 initializer_op = ( 1759 gen_resource_variable_ops.assign_variable_op( 1760 handle, 1761 variables._try_guard_against_uninitialized_dependencies( 1762 name, initial_value), 1763 name=n)) 1764 # pylint: enable=protected-access 1765 # pylint: enable=g-backslash-continuation 1766 with ops.name_scope("Read"): 1767 # Manually assign reads to the handle's device to avoid log 1768 # messages. 1769 with ops.device(handle.device): 1770 value = gen_resource_variable_ops.read_variable_op(handle, dtype) 1771 _maybe_set_handle_data(dtype, handle, value) 1772 graph_element = value 1773 if caching_device is not None: 1774 # Variables may be created in a tf.device() or ops.colocate_with() 1775 # context. At the same time, users would expect caching device to 1776 # be independent of this context, and/or would not expect the 1777 # current device context to be merged with the caching device 1778 # spec. Therefore we reset the colocation stack before creating 1779 # the cached value. Note that resetting the colocation stack will 1780 # also reset the device stack. 1781 with ops.colocate_with(None, ignore_existing=True): 1782 with ops.device(caching_device): 1783 cached_value = array_ops.identity(value) 1784 else: 1785 cached_value = None 1786 else: 1787 gen_resource_variable_ops.assign_variable_op(handle, initial_value) 1788 is_initialized_op = None 1789 initializer_op = None 1790 graph_element = None 1791 if caching_device: 1792 with ops.device(caching_device): 1793 cached_value = gen_resource_variable_ops.read_variable_op( 1794 handle, dtype) 1795 _maybe_set_handle_data(dtype, handle, cached_value) 1796 else: 1797 cached_value = None 1798 1799 if cached_value is not None: 1800 # Store the variable object so that the original variable can be 1801 # accessed to generate functions that are compatible with SavedModel. 1802 cached_value._cached_variable = weakref.ref(self) # pylint: disable=protected-access 1803 1804 if not context.executing_eagerly(): 1805 # Eager variables are only added to collections if they are part of an 1806 # eager variable store (otherwise in an interactive session they would 1807 # hog memory and cause OOM). This is done in ops/variable_scope.py. 1808 ops.add_to_collections(collections, self) 1809 elif ops.GraphKeys.GLOBAL_STEP in collections: 1810 ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self) 1811 initial_value = initial_value if self._in_graph_mode else None 1812 super(ResourceVariable, self).__init__( 1813 trainable=trainable, 1814 shape=shape, 1815 dtype=dtype, 1816 handle=handle, 1817 synchronization=synchronization, 1818 constraint=constraint, 1819 aggregation=aggregation, 1820 distribute_strategy=distribute_strategy, 1821 name=name, 1822 unique_id=unique_id, 1823 handle_name=handle_name, 1824 graph_element=graph_element, 1825 initial_value=initial_value, 1826 initializer_op=initializer_op, 1827 is_initialized_op=is_initialized_op, 1828 cached_value=cached_value, 1829 caching_device=caching_device) 1830 1831 def _init_from_proto(self, variable_def, import_scope=None): 1832 """Initializes from `VariableDef` proto.""" 1833 # Note that init_from_proto is currently not supported in Eager mode. 1834 assert not context.executing_eagerly() 1835 self._in_graph_mode = True 1836 assert isinstance(variable_def, variable_pb2.VariableDef) 1837 if not variable_def.is_resource: 1838 raise ValueError("Trying to restore Variable as ResourceVariable.") 1839 1840 # Create from variable_def. 1841 g = ops.get_default_graph() 1842 self._handle = g.as_graph_element( 1843 ops.prepend_name_scope( 1844 variable_def.variable_name, import_scope=import_scope)) 1845 self._shape = tensor_shape.TensorShape(self._handle.op.get_attr("shape")) 1846 self._handle_name = self._handle.name 1847 self._unique_id = self._handle_name 1848 self._initializer_op = g.as_graph_element( 1849 ops.prepend_name_scope( 1850 variable_def.initializer_name, import_scope=import_scope)) 1851 # Check whether initial_value_name exists for backwards compatibility. 1852 if (hasattr(variable_def, "initial_value_name") and 1853 variable_def.initial_value_name): 1854 self._initial_value = g.as_graph_element( 1855 ops.prepend_name_scope( 1856 variable_def.initial_value_name, import_scope=import_scope)) 1857 else: 1858 self._initial_value = None 1859 synchronization, aggregation, trainable = ( 1860 variables.validate_synchronization_aggregation_trainable( 1861 variable_def.synchronization, variable_def.aggregation, 1862 variable_def.trainable, variable_def.variable_name)) 1863 self._synchronization = synchronization 1864 self._aggregation = aggregation 1865 self._trainable = trainable 1866 if variable_def.snapshot_name: 1867 snapshot = g.as_graph_element( 1868 ops.prepend_name_scope( 1869 variable_def.snapshot_name, import_scope=import_scope)) 1870 if snapshot.op.type != "ReadVariableOp": 1871 self._cached_value = snapshot 1872 else: 1873 self._cached_value = None 1874 while snapshot.op.type != "ReadVariableOp": 1875 snapshot = snapshot.op.inputs[0] 1876 self._graph_element = snapshot 1877 else: 1878 self._cached_value = None 1879 # Legacy case for protos without the snapshot name; assume it's the 1880 # following. 1881 self._graph_element = g.get_tensor_by_name(self._handle.op.name + 1882 "/Read/ReadVariableOp:0") 1883 if variable_def.HasField("save_slice_info_def"): 1884 self._save_slice_info = variables.Variable.SaveSliceInfo( 1885 save_slice_info_def=variable_def.save_slice_info_def, 1886 import_scope=import_scope) 1887 else: 1888 self._save_slice_info = None 1889 self._caching_device = None 1890 self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype")) 1891 self._constraint = None 1892 1893 1894class UninitializedVariable(BaseResourceVariable): 1895 """A variable with no initializer.""" 1896 1897 def __init__( # pylint: disable=super-init-not-called 1898 self, 1899 trainable=None, 1900 caching_device=None, 1901 name=None, 1902 shape=None, 1903 dtype=None, 1904 constraint=None, 1905 synchronization=None, 1906 aggregation=None, 1907 extra_handle_data=None, 1908 distribute_strategy=None, 1909 **unused_kwargs): 1910 """Creates the variable handle. 1911 1912 Args: 1913 trainable: If `True`, GradientTapes automatically watch uses of this 1914 Variable. 1915 caching_device: Optional device string or function describing where the 1916 Variable should be cached for reading. Defaults to the Variable's 1917 device. If not `None`, caches on another device. Typical use is to 1918 cache on the device where the Ops using the Variable reside, to 1919 deduplicate copying through `Switch` and other conditional statements. 1920 name: Optional name for the variable. Defaults to `'Variable'` and gets 1921 uniquified automatically. 1922 shape: The variable's shape. 1923 dtype: The variable's dtype. 1924 constraint: An optional projection function to be applied to the variable 1925 after being updated by an `Optimizer` (e.g. used to implement norm 1926 constraints or value constraints for layer weights). The function must 1927 take as input the unprojected Tensor representing the value of the 1928 variable and return the Tensor for the projected value (which must have 1929 the same shape). Constraints are not safe to use when doing asynchronous 1930 distributed training. 1931 synchronization: Indicates when a distributed a variable will be 1932 aggregated. Accepted values are constants defined in the class 1933 `tf.VariableSynchronization`. By default the synchronization is set to 1934 `AUTO` and the current `DistributionStrategy` chooses when to 1935 synchronize. 1936 aggregation: Indicates how a distributed variable will be aggregated. 1937 Accepted values are constants defined in the class 1938 `tf.VariableAggregation`. 1939 extra_handle_data: Optional, another resource handle or Tensor with handle 1940 data to merge with `shape` and `dtype`. 1941 distribute_strategy: The tf.distribute.Strategy this variable is being 1942 created inside of. 1943 """ 1944 with ops.init_scope(): 1945 self._in_graph_mode = not context.executing_eagerly() 1946 with ops.init_scope(): 1947 with ops.name_scope(name, "Variable", skip_on_eager=False) as name: 1948 handle_name = ops.name_from_scope_name(name) 1949 if self._in_graph_mode: 1950 shared_name = handle_name 1951 unique_id = shared_name 1952 else: 1953 unique_id = "%s_%d" % (handle_name, ops.uid()) 1954 shared_name = None # Never shared 1955 handle = _variable_handle_from_shape_and_dtype( 1956 shape=shape, 1957 dtype=dtype, 1958 shared_name=shared_name, 1959 name=name, 1960 graph_mode=self._in_graph_mode, 1961 initial_value=extra_handle_data) 1962 if not context.executing_eagerly(): 1963 with ops.name_scope("Read"): 1964 # Manually assign reads to the handle's device to avoid log 1965 # messages. 1966 with ops.device(handle.device): 1967 value = gen_resource_variable_ops.read_variable_op(handle, dtype) 1968 _maybe_set_handle_data(dtype, handle, value) 1969 graph_element = value 1970 ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self) 1971 # Do *not* add to TRAINABLE_VARIABLES here, even if self._trainable, 1972 # because retraining or frozen use of imported SavedModels is 1973 # controlled at higher levels of model building. 1974 else: 1975 graph_element = None 1976 super(UninitializedVariable, self).__init__( 1977 distribute_strategy=distribute_strategy, 1978 shape=shape, 1979 dtype=dtype, 1980 unique_id=unique_id, 1981 handle_name=handle_name, 1982 constraint=constraint, 1983 handle=handle, 1984 graph_element=graph_element, 1985 trainable=trainable, 1986 synchronization=synchronization, 1987 aggregation=aggregation) 1988 1989 1990_pywrap_utils.RegisterType("ResourceVariable", ResourceVariable) 1991math_ops._resource_variable_type = ResourceVariable # pylint: disable=protected-access 1992 1993 1994def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False): 1995 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 1996 1997 1998# Register a conversion function which reads the value of the variable, 1999# allowing instances of the class to be used as tensors. 2000ops.register_tensor_conversion_function(BaseResourceVariable, 2001 _dense_var_to_tensor) 2002 2003 2004class _UnreadVariable(BaseResourceVariable): 2005 """Represents a future for a read of a variable. 2006 2007 Pretends to be the tensor if anyone looks. 2008 """ 2009 2010 def __init__(self, handle, dtype, shape, in_graph_mode, deleter, parent_op, 2011 unique_id): 2012 if isinstance(handle, ops.EagerTensor): 2013 handle_name = "" 2014 else: 2015 handle_name = handle.name 2016 # Only create a graph_element if we're in session.run-land as only 2017 # session.run requires a preexisting tensor to evaluate. Otherwise we can 2018 # avoid accidentally reading the variable. 2019 if context.executing_eagerly() or ops.inside_function(): 2020 graph_element = None 2021 else: 2022 with ops.control_dependencies([parent_op]): 2023 graph_element = gen_resource_variable_ops.read_variable_op( 2024 handle, dtype) 2025 _maybe_set_handle_data(dtype, handle, graph_element) 2026 super(_UnreadVariable, self).__init__( 2027 handle=handle, 2028 shape=shape, 2029 handle_name=handle_name, 2030 unique_id=unique_id, 2031 dtype=dtype, 2032 handle_deleter=deleter, 2033 graph_element=graph_element) 2034 self._parent_op = parent_op 2035 2036 @property 2037 def name(self): 2038 if self._in_graph_mode: 2039 return self._parent_op.name 2040 else: 2041 return "UnreadVariable" 2042 2043 def value(self): 2044 return self._read_variable_op() 2045 2046 def read_value(self): 2047 return self._read_variable_op() 2048 2049 def _read_variable_op(self): 2050 with ops.control_dependencies([self._parent_op]): 2051 result = gen_resource_variable_ops.read_variable_op( 2052 self._handle, self._dtype) 2053 _maybe_set_handle_data(self._dtype, self._handle, result) 2054 return result 2055 2056 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 2057 with ops.control_dependencies([self._parent_op]): 2058 return super(_UnreadVariable, self).assign_sub(delta, use_locking, name, 2059 read_value) 2060 2061 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 2062 with ops.control_dependencies([self._parent_op]): 2063 return super(_UnreadVariable, self).assign_add(delta, use_locking, name, 2064 read_value) 2065 2066 def assign(self, value, use_locking=None, name=None, read_value=True): 2067 with ops.control_dependencies([self._parent_op]): 2068 return super(_UnreadVariable, self).assign(value, use_locking, name, 2069 read_value) 2070 2071 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 2072 with ops.control_dependencies([self._parent_op]): 2073 return super(_UnreadVariable, self).scatter_sub(sparse_delta, use_locking, 2074 name) 2075 2076 def scatter_add(self, sparse_delta, use_locking=False, name=None): 2077 with ops.control_dependencies([self._parent_op]): 2078 return super(_UnreadVariable, self).scatter_add(sparse_delta, use_locking, 2079 name) 2080 2081 def scatter_max(self, sparse_delta, use_locking=False, name=None): 2082 with ops.control_dependencies([self._parent_op]): 2083 return super(_UnreadVariable, self).scatter_max(sparse_delta, use_locking, 2084 name) 2085 2086 def scatter_min(self, sparse_delta, use_locking=False, name=None): 2087 with ops.control_dependencies([self._parent_op]): 2088 return super(_UnreadVariable, self).scatter_min(sparse_delta, use_locking, 2089 name) 2090 2091 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 2092 with ops.control_dependencies([self._parent_op]): 2093 return super(_UnreadVariable, self).scatter_mul(sparse_delta, use_locking, 2094 name) 2095 2096 def scatter_div(self, sparse_delta, use_locking=False, name=None): 2097 with ops.control_dependencies([self._parent_op]): 2098 return super(_UnreadVariable, self).scatter_div(sparse_delta, use_locking, 2099 name) 2100 2101 def scatter_update(self, sparse_delta, use_locking=False, name=None): 2102 with ops.control_dependencies([self._parent_op]): 2103 return super(_UnreadVariable, 2104 self).scatter_update(sparse_delta, use_locking, name) 2105 2106 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 2107 with ops.control_dependencies([self._parent_op]): 2108 return super(_UnreadVariable, 2109 self).batch_scatter_update(sparse_delta, use_locking, name) 2110 2111 def scatter_nd_sub(self, indices, updates, name=None): 2112 with ops.control_dependencies([self._parent_op]): 2113 return super(_UnreadVariable, self).scatter_nd_sub(indices, updates, name) 2114 2115 def scatter_nd_add(self, indices, updates, name=None): 2116 with ops.control_dependencies([self._parent_op]): 2117 return super(_UnreadVariable, self).scatter_nd_add(indices, updates, name) 2118 2119 def scatter_nd_update(self, indices, updates, name=None): 2120 with ops.control_dependencies([self._parent_op]): 2121 return super(_UnreadVariable, 2122 self).scatter_nd_update(indices, updates, name) 2123 2124 def scatter_nd_max(self, indices, updates, name=None): 2125 with ops.control_dependencies([self._parent_op]): 2126 return super(_UnreadVariable, self).scatter_nd_max(indices, updates, name) 2127 2128 def scatter_nd_min(self, indices, updates, name=None): 2129 with ops.control_dependencies([self._parent_op]): 2130 return super(_UnreadVariable, self).scatter_nd_min(indices, updates, name) 2131 2132 @property 2133 def op(self): 2134 """The op for this variable.""" 2135 return self._parent_op 2136 2137 2138@ops.RegisterGradient("ReadVariableOp") 2139def _ReadGrad(_, grad): 2140 """Gradient for read op.""" 2141 return grad 2142 2143 2144def variable_shape(handle, out_type=dtypes.int32): 2145 if getattr(handle, "_handle_data", 2146 None) is None or not handle._handle_data.is_set: # pylint: disable=protected-access 2147 return gen_resource_variable_ops.variable_shape(handle, out_type=out_type) 2148 shape_proto = handle._handle_data.shape_and_type[0].shape # pylint: disable=protected-access 2149 if shape_proto.unknown_rank or any(x.size == -1 for x in shape_proto.dim): 2150 return gen_resource_variable_ops.variable_shape(handle, out_type=out_type) 2151 return constant_op.constant([x.size for x in shape_proto.dim], dtype=out_type) 2152 2153 2154@ops.RegisterGradient("ResourceGather") 2155def _GatherGrad(op, grad): 2156 """Gradient for gather op.""" 2157 # Build appropriately shaped IndexedSlices 2158 handle = op.inputs[0] 2159 indices = op.inputs[1] 2160 params_shape = variable_shape(handle) 2161 size = array_ops.expand_dims(array_ops.size(indices), 0) 2162 values_shape = array_ops.concat([size, params_shape[1:]], 0) 2163 values = array_ops.reshape(grad, values_shape) 2164 indices = array_ops.reshape(indices, size) 2165 return (ops.IndexedSlices(values, indices, params_shape), None) 2166 2167 2168def _to_proto_fn(v, export_scope=None): 2169 """Converts Variable and ResourceVariable to VariableDef for collections.""" 2170 return v.to_proto(export_scope=export_scope) 2171 2172 2173def _from_proto_fn(v, import_scope=None): 2174 """Creates Variable or ResourceVariable from VariableDef as needed.""" 2175 if v.is_resource: 2176 return ResourceVariable.from_proto(v, import_scope=import_scope) 2177 return variables.Variable.from_proto(v, import_scope=import_scope) 2178 2179 2180ops.register_proto_function( 2181 ops.GraphKeys.GLOBAL_VARIABLES, 2182 proto_type=variable_pb2.VariableDef, 2183 to_proto=_to_proto_fn, 2184 from_proto=_from_proto_fn) 2185ops.register_proto_function( 2186 ops.GraphKeys.TRAINABLE_VARIABLES, 2187 proto_type=variable_pb2.VariableDef, 2188 to_proto=_to_proto_fn, 2189 from_proto=_from_proto_fn) 2190ops.register_proto_function( 2191 ops.GraphKeys.MOVING_AVERAGE_VARIABLES, 2192 proto_type=variable_pb2.VariableDef, 2193 to_proto=_to_proto_fn, 2194 from_proto=_from_proto_fn) 2195ops.register_proto_function( 2196 ops.GraphKeys.LOCAL_VARIABLES, 2197 proto_type=variable_pb2.VariableDef, 2198 to_proto=_to_proto_fn, 2199 from_proto=_from_proto_fn) 2200ops.register_proto_function( 2201 ops.GraphKeys.MODEL_VARIABLES, 2202 proto_type=variable_pb2.VariableDef, 2203 to_proto=_to_proto_fn, 2204 from_proto=_from_proto_fn) 2205ops.register_proto_function( 2206 ops.GraphKeys.GLOBAL_STEP, 2207 proto_type=variable_pb2.VariableDef, 2208 to_proto=_to_proto_fn, 2209 from_proto=_from_proto_fn) 2210ops.register_proto_function( 2211 ops.GraphKeys.METRIC_VARIABLES, 2212 proto_type=variable_pb2.VariableDef, 2213 to_proto=_to_proto_fn, 2214 from_proto=_from_proto_fn) 2215 2216 2217@tf_export("__internal__.ops.is_resource_variable", v1=[]) 2218def is_resource_variable(var): 2219 """"Returns True if `var` is to be considered a ResourceVariable.""" 2220 return isinstance(var, BaseResourceVariable) or hasattr( 2221 var, "_should_act_as_resource_variable") 2222 2223 2224def copy_to_graph_uninitialized(var): 2225 """Copies an existing variable to a new graph, with no initializer.""" 2226 # Like ResourceVariable.__deepcopy__, but does not set an initializer on the 2227 # new variable. 2228 # pylint: disable=protected-access 2229 new_variable = UninitializedVariable( 2230 trainable=var.trainable, 2231 constraint=var._constraint, 2232 shape=var.shape, 2233 dtype=var.dtype, 2234 name=var._shared_name, 2235 synchronization=var.synchronization, 2236 aggregation=var.aggregation, 2237 extra_handle_data=var.handle) 2238 new_variable._maybe_initialize_trackable() 2239 # pylint: enable=protected-access 2240 return new_variable 2241 2242 2243ops.NotDifferentiable("Assert") 2244ops.NotDifferentiable("VarIsInitializedOp") 2245ops.NotDifferentiable("VariableShape") 2246 2247 2248class VariableSpec(tensor_spec.DenseSpec): 2249 """Describes a tf.Variable.""" 2250 2251 __slots__ = [] 2252 2253 value_type = property(lambda self: BaseResourceVariable) 2254 2255 def _to_components(self, value): 2256 raise NotImplementedError 2257 2258 def _from_components(self, components): 2259 raise NotImplementedError 2260 2261 def _from_compatible_tensor_list(self, tensor_list): 2262 assert len(tensor_list) == 1 2263 return tensor_list[0] 2264 2265 2266_pywrap_utils.RegisterType("VariableSpec", VariableSpec) 2267