1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ops to use variables as resources.""" 16 17# pylint: disable=g-bad-name 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import contextlib 23import functools 24import weakref 25 26import numpy as np 27 28from tensorflow.core.framework import attr_value_pb2 29from tensorflow.core.framework import variable_pb2 30from tensorflow.python.client import pywrap_tf_session 31from tensorflow.python.eager import context 32from tensorflow.python.eager import tape 33from tensorflow.python.framework import auto_control_deps_utils as acd 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import cpp_shape_inference_pb2 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import errors 38from tensorflow.python.framework import meta_graph 39from tensorflow.python.framework import ops 40from tensorflow.python.framework import tensor_shape 41from tensorflow.python.framework import tensor_spec 42from tensorflow.python.ops import array_ops 43from tensorflow.python.ops import gen_array_ops 44from tensorflow.python.ops import gen_resource_variable_ops 45from tensorflow.python.ops import gen_state_ops 46from tensorflow.python.ops import handle_data_util 47from tensorflow.python.ops import math_ops 48from tensorflow.python.ops import state_ops 49from tensorflow.python.ops import variables 50# go/tf-wildcard-import 51# pylint: disable=wildcard-import 52from tensorflow.python.ops.gen_resource_variable_ops import * 53# pylint: enable=wildcard-import 54from tensorflow.python.training.tracking import base as trackable 55from tensorflow.python.types import core 56from tensorflow.python.util import _pywrap_utils 57from tensorflow.python.util import compat 58from tensorflow.python.util.deprecation import deprecated 59from tensorflow.python.util.tf_export import tf_export 60 61acd.register_read_only_resource_op("ReadVariableOp") 62acd.register_read_only_resource_op("VariableShape") 63acd.register_read_only_resource_op("ResourceGather") 64acd.register_read_only_resource_op("ResourceGatherNd") 65acd.register_read_only_resource_op("_ReadVariablesOp") 66 67 68# TODO(allenl): Remove this alias and migrate callers. 69get_resource_handle_data = handle_data_util.get_resource_handle_data 70 71 72def get_eager_safe_handle_data(handle): 73 """Get the data handle from the Tensor `handle`.""" 74 assert isinstance(handle, ops.Tensor) 75 76 if isinstance(handle, ops.EagerTensor): 77 return handle._handle_data # pylint: disable=protected-access 78 else: 79 return get_resource_handle_data(handle) 80 81 82def _set_handle_shapes_and_types(tensor, handle_data, graph_mode): 83 """Sets the shape inference result HandleData on tensor. 84 85 Args: 86 tensor: A `Tensor` or `EagerTensor`. 87 handle_data: A `CppShapeInferenceResult.HandleData`. 88 graph_mode: A python bool. 89 """ 90 tensor._handle_data = handle_data # pylint: disable=protected-access 91 if not graph_mode: 92 return 93 94 # Not an EagerTensor, so a graph tensor. 95 shapes, types = zip(*[(pair.shape, pair.dtype) 96 for pair in handle_data.shape_and_type]) 97 ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] 98 shapes = [ 99 [d.size for d in s.dim] # pylint: disable=g-complex-comprehension 100 if not s.unknown_rank else None for s in shapes 101 ] 102 pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper( 103 tensor._op._graph._c_graph, # pylint: disable=protected-access 104 tensor._as_tf_output(), # pylint: disable=protected-access 105 shapes, 106 ranks, 107 types) 108 109 110def _combine_handle_data(handle, initial_value): 111 """Concats HandleData from tensors `handle` and `initial_value`. 112 113 Args: 114 handle: A `Tensor` of dtype `resource`. 115 initial_value: A `Tensor`. 116 117 Returns: 118 A `CppShapeInferenceResult.HandleData`. If `initial_value` has dtype 119 `variant`, the `HandleData` contains the concatenation of the shape_and_type 120 from both `handle` and `initial_value`. 121 122 Raises: 123 RuntimeError: If handle, which was returned by VarHandleOp, either has 124 no handle data, or its len(handle_data.shape_and_type) != 1. 125 """ 126 assert handle.dtype == dtypes.resource 127 128 variable_handle_data = get_eager_safe_handle_data(handle) 129 130 if initial_value.dtype != dtypes.variant: 131 return variable_handle_data 132 133 extra_handle_data = get_eager_safe_handle_data(initial_value) 134 if extra_handle_data is not None and extra_handle_data.is_set: 135 if (variable_handle_data is None or not variable_handle_data.is_set or 136 len(variable_handle_data.shape_and_type) != 1): 137 raise RuntimeError( 138 "Expected VarHandleOp to return a length==1 shape_and_type, " 139 f"but saw: '{variable_handle_data}'") 140 variable_handle_data.shape_and_type.extend(extra_handle_data.shape_and_type) 141 return variable_handle_data 142 143 144def _variable_handle_from_shape_and_dtype(shape, 145 dtype, 146 shared_name, 147 name, 148 graph_mode, 149 initial_value=None): 150 """Create a variable handle, copying in handle data from `initial_value`.""" 151 container = ops.get_default_graph()._container # pylint: disable=protected-access 152 if container is None: 153 container = "" 154 shape = tensor_shape.as_shape(shape) 155 dtype = dtypes.as_dtype(dtype) 156 if not graph_mode: 157 if shared_name is not None: 158 raise errors.InternalError( # pylint: disable=no-value-for-parameter 159 "Using an explicit shared_name is not supported executing eagerly.") 160 shared_name = context.shared_name() 161 162 handle = gen_resource_variable_ops.var_handle_op( 163 shape=shape, 164 dtype=dtype, 165 shared_name=shared_name, 166 name=name, 167 container=container) 168 if initial_value is None: 169 initial_value = handle 170 if graph_mode: 171 full_handle_data = _combine_handle_data(handle, initial_value) 172 _set_handle_shapes_and_types(handle, full_handle_data, graph_mode) 173 return handle 174 else: 175 handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData() 176 handle_data.is_set = True 177 handle_data.shape_and_type.append( 178 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType( 179 shape=shape.as_proto(), dtype=dtype.as_datatype_enum)) 180 181 if initial_value is not None and initial_value.dtype == dtypes.variant: 182 extra_handle_data = get_eager_safe_handle_data(initial_value) 183 if extra_handle_data is not None and extra_handle_data.is_set: 184 if (not handle_data.is_set or len(handle_data.shape_and_type) != 1): 185 raise RuntimeError( 186 "Expected VarHandleOp to return a length==1 shape_and_type, " 187 f"but saw: '{handle_data}'") 188 handle_data.shape_and_type.extend(extra_handle_data.shape_and_type) 189 190 _set_handle_shapes_and_types(handle, handle_data, graph_mode) 191 return handle 192 193 194def eager_safe_variable_handle(initial_value, shape, shared_name, name, 195 graph_mode): 196 """Creates a variable handle with information to do shape inference. 197 198 The dtype is read from `initial_value` and stored in the returned 199 resource tensor's handle data. 200 201 If `initial_value.dtype == tf.variant`, we additionally extract the handle 202 data (if any) from `initial_value` and append it to the `handle_data`. 203 In this case, the returned tensor's handle data is in the form 204 205 ``` 206 is_set: true 207 shape_and_type { 208 shape { 209 // initial_value.shape 210 } 211 dtype: DT_VARIANT 212 } 213 shape_and_type { 214 // handle_data(initial_value).shape_and_type[0] 215 } 216 shape_and_type { 217 // handle_data(initial_value).shape_and_type[1] 218 } 219 ... 220 ``` 221 222 Ops that read from this tensor, such as `ReadVariableOp` and 223 `AssignVariableOp`, know that `handle_data(handle).shape_and_type[1:]` 224 correspond to the handle data of the variant(s) stored in the Variable. 225 226 Args: 227 initial_value: A `Tensor`. 228 shape: The shape of the handle data. Can be `TensorShape(None)` (i.e. 229 unknown shape). 230 shared_name: A string. 231 name: A string. 232 graph_mode: A python bool. 233 234 Returns: 235 The handle, a `Tensor` of type `resource`. 236 """ 237 dtype = initial_value.dtype.base_dtype 238 return _variable_handle_from_shape_and_dtype(shape, dtype, shared_name, name, 239 graph_mode, initial_value) 240 241 242@contextlib.contextmanager 243def _handle_graph(handle): 244 # Note: might have an eager tensor but not be executing eagerly when building 245 # functions. 246 if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor) or 247 ops.has_default_graph()): 248 yield 249 else: 250 with handle.graph.as_default(): 251 yield 252 253 254class EagerResourceDeleter(object): 255 """An object which cleans up a resource handle. 256 257 An alternative to defining a __del__ method on an object. The intended use is 258 that ResourceVariables or other objects with resource handles will maintain a 259 single reference to this object. When the parent object is collected, this 260 object will be too. Even if the parent object is part of a reference cycle, 261 the cycle will be collectable. 262 """ 263 264 __slots__ = ["_handle", "_handle_device", "_context"] 265 266 def __init__(self, handle, handle_device): 267 if not isinstance(handle, ops.Tensor): 268 raise ValueError( 269 (f"Passed handle={handle} to EagerResourceDeleter. Was expecting " 270 f"the handle to be a `tf.Tensor`.")) 271 self._handle = handle 272 self._handle_device = handle_device 273 # This is held since the __del__ function runs an op, and if the context() 274 # is collected before this object, there will be a segfault when running the 275 # op. 276 self._context = context.context() 277 278 def __del__(self): 279 # Resources follow object-identity when executing eagerly, so it is safe to 280 # delete the resource we have a handle to. 281 try: 282 # A packed EagerTensor doesn't own any resource. 283 if isinstance(self._handle, ops.EagerTensor) and self._handle.is_packed: 284 return 285 # This resource was created in eager mode. However, this destructor may be 286 # running in graph mode (especially during unit tests). To clean up 287 # successfully, we switch back into eager mode temporarily. 288 with context.eager_mode(): 289 with ops.device(self._handle_device): 290 gen_resource_variable_ops.destroy_resource_op( 291 self._handle, ignore_lookup_error=True) 292 except TypeError: 293 # Suppress some exceptions, mainly for the case when we're running on 294 # module deletion. Things that can go wrong include the context module 295 # already being unloaded, self._handle._handle_data no longer being 296 # valid, and so on. Printing warnings in these cases is silly 297 # (exceptions raised from __del__ are printed as warnings to stderr). 298 pass # 'NoneType' object is not callable when the handle has been 299 # partially unloaded. 300 except AttributeError: 301 pass # 'NoneType' object has no attribute 'eager_mode' when context has 302 # been unloaded. Will catch other module unloads as well. 303 304 305def shape_safe_assign_variable_handle(handle, shape, value, name=None): 306 """Helper that checks shape compatibility and assigns variable.""" 307 with _handle_graph(handle): 308 value_tensor = ops.convert_to_tensor(value) 309 shape.assert_is_compatible_with(value_tensor.shape) 310 return gen_resource_variable_ops.assign_variable_op( 311 handle, value_tensor, name=name) 312 313 314def _maybe_set_handle_data(dtype, handle, tensor): 315 if dtype == dtypes.variant: 316 # For DT_VARIANT types, the handle's shape_and_type[1:] stores the 317 # variant's handle data. Extract it. 318 handle_data = get_eager_safe_handle_data(handle) 319 if handle_data.is_set and len(handle_data.shape_and_type) > 1: 320 tensor._handle_data = ( # pylint: disable=protected-access 321 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData( 322 is_set=True, shape_and_type=handle_data.shape_and_type[1:])) 323 324 325def variable_accessed(variable): 326 """Records that `variable` was accessed for the tape and FuncGraph.""" 327 if hasattr(ops.get_default_graph(), "watch_variable"): 328 ops.get_default_graph().watch_variable(variable) 329 if variable.trainable: 330 tape.variable_accessed(variable) 331 332 333class BaseResourceVariable(variables.VariableV1, core.Tensor): 334 """A python variable from an existing handle.""" 335 336 # TODO(wangpeng): Deprecate `constraint` when callers no long pass it in. 337 def __init__( # pylint: disable=super-init-not-called 338 self, 339 trainable=None, 340 shape=None, 341 dtype=None, 342 handle=None, 343 constraint=None, 344 synchronization=None, 345 aggregation=None, 346 distribute_strategy=None, 347 name=None, 348 unique_id=None, 349 handle_name=None, 350 graph_element=None, 351 initial_value=None, 352 initializer_op=None, 353 is_initialized_op=None, 354 cached_value=None, 355 save_slice_info=None, 356 handle_deleter=None, 357 caching_device=None, 358 in_graph_mode=None, 359 **unused_kwargs): 360 """Creates a variable from a handle. 361 362 Args: 363 trainable: If `True`, GradientTapes automatically watch uses of this 364 Variable. 365 shape: The variable's shape. 366 dtype: The variable's dtype. 367 handle: The variable's handle 368 constraint: An optional projection function to be applied to the variable 369 after being updated by an `Optimizer` (e.g. used to implement norm 370 constraints or value constraints for layer weights). The function must 371 take as input the unprojected Tensor representing the value of the 372 variable and return the Tensor for the projected value (which must have 373 the same shape). Constraints are not safe to use when doing asynchronous 374 distributed training. 375 synchronization: Indicates when a distributed a variable will be 376 aggregated. Accepted values are constants defined in the class 377 `tf.VariableSynchronization`. By default the synchronization is set to 378 `AUTO` and the current `DistributionStrategy` chooses when to 379 synchronize. 380 aggregation: Indicates how a distributed variable will be aggregated. 381 Accepted values are constants defined in the class 382 `tf.VariableAggregation`. 383 distribute_strategy: The distribution strategy this variable was created 384 under. 385 name: The name for this variable. 386 unique_id: Internal. Unique ID for this variable's handle. 387 handle_name: The name for the variable's handle. 388 graph_element: Optional, required only in session.run-mode. Pre-created 389 tensor which reads this variable's value. 390 initial_value: Optional. Variable's initial value. 391 initializer_op: Operation which assigns the variable's initial value. 392 is_initialized_op: Pre-created operation to check whether this variable is 393 initialized. 394 cached_value: Pre-created operation to read this variable in a specific 395 device. 396 save_slice_info: Metadata for variable partitioning. 397 handle_deleter: EagerResourceDeleter responsible for cleaning up the 398 handle. 399 caching_device: Optional device string or function describing where the 400 Variable should be cached for reading. Defaults to the Variable's 401 device. If not `None`, caches on another device. Typical use is to 402 cache on the device where the Ops using the Variable reside, to 403 deduplicate copying through `Switch` and other conditional statements. 404 in_graph_mode: whether we are executing in TF1 graph mode. If None, will 405 detect within the function. This is to avoid repeated init_scope() 406 conetxt entrances which can add up. 407 """ 408 if in_graph_mode is None: 409 with ops.init_scope(): 410 self._in_graph_mode = not context.executing_eagerly() 411 else: 412 self._in_graph_mode = in_graph_mode 413 synchronization, aggregation, trainable = ( 414 variables.validate_synchronization_aggregation_trainable( 415 synchronization, aggregation, trainable, name)) 416 self._trainable = trainable 417 self._synchronization = synchronization 418 self._aggregation = aggregation 419 self._save_slice_info = save_slice_info 420 self._initial_value = initial_value 421 self._initializer_op = initializer_op 422 self._is_initialized_op = is_initialized_op 423 self._graph_element = graph_element 424 self._caching_device = caching_device 425 self._cached_value = cached_value 426 self._distribute_strategy = distribute_strategy 427 # Store the graph key so optimizers know how to only retrieve variables from 428 # this graph. Guaranteed to be the same as the eager graph_key. 429 self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 430 self._shape = tensor_shape.as_shape(shape) 431 self._dtype = dtypes.as_dtype(dtype) 432 self._handle = handle 433 self._unique_id = unique_id 434 self._handle_name = handle_name + ":0" 435 self._constraint = constraint 436 # After the handle has been created, set up a way to clean it up when 437 # executing eagerly. We'll hold the only reference to the deleter, so that 438 # when this object is garbage collected the deleter will be too. This 439 # means ResourceVariables can be part of reference cycles without those 440 # cycles being uncollectable. 441 if not self._in_graph_mode: 442 if handle_deleter is None: 443 handle_deleter = EagerResourceDeleter( 444 handle=self._handle, handle_device=self._handle.device) 445 self._handle_deleter = handle_deleter 446 self._cached_shape_as_list = None 447 448 def __repr__(self): 449 if context.executing_eagerly() and not self._in_graph_mode: 450 # If we cannot read the value for any reason, still produce a __repr__. 451 try: 452 value_text = ops.numpy_text(self.read_value(), is_repr=True) 453 except: # pylint: disable=bare-except 454 value_text = "<unavailable>" 455 456 return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % ( 457 self.name, self.get_shape(), self.dtype.name, value_text) 458 else: 459 return "<tf.Variable '%s' shape=%s dtype=%s>" % ( 460 self.name, self.get_shape(), self.dtype.name) 461 462 @contextlib.contextmanager 463 def _assign_dependencies(self): 464 """Makes assignments depend on the cached value, if any. 465 466 This prevents undefined behavior with reads not ordered wrt writes. 467 468 Yields: 469 None. 470 """ 471 if self._cached_value is not None: 472 with ops.control_dependencies([self._cached_value]): 473 yield 474 else: 475 yield 476 477 def __array__(self): 478 """Allows direct conversion to a numpy array. 479 480 >>> np.array(tf.Variable([1.0])) 481 array([1.], dtype=float32) 482 483 Returns: 484 The variable value as a numpy array. 485 """ 486 # You can't return `self.numpy()` here because for scalars 487 # that raises: 488 # ValueError: object __array__ method not producing an array 489 # Even `self.read_value().__array__()` and `self.read_value()._numpy()` give 490 # the same error. The `EagerTensor` class must be doing something behind the 491 # scenes to make `np.array(tf.constant(1))` work. 492 return np.asarray(self.numpy()) 493 494 def __nonzero__(self): 495 return self.__bool__() 496 497 def __bool__(self): 498 return bool(self.read_value()) 499 500 def __copy__(self): 501 return self 502 503 def __deepcopy__(self, memo): 504 if not context.executing_eagerly(): 505 raise NotImplementedError( 506 "__deepcopy__() is only available when eager execution is enabled.") 507 copied_variable = ResourceVariable( 508 initial_value=self.read_value(), 509 trainable=self._trainable, 510 constraint=self._constraint, 511 dtype=self._dtype, 512 name=self._shared_name, 513 distribute_strategy=self._distribute_strategy, 514 synchronization=self.synchronization, 515 aggregation=self.aggregation) 516 memo[self._unique_id] = copied_variable 517 return copied_variable 518 519 @property 520 def dtype(self): 521 """The dtype of this variable.""" 522 return self._dtype 523 524 @property 525 def device(self): 526 """The device this variable is on.""" 527 return self.handle.device 528 529 @property 530 def graph(self): 531 """The `Graph` of this variable.""" 532 return self.handle.graph 533 534 @property 535 def name(self): 536 """The name of the handle for this variable.""" 537 return self._handle_name 538 539 @property 540 def shape(self): 541 """The shape of this variable.""" 542 return self._shape 543 544 def set_shape(self, shape): 545 self._shape = self._shape.merge_with(shape) 546 547 def _shape_as_list(self): 548 if self.shape.ndims is None: 549 return None 550 return [dim.value for dim in self.shape.dims] 551 552 def _shape_tuple(self): 553 shape = self._shape_as_list() 554 if shape is None: 555 return None 556 return tuple(shape) 557 558 @property 559 def create(self): 560 """The op responsible for initializing this variable.""" 561 if not self._in_graph_mode: 562 raise RuntimeError("This operation is not supported " 563 "when eager execution is enabled.") 564 return self._initializer_op 565 566 @property 567 def handle(self): 568 """The handle by which this variable can be accessed.""" 569 return self._handle 570 571 def value(self): 572 """A cached operation which reads the value of this variable.""" 573 if self._cached_value is not None: 574 return self._cached_value 575 with ops.colocate_with(None, ignore_existing=True): 576 return self._read_variable_op() 577 578 def _as_graph_element(self): 579 """Conversion function for Graph.as_graph_element().""" 580 return self._graph_element 581 582 @property 583 def initializer(self): 584 """The op responsible for initializing this variable.""" 585 return self._initializer_op 586 587 @property 588 def initial_value(self): 589 """Returns the Tensor used as the initial value for the variable.""" 590 if context.executing_eagerly(): 591 raise RuntimeError("This property is not supported " 592 "when eager execution is enabled.") 593 return self._initial_value 594 595 @property 596 def constraint(self): 597 """Returns the constraint function associated with this variable. 598 599 Returns: 600 The constraint function that was passed to the variable constructor. 601 Can be `None` if no constraint was passed. 602 """ 603 return self._constraint 604 605 @property 606 def op(self): 607 """The op for this variable.""" 608 return self.handle.op 609 610 @property 611 def trainable(self): 612 return self._trainable 613 614 @property 615 def synchronization(self): 616 return self._synchronization 617 618 @property 619 def aggregation(self): 620 return self._aggregation 621 622 def eval(self, session=None): 623 """Evaluates and returns the value of this variable.""" 624 if context.executing_eagerly(): 625 raise RuntimeError("This operation is not supported " 626 "when eager execution is enabled.") 627 return self._graph_element.eval(session=session) 628 629 def numpy(self): 630 if context.executing_eagerly(): 631 return self.read_value().numpy() 632 raise NotImplementedError( 633 "numpy() is only available when eager execution is enabled.") 634 635 @deprecated(None, "Prefer Dataset.range instead.") 636 def count_up_to(self, limit): 637 """Increments this variable until it reaches `limit`. 638 639 When that Op is run it tries to increment the variable by `1`. If 640 incrementing the variable would bring it above `limit` then the Op raises 641 the exception `OutOfRangeError`. 642 643 If no error is raised, the Op outputs the value of the variable before 644 the increment. 645 646 This is essentially a shortcut for `count_up_to(self, limit)`. 647 648 Args: 649 limit: value at which incrementing the variable raises an error. 650 651 Returns: 652 A `Tensor` that will hold the variable value before the increment. If no 653 other Op modifies this variable, the values produced will all be 654 distinct. 655 """ 656 return gen_state_ops.resource_count_up_to( 657 self.handle, limit=limit, T=self.dtype) 658 659 def _map_resources(self, save_options): 660 """For implementing `Trackable`.""" 661 new_variable = None 662 if save_options.experimental_variable_policy._save_variable_devices(): # pylint:disable=protected-access 663 with ops.device(self.device): 664 new_variable = copy_to_graph_uninitialized(self) 665 else: 666 new_variable = copy_to_graph_uninitialized(self) 667 obj_map = {self: new_variable} 668 resource_map = {self.handle: new_variable.handle} 669 return obj_map, resource_map 670 671 def _read_variable_op(self): 672 variable_accessed(self) 673 674 def read_and_set_handle(): 675 result = gen_resource_variable_ops.read_variable_op( 676 self.handle, self._dtype) 677 _maybe_set_handle_data(self._dtype, self.handle, result) 678 return result 679 680 if getattr(self, "_caching_device", None) is not None: 681 with ops.colocate_with(None, ignore_existing=True): 682 with ops.device(self._caching_device): 683 result = read_and_set_handle() 684 else: 685 result = read_and_set_handle() 686 687 if not context.executing_eagerly(): 688 # Note that if a control flow context is active the input of the read op 689 # might not actually be the handle. This line bypasses it. 690 tape.record_operation( 691 "ReadVariableOp", [result], [self.handle], 692 backward_function=lambda x: [x], 693 forward_function=lambda x: [x]) 694 return result 695 696 def read_value(self): 697 """Constructs an op which reads the value of this variable. 698 699 Should be used when there are multiple reads, or when it is desirable to 700 read the value only after some condition is true. 701 702 Returns: 703 the read operation. 704 """ 705 with ops.name_scope("Read"): 706 value = self._read_variable_op() 707 # Return an identity so it can get placed on whatever device the context 708 # specifies instead of the device where the variable is. 709 return array_ops.identity(value) 710 711 def sparse_read(self, indices, name=None): 712 """Reads the value of this variable sparsely, using `gather`.""" 713 with ops.name_scope("Gather" if name is None else name) as name: 714 variable_accessed(self) 715 value = gen_resource_variable_ops.resource_gather( 716 self.handle, indices, dtype=self._dtype, name=name) 717 718 if self._dtype == dtypes.variant: 719 # For DT_VARIANT types, the handle's shape_and_type[1:] stores the 720 # variant's handle data. Extract it. 721 handle_data = get_eager_safe_handle_data(self.handle) 722 if handle_data.is_set and len(handle_data.shape_and_type) > 1: 723 value._handle_data = ( # pylint: disable=protected-access 724 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData( 725 is_set=True, shape_and_type=handle_data.shape_and_type[1:])) 726 727 return array_ops.identity(value) 728 729 def gather_nd(self, indices, name=None): 730 """Reads the value of this variable sparsely, using `gather_nd`.""" 731 with ops.name_scope("GatherNd" if name is None else name) as name: 732 if self.trainable: 733 variable_accessed(self) 734 value = gen_resource_variable_ops.resource_gather_nd( 735 self.handle, indices, dtype=self._dtype, name=name) 736 737 return array_ops.identity(value) 738 739 def to_proto(self, export_scope=None): 740 """Converts a `ResourceVariable` to a `VariableDef` protocol buffer. 741 742 Args: 743 export_scope: Optional `string`. Name scope to remove. 744 745 Raises: 746 RuntimeError: If run in EAGER mode. 747 748 Returns: 749 A `VariableDef` protocol buffer, or `None` if the `Variable` is not 750 in the specified name scope. 751 """ 752 if context.executing_eagerly(): 753 raise RuntimeError("This operation is not supported " 754 "when eager execution is enabled.") 755 if export_scope is None or self.handle.name.startswith(export_scope): 756 var_def = variable_pb2.VariableDef() 757 var_def.variable_name = ops.strip_name_scope(self.handle.name, 758 export_scope) 759 if self._initial_value is not None: 760 # This is inside an if-statement for backwards compatibility, since 761 # self._initial_value might be None for variables constructed from old 762 # protos. 763 var_def.initial_value_name = ops.strip_name_scope( 764 self._initial_value.name, export_scope) 765 var_def.initializer_name = ops.strip_name_scope(self.initializer.name, 766 export_scope) 767 if self._cached_value is not None: 768 var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name, 769 export_scope) 770 else: 771 # Store the graph_element here 772 var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name, 773 export_scope) 774 var_def.is_resource = True 775 var_def.trainable = self.trainable 776 var_def.synchronization = self.synchronization.value 777 var_def.aggregation = self.aggregation.value 778 if self._save_slice_info: 779 var_def.save_slice_info_def.MergeFrom( 780 self._save_slice_info.to_proto(export_scope=export_scope)) 781 return var_def 782 else: 783 return None 784 785 @staticmethod 786 def from_proto(variable_def, import_scope=None): 787 if context.executing_eagerly(): 788 raise RuntimeError("This operation is not supported " 789 "when eager execution is enabled.") 790 return ResourceVariable( 791 variable_def=variable_def, import_scope=import_scope) 792 793 __array_priority__ = 100 794 795 def is_initialized(self, name=None): 796 """Checks whether a resource variable has been initialized. 797 798 Outputs boolean scalar indicating whether the tensor has been initialized. 799 800 Args: 801 name: A name for the operation (optional). 802 803 Returns: 804 A `Tensor` of type `bool`. 805 """ 806 return gen_resource_variable_ops.var_is_initialized_op(self.handle, name) 807 808 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 809 """Subtracts a value from this variable. 810 811 Args: 812 delta: A `Tensor`. The value to subtract from this variable. 813 use_locking: If `True`, use locking during the operation. 814 name: The name to use for the operation. 815 read_value: A `bool`. Whether to read and return the new value of the 816 variable or not. 817 818 Returns: 819 If `read_value` is `True`, this method will return the new value of the 820 variable after the assignment has completed. Otherwise, when in graph mode 821 it will return the `Operation` that does the assignment, and when in eager 822 mode it will return `None`. 823 """ 824 # TODO(apassos): this here and below is not atomic. Consider making it 825 # atomic if there's a way to do so without a performance cost for those who 826 # don't need it. 827 with _handle_graph(self.handle), self._assign_dependencies(): 828 assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op( 829 self.handle, 830 ops.convert_to_tensor(delta, dtype=self.dtype), 831 name=name) 832 if read_value: 833 return self._lazy_read(assign_sub_op) 834 return assign_sub_op 835 836 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 837 """Adds a value to this variable. 838 839 Args: 840 delta: A `Tensor`. The value to add to this variable. 841 use_locking: If `True`, use locking during the operation. 842 name: The name to use for the operation. 843 read_value: A `bool`. Whether to read and return the new value of the 844 variable or not. 845 846 Returns: 847 If `read_value` is `True`, this method will return the new value of the 848 variable after the assignment has completed. Otherwise, when in graph mode 849 it will return the `Operation` that does the assignment, and when in eager 850 mode it will return `None`. 851 """ 852 with _handle_graph(self.handle), self._assign_dependencies(): 853 assign_add_op = gen_resource_variable_ops.assign_add_variable_op( 854 self.handle, 855 ops.convert_to_tensor(delta, dtype=self.dtype), 856 name=name) 857 if read_value: 858 return self._lazy_read(assign_add_op) 859 return assign_add_op 860 861 def _lazy_read(self, op): 862 variable_accessed(self) 863 return _UnreadVariable( 864 handle=self.handle, 865 dtype=self.dtype, 866 shape=self._shape, 867 in_graph_mode=self._in_graph_mode, 868 deleter=self._handle_deleter if not self._in_graph_mode else None, 869 parent_op=op, 870 unique_id=self._unique_id) 871 872 def assign(self, value, use_locking=None, name=None, read_value=True): 873 """Assigns a new value to this variable. 874 875 Args: 876 value: A `Tensor`. The new value for this variable. 877 use_locking: If `True`, use locking during the assignment. 878 name: The name to use for the assignment. 879 read_value: A `bool`. Whether to read and return the new value of the 880 variable or not. 881 882 Returns: 883 If `read_value` is `True`, this method will return the new value of the 884 variable after the assignment has completed. Otherwise, when in graph mode 885 it will return the `Operation` that does the assignment, and when in eager 886 mode it will return `None`. 887 """ 888 # Note: not depending on the cached value here since this can be used to 889 # initialize the variable. 890 with _handle_graph(self.handle): 891 value_tensor = ops.convert_to_tensor(value, dtype=self.dtype) 892 if not self._shape.is_compatible_with(value_tensor.shape): 893 if self.name is None: 894 tensor_name = "" 895 else: 896 tensor_name = " " + str(self.name) 897 raise ValueError( 898 (f"Cannot assign value to variable '{tensor_name}': Shape mismatch." 899 f"The variable shape {self._shape}, and the " 900 f"assigned value shape {value_tensor.shape} are incompatible.")) 901 assign_op = gen_resource_variable_ops.assign_variable_op( 902 self.handle, value_tensor, name=name) 903 if read_value: 904 return self._lazy_read(assign_op) 905 return assign_op 906 907 def __reduce__(self): 908 # The implementation mirrors that of __deepcopy__. 909 return functools.partial( 910 ResourceVariable, 911 initial_value=self.numpy(), 912 trainable=self.trainable, 913 name=self._shared_name, 914 dtype=self.dtype, 915 constraint=self.constraint, 916 distribute_strategy=self._distribute_strategy), () 917 918 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 919 """Subtracts `tf.IndexedSlices` from this variable. 920 921 Args: 922 sparse_delta: `tf.IndexedSlices` to be subtracted from this variable. 923 use_locking: If `True`, use locking during the operation. 924 name: the name of the operation. 925 926 Returns: 927 The updated variable. 928 929 Raises: 930 TypeError: if `sparse_delta` is not an `IndexedSlices`. 931 """ 932 if not isinstance(sparse_delta, ops.IndexedSlices): 933 raise TypeError(f"Argument `sparse_delta` must be a " 934 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 935 return self._lazy_read( 936 gen_resource_variable_ops.resource_scatter_sub( 937 self.handle, 938 sparse_delta.indices, 939 ops.convert_to_tensor(sparse_delta.values, self.dtype), 940 name=name)) 941 942 def scatter_add(self, sparse_delta, use_locking=False, name=None): 943 """Adds `tf.IndexedSlices` to this variable. 944 945 Args: 946 sparse_delta: `tf.IndexedSlices` to be added to this variable. 947 use_locking: If `True`, use locking during the operation. 948 name: the name of the operation. 949 950 Returns: 951 The updated variable. 952 953 Raises: 954 TypeError: if `sparse_delta` is not an `IndexedSlices`. 955 """ 956 if not isinstance(sparse_delta, ops.IndexedSlices): 957 raise TypeError(f"Argument `sparse_delta` must be a " 958 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 959 return self._lazy_read( 960 gen_resource_variable_ops.resource_scatter_add( 961 self.handle, 962 sparse_delta.indices, 963 ops.convert_to_tensor(sparse_delta.values, self.dtype), 964 name=name)) 965 966 def scatter_max(self, sparse_delta, use_locking=False, name=None): 967 """Updates this variable with the max of `tf.IndexedSlices` and itself. 968 969 Args: 970 sparse_delta: `tf.IndexedSlices` to use as an argument of max with this 971 variable. 972 use_locking: If `True`, use locking during the operation. 973 name: the name of the operation. 974 975 Returns: 976 The updated variable. 977 978 Raises: 979 TypeError: if `sparse_delta` is not an `IndexedSlices`. 980 """ 981 if not isinstance(sparse_delta, ops.IndexedSlices): 982 raise TypeError(f"Argument `sparse_delta` must be a " 983 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 984 return self._lazy_read( 985 gen_resource_variable_ops.resource_scatter_max( 986 self.handle, 987 sparse_delta.indices, 988 ops.convert_to_tensor(sparse_delta.values, self.dtype), 989 name=name)) 990 991 def scatter_min(self, sparse_delta, use_locking=False, name=None): 992 """Updates this variable with the min of `tf.IndexedSlices` and itself. 993 994 Args: 995 sparse_delta: `tf.IndexedSlices` to use as an argument of min with this 996 variable. 997 use_locking: If `True`, use locking during the operation. 998 name: the name of the operation. 999 1000 Returns: 1001 The updated variable. 1002 1003 Raises: 1004 TypeError: if `sparse_delta` is not an `IndexedSlices`. 1005 """ 1006 if not isinstance(sparse_delta, ops.IndexedSlices): 1007 raise TypeError(f"Argument `sparse_delta` must be a " 1008 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 1009 return self._lazy_read( 1010 gen_resource_variable_ops.resource_scatter_min( 1011 self.handle, 1012 sparse_delta.indices, 1013 ops.convert_to_tensor(sparse_delta.values, self.dtype), 1014 name=name)) 1015 1016 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 1017 """Multiply this variable by `tf.IndexedSlices`. 1018 1019 Args: 1020 sparse_delta: `tf.IndexedSlices` to multiply this variable by. 1021 use_locking: If `True`, use locking during the operation. 1022 name: the name of the operation. 1023 1024 Returns: 1025 The updated variable. 1026 1027 Raises: 1028 TypeError: if `sparse_delta` is not an `IndexedSlices`. 1029 """ 1030 if not isinstance(sparse_delta, ops.IndexedSlices): 1031 raise TypeError(f"Argument `sparse_delta` must be a " 1032 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 1033 return self._lazy_read( 1034 gen_resource_variable_ops.resource_scatter_mul( 1035 self.handle, 1036 sparse_delta.indices, 1037 ops.convert_to_tensor(sparse_delta.values, self.dtype), 1038 name=name)) 1039 1040 def scatter_div(self, sparse_delta, use_locking=False, name=None): 1041 """Divide this variable by `tf.IndexedSlices`. 1042 1043 Args: 1044 sparse_delta: `tf.IndexedSlices` to divide this variable by. 1045 use_locking: If `True`, use locking during the operation. 1046 name: the name of the operation. 1047 1048 Returns: 1049 The updated variable. 1050 1051 Raises: 1052 TypeError: if `sparse_delta` is not an `IndexedSlices`. 1053 """ 1054 if not isinstance(sparse_delta, ops.IndexedSlices): 1055 raise TypeError(f"Argument `sparse_delta` must be a " 1056 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 1057 return self._lazy_read( 1058 gen_resource_variable_ops.resource_scatter_div( 1059 self.handle, 1060 sparse_delta.indices, 1061 ops.convert_to_tensor(sparse_delta.values, self.dtype), 1062 name=name)) 1063 1064 def scatter_update(self, sparse_delta, use_locking=False, name=None): 1065 """Assigns `tf.IndexedSlices` to this variable. 1066 1067 Args: 1068 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 1069 use_locking: If `True`, use locking during the operation. 1070 name: the name of the operation. 1071 1072 Returns: 1073 The updated variable. 1074 1075 Raises: 1076 TypeError: if `sparse_delta` is not an `IndexedSlices`. 1077 """ 1078 if not isinstance(sparse_delta, ops.IndexedSlices): 1079 raise TypeError(f"Argument `sparse_delta` must be a " 1080 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 1081 return self._lazy_read( 1082 gen_resource_variable_ops.resource_scatter_update( 1083 self.handle, 1084 sparse_delta.indices, 1085 ops.convert_to_tensor(sparse_delta.values, self.dtype), 1086 name=name)) 1087 1088 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 1089 """Assigns `tf.IndexedSlices` to this variable batch-wise. 1090 1091 Analogous to `batch_gather`. This assumes that this variable and the 1092 sparse_delta IndexedSlices have a series of leading dimensions that are the 1093 same for all of them, and the updates are performed on the last dimension of 1094 indices. In other words, the dimensions should be the following: 1095 1096 `num_prefix_dims = sparse_delta.indices.ndims - 1` 1097 `batch_dim = num_prefix_dims + 1` 1098 `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[ 1099 batch_dim:]` 1100 1101 where 1102 1103 `sparse_delta.updates.shape[:num_prefix_dims]` 1104 `== sparse_delta.indices.shape[:num_prefix_dims]` 1105 `== var.shape[:num_prefix_dims]` 1106 1107 And the operation performed can be expressed as: 1108 1109 `var[i_1, ..., i_n, 1110 sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[ 1111 i_1, ..., i_n, j]` 1112 1113 When sparse_delta.indices is a 1D tensor, this operation is equivalent to 1114 `scatter_update`. 1115 1116 To avoid this operation one can looping over the first `ndims` of the 1117 variable and using `scatter_update` on the subtensors that result of slicing 1118 the first dimension. This is a valid option for `ndims = 1`, but less 1119 efficient than this implementation. 1120 1121 Args: 1122 sparse_delta: `tf.IndexedSlices` to be assigned to this variable. 1123 use_locking: If `True`, use locking during the operation. 1124 name: the name of the operation. 1125 1126 Returns: 1127 The updated variable. 1128 1129 Raises: 1130 TypeError: if `sparse_delta` is not an `IndexedSlices`. 1131 """ 1132 if not isinstance(sparse_delta, ops.IndexedSlices): 1133 raise TypeError(f"Argument `sparse_delta` must be a " 1134 f"`tf.IndexedSlices`. Received arg: {sparse_delta}") 1135 return self._lazy_read( 1136 state_ops.batch_scatter_update( 1137 self, 1138 sparse_delta.indices, 1139 sparse_delta.values, 1140 use_locking=use_locking, 1141 name=name)) 1142 1143 def scatter_nd_sub(self, indices, updates, name=None): 1144 """Applies sparse subtraction to individual values or slices in a Variable. 1145 1146 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 1147 1148 `indices` must be integer tensor, containing indices into `ref`. 1149 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 1150 1151 The innermost dimension of `indices` (with length `K`) corresponds to 1152 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 1153 dimension of `ref`. 1154 1155 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 1156 1157 ``` 1158 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 1159 ``` 1160 1161 For example, say we want to add 4 scattered elements to a rank-1 tensor to 1162 8 elements. In Python, that update would look like this: 1163 1164 ```python 1165 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 1166 indices = tf.constant([[4], [3], [1] ,[7]]) 1167 updates = tf.constant([9, 10, 11, 12]) 1168 op = ref.scatter_nd_sub(indices, updates) 1169 with tf.compat.v1.Session() as sess: 1170 print sess.run(op) 1171 ``` 1172 1173 The resulting update to ref would look like this: 1174 1175 [1, -9, 3, -6, -6, 6, 7, -4] 1176 1177 See `tf.scatter_nd` for more details about how to make updates to 1178 slices. 1179 1180 Args: 1181 indices: The indices to be used in the operation. 1182 updates: The values to be used in the operation. 1183 name: the name of the operation. 1184 1185 Returns: 1186 The updated variable. 1187 """ 1188 return self._lazy_read( 1189 gen_state_ops.resource_scatter_nd_sub( 1190 self.handle, 1191 indices, 1192 ops.convert_to_tensor(updates, self.dtype), 1193 name=name)) 1194 1195 def scatter_nd_add(self, indices, updates, name=None): 1196 """Applies sparse addition to individual values or slices in a Variable. 1197 1198 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 1199 1200 `indices` must be integer tensor, containing indices into `ref`. 1201 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 1202 1203 The innermost dimension of `indices` (with length `K`) corresponds to 1204 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 1205 dimension of `ref`. 1206 1207 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 1208 1209 ``` 1210 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 1211 ``` 1212 1213 For example, say we want to add 4 scattered elements to a rank-1 tensor to 1214 8 elements. In Python, that update would look like this: 1215 1216 ```python 1217 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 1218 indices = tf.constant([[4], [3], [1] ,[7]]) 1219 updates = tf.constant([9, 10, 11, 12]) 1220 add = ref.scatter_nd_add(indices, updates) 1221 with tf.compat.v1.Session() as sess: 1222 print sess.run(add) 1223 ``` 1224 1225 The resulting update to ref would look like this: 1226 1227 [1, 13, 3, 14, 14, 6, 7, 20] 1228 1229 See `tf.scatter_nd` for more details about how to make updates to 1230 slices. 1231 1232 Args: 1233 indices: The indices to be used in the operation. 1234 updates: The values to be used in the operation. 1235 name: the name of the operation. 1236 1237 Returns: 1238 The updated variable. 1239 """ 1240 return self._lazy_read( 1241 gen_state_ops.resource_scatter_nd_add( 1242 self.handle, 1243 indices, 1244 ops.convert_to_tensor(updates, self.dtype), 1245 name=name)) 1246 1247 def scatter_nd_update(self, indices, updates, name=None): 1248 """Applies sparse assignment to individual values or slices in a Variable. 1249 1250 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 1251 1252 `indices` must be integer tensor, containing indices into `ref`. 1253 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 1254 1255 The innermost dimension of `indices` (with length `K`) corresponds to 1256 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 1257 dimension of `ref`. 1258 1259 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 1260 1261 ``` 1262 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 1263 ``` 1264 1265 For example, say we want to add 4 scattered elements to a rank-1 tensor to 1266 8 elements. In Python, that update would look like this: 1267 1268 ```python 1269 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 1270 indices = tf.constant([[4], [3], [1] ,[7]]) 1271 updates = tf.constant([9, 10, 11, 12]) 1272 op = ref.scatter_nd_update(indices, updates) 1273 with tf.compat.v1.Session() as sess: 1274 print sess.run(op) 1275 ``` 1276 1277 The resulting update to ref would look like this: 1278 1279 [1, 11, 3, 10, 9, 6, 7, 12] 1280 1281 See `tf.scatter_nd` for more details about how to make updates to 1282 slices. 1283 1284 Args: 1285 indices: The indices to be used in the operation. 1286 updates: The values to be used in the operation. 1287 name: the name of the operation. 1288 1289 Returns: 1290 The updated variable. 1291 """ 1292 return self._lazy_read( 1293 gen_state_ops.resource_scatter_nd_update( 1294 self.handle, 1295 indices, 1296 ops.convert_to_tensor(updates, self.dtype), 1297 name=name)) 1298 1299 def scatter_nd_max(self, indices, updates, name=None): 1300 """Updates this variable with the max of `tf.IndexedSlices` and itself. 1301 1302 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 1303 1304 `indices` must be integer tensor, containing indices into `ref`. 1305 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 1306 1307 The innermost dimension of `indices` (with length `K`) corresponds to 1308 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 1309 dimension of `ref`. 1310 1311 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 1312 1313 ``` 1314 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 1315 ``` 1316 1317 See `tf.scatter_nd` for more details about how to make updates to 1318 slices. 1319 1320 Args: 1321 indices: The indices to be used in the operation. 1322 updates: The values to be used in the operation. 1323 name: the name of the operation. 1324 1325 Returns: 1326 The updated variable. 1327 """ 1328 return self._lazy_read( 1329 gen_state_ops.resource_scatter_nd_max( 1330 self.handle, 1331 indices, 1332 ops.convert_to_tensor(updates, self.dtype), 1333 name=name)) 1334 1335 def scatter_nd_min(self, indices, updates, name=None): 1336 """Updates this variable with the min of `tf.IndexedSlices` and itself. 1337 1338 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 1339 1340 `indices` must be integer tensor, containing indices into `ref`. 1341 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 1342 1343 The innermost dimension of `indices` (with length `K`) corresponds to 1344 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 1345 dimension of `ref`. 1346 1347 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 1348 1349 ``` 1350 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 1351 ``` 1352 1353 See `tf.scatter_nd` for more details about how to make updates to 1354 slices. 1355 1356 Args: 1357 indices: The indices to be used in the operation. 1358 updates: The values to be used in the operation. 1359 name: the name of the operation. 1360 1361 Returns: 1362 The updated variable. 1363 """ 1364 return self._lazy_read( 1365 gen_state_ops.resource_scatter_nd_min( 1366 self.handle, 1367 indices, 1368 ops.convert_to_tensor(updates, self.dtype), 1369 name=name)) 1370 1371 def _write_object_proto(self, proto, options): 1372 """Writes additional information of the variable into the SavedObject proto. 1373 1374 Subclasses of ResourceVariables could choose to override this method to 1375 customize extra information to provide when saving a SavedModel. 1376 1377 Ideally, this should contain the logic in 1378 write_object_proto_for_resource_variable but `DistributedValue` is an 1379 outlier at the momemnt. Once `DistributedValue` becomes a proper 1380 ResourceVariable, we should remove the helper method below. 1381 1382 Args: 1383 proto: `SavedObject` proto to update. 1384 options: A `SaveOption` instance that configures save behavior. 1385 """ 1386 write_object_proto_for_resource_variable(self, proto, options) 1387 1388 def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask, 1389 end_mask, ellipsis_mask, new_axis_mask, 1390 shrink_axis_mask): 1391 with _handle_graph(self.handle), self._assign_dependencies(): 1392 return self._lazy_read( 1393 gen_array_ops.resource_strided_slice_assign( 1394 ref=self.handle, 1395 begin=begin, 1396 end=end, 1397 strides=strides, 1398 value=ops.convert_to_tensor(value, dtype=self.dtype), 1399 name=name, 1400 begin_mask=begin_mask, 1401 end_mask=end_mask, 1402 ellipsis_mask=ellipsis_mask, 1403 new_axis_mask=new_axis_mask, 1404 shrink_axis_mask=shrink_axis_mask)) 1405 1406 def __complex__(self): 1407 return complex(self.value().numpy()) 1408 1409 def __int__(self): 1410 return int(self.value().numpy()) 1411 1412 def __long__(self): 1413 return long(self.value().numpy()) 1414 1415 def __float__(self): 1416 return float(self.value().numpy()) 1417 1418 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 1419 del name 1420 if dtype is not None and not dtype.is_compatible_with(self.dtype): 1421 raise ValueError( 1422 f"Incompatible type conversion requested to type {dtype.name} for " 1423 f"`tf.Variable of type {self.dtype.name}. (Variable: {self})") 1424 if as_ref: 1425 return self.read_value().op.inputs[0] 1426 else: 1427 return self.value() 1428 1429 def __iadd__(self, unused_other): 1430 raise RuntimeError("`variable += value` with `tf.Variable`s is not " 1431 "supported. Use `variable.assign_add(value)` to modify " 1432 "the variable, or `out = variable + value` if you " 1433 "need to get a new output Tensor.") 1434 1435 def __isub__(self, unused_other): 1436 raise RuntimeError("`variable -= value` with `tf.Variable`s is not " 1437 "supported. Use `variable.assign_sub(value)` to modify " 1438 "the variable, or `out = variable * value` if you " 1439 "need to get a new output Tensor.") 1440 1441 def __imul__(self, unused_other): 1442 raise RuntimeError("`var *= value` with `tf.Variable`s is not " 1443 "supported. Use `var.assign(var * value)` to modify " 1444 "the variable, or `out = var * value` if you " 1445 "need to get a new output Tensor.") 1446 1447 def __idiv__(self, unused_other): 1448 raise RuntimeError("`var /= value` with `tf.Variable`s is not " 1449 "supported. Use `var.assign(var / value)` to modify " 1450 "the variable, or `out = var / value` if you " 1451 "need to get a new output Tensor.") 1452 1453 def __itruediv__(self, unused_other): 1454 raise RuntimeError("`var /= value` with `tf.Variable`s is not " 1455 "supported. Use `var.assign(var / value)` to modify " 1456 "the variable, or `out = var / value` if you " 1457 "need to get a new output Tensor.") 1458 1459 def __irealdiv__(self, unused_other): 1460 raise RuntimeError("`var /= value` with `tf.Variable`s is not " 1461 "supported. Use `var.assign(var / value)` to modify " 1462 "the variable, or `out = var / value` if you " 1463 "need to get a new output Tensor.") 1464 1465 def __ipow__(self, unused_other): 1466 raise RuntimeError("`var **= value` with `tf.Variable`s is not " 1467 "supported. Use `var.assign(var ** value)` to modify " 1468 "the variable, or `out = var ** value` if you " 1469 "need to get a new output Tensor.") 1470 1471 1472class ResourceVariable(BaseResourceVariable): 1473 """Variable based on resource handles. 1474 1475 See the [Variables How To](https://tensorflow.org/guide/variables) 1476 for a high level overview. 1477 1478 A `ResourceVariable` allows you to maintain state across subsequent calls to 1479 session.run. 1480 1481 The `ResourceVariable` constructor requires an initial value for the variable, 1482 which can be a `Tensor` of any type and shape. The initial value defines the 1483 type and shape of the variable. After construction, the type and shape of 1484 the variable are fixed. The value can be changed using one of the assign 1485 methods. 1486 1487 Just like any `Tensor`, variables created with 1488 `tf.Variable(use_resource=True)` can be used as inputs for other Ops in the 1489 graph. Additionally, all the operators overloaded for the `Tensor` class are 1490 carried over to variables, so you can also add nodes to the graph by just 1491 doing arithmetic on variables. 1492 1493 Unlike ref-based variable, a ResourceVariable has well-defined semantics. Each 1494 usage of a ResourceVariable in a TensorFlow graph adds a read_value operation 1495 to the graph. The Tensors returned by a read_value operation are guaranteed to 1496 see all modifications to the value of the variable which happen in any 1497 operation on which the read_value depends on (either directly, indirectly, or 1498 via a control dependency) and guaranteed to not see any modification to the 1499 value of the variable from operations that depend on the read_value operation. 1500 Updates from operations that have no dependency relationship to the read_value 1501 operation might or might not be visible to read_value. 1502 1503 For example, if there is more than one assignment to a ResourceVariable in 1504 a single session.run call there is a well-defined value for each operation 1505 which uses the variable's value if the assignments and the read are connected 1506 by edges in the graph. Consider the following example, in which two writes 1507 can cause tf.Variable and tf.ResourceVariable to behave differently: 1508 1509 ```python 1510 a = tf.Variable(1.0, use_resource=True) 1511 a.initializer.run() 1512 1513 assign = a.assign(2.0) 1514 with tf.control_dependencies([assign]): 1515 b = a.read_value() 1516 with tf.control_dependencies([b]): 1517 other_assign = a.assign(3.0) 1518 with tf.control_dependencies([other_assign]): 1519 # Will print 2.0 because the value was read before other_assign ran. If 1520 # `a` was a tf.Variable instead, 2.0 or 3.0 could be printed. 1521 tf.compat.v1.Print(b, [b]).eval() 1522 ``` 1523 """ 1524 1525 def __init__( 1526 self, # pylint: disable=super-init-not-called 1527 initial_value=None, 1528 trainable=None, 1529 collections=None, 1530 validate_shape=True, # pylint: disable=unused-argument 1531 caching_device=None, 1532 name=None, 1533 dtype=None, 1534 variable_def=None, 1535 import_scope=None, 1536 constraint=None, 1537 distribute_strategy=None, 1538 synchronization=None, 1539 aggregation=None, 1540 shape=None): 1541 """Creates a variable. 1542 1543 Args: 1544 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 1545 which is the initial value for the Variable. Can also be a callable with 1546 no argument that returns the initial value when called. (Note that 1547 initializer functions from init_ops.py must first be bound to a shape 1548 before being used here.) 1549 trainable: If `True`, the default, also adds the variable to the graph 1550 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 1551 the default list of variables to use by the `Optimizer` classes. 1552 Defaults to `True`, unless `synchronization` is set to `ON_READ`, in 1553 which case it defaults to `False`. 1554 collections: List of graph collections keys. The new variable is added to 1555 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1556 validate_shape: Ignored. Provided for compatibility with tf.Variable. 1557 caching_device: Optional device string or function describing where the 1558 Variable should be cached for reading. Defaults to the Variable's 1559 device. If not `None`, caches on another device. Typical use is to 1560 cache on the device where the Ops using the Variable reside, to 1561 deduplicate copying through `Switch` and other conditional statements. 1562 name: Optional name for the variable. Defaults to `'Variable'` and gets 1563 uniquified automatically. 1564 dtype: If set, initial_value will be converted to the given type. If None, 1565 either the datatype will be kept (if initial_value is a Tensor) or 1566 float32 will be used (if it is a Python object convertible to a Tensor). 1567 variable_def: `VariableDef` protocol buffer. If not None, recreates the 1568 `ResourceVariable` object with its contents. `variable_def` and other 1569 arguments (except for import_scope) are mutually exclusive. 1570 import_scope: Optional `string`. Name scope to add to the 1571 ResourceVariable. Only used when `variable_def` is provided. 1572 constraint: An optional projection function to be applied to the variable 1573 after being updated by an `Optimizer` (e.g. used to implement norm 1574 constraints or value constraints for layer weights). The function must 1575 take as input the unprojected Tensor representing the value of the 1576 variable and return the Tensor for the projected value (which must have 1577 the same shape). Constraints are not safe to use when doing asynchronous 1578 distributed training. 1579 distribute_strategy: The tf.distribute.Strategy this variable is being 1580 created inside of. 1581 synchronization: Indicates when a distributed a variable will be 1582 aggregated. Accepted values are constants defined in the class 1583 `tf.VariableSynchronization`. By default the synchronization is set to 1584 `AUTO` and the current `DistributionStrategy` chooses when to 1585 synchronize. 1586 aggregation: Indicates how a distributed variable will be aggregated. 1587 Accepted values are constants defined in the class 1588 `tf.VariableAggregation`. 1589 shape: (optional) The shape of this variable. If None, the shape of 1590 `initial_value` will be used. When setting this argument to 1591 `tf.TensorShape(None)` (representing an unspecified shape), the variable 1592 can be assigned with values of different shapes. 1593 1594 Raises: 1595 ValueError: If the initial value is not specified, or does not have a 1596 shape and `validate_shape` is `True`. 1597 1598 @compatibility(eager) 1599 When Eager Execution is enabled, the default for the `collections` argument 1600 is `None`, which signifies that this `Variable` will not be added to any 1601 collections. 1602 @end_compatibility 1603 """ 1604 if variable_def: 1605 if initial_value is not None: 1606 raise ValueError(f"The variable_def and initial_value args to " 1607 f"`tf.Variable` are mutually exclusive, but got both: " 1608 f"variable_def={variable_def},\n" 1609 f"initial_value={initial_value}") 1610 if context.executing_eagerly(): 1611 raise ValueError(f"Creating a `tf.Variable` with a `variable_def` arg " 1612 f"is not supported when eager execution is enabled. " 1613 f"Got: variable_def={variable_def}") 1614 self._init_from_proto(variable_def, import_scope=import_scope) 1615 else: 1616 self._init_from_args( 1617 initial_value=initial_value, 1618 trainable=trainable, 1619 collections=collections, 1620 caching_device=caching_device, 1621 name=name, 1622 dtype=dtype, 1623 constraint=constraint, 1624 synchronization=synchronization, 1625 aggregation=aggregation, 1626 shape=shape, 1627 distribute_strategy=distribute_strategy) 1628 1629 def _init_from_args(self, 1630 initial_value=None, 1631 trainable=None, 1632 collections=None, 1633 caching_device=None, 1634 name=None, 1635 dtype=None, 1636 constraint=None, 1637 synchronization=None, 1638 aggregation=None, 1639 distribute_strategy=None, 1640 shape=None): 1641 """Creates a variable. 1642 1643 Args: 1644 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 1645 which is the initial value for the Variable. The initial value must have 1646 a shape specified unless `validate_shape` is set to False. Can also be a 1647 callable with no argument that returns the initial value when called. 1648 (Note that initializer functions from init_ops.py must first be bound to 1649 a shape before being used here.) 1650 trainable: If `True`, the default, also adds the variable to the graph 1651 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 1652 the default list of variables to use by the `Optimizer` classes. 1653 Defaults to `True`, unless `synchronization` is set to `ON_READ`, in 1654 which case it defaults to `False`. 1655 collections: List of graph collections keys. The new variable is added to 1656 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 1657 caching_device: Optional device string or function describing where the 1658 Variable should be cached for reading. Defaults to the Variable's 1659 device. If not `None`, caches on another device. Typical use is to 1660 cache on the device where the Ops using the Variable reside, to 1661 deduplicate copying through `Switch` and other conditional statements. 1662 name: Optional name for the variable. Defaults to `'Variable'` and gets 1663 uniquified automatically. 1664 dtype: If set, initial_value will be converted to the given type. If None, 1665 either the datatype will be kept (if initial_value is a Tensor) or 1666 float32 will be used (if it is a Python object convertible to a Tensor). 1667 constraint: An optional projection function to be applied to the variable 1668 after being updated by an `Optimizer` (e.g. used to implement norm 1669 constraints or value constraints for layer weights). The function must 1670 take as input the unprojected Tensor representing the value of the 1671 variable and return the Tensor for the projected value (which must have 1672 the same shape). Constraints are not safe to use when doing asynchronous 1673 distributed training. 1674 synchronization: Indicates when a distributed a variable will be 1675 aggregated. Accepted values are constants defined in the class 1676 `tf.VariableSynchronization`. By default the synchronization is set to 1677 `AUTO` and the current `DistributionStrategy` chooses when to 1678 synchronize. 1679 aggregation: Indicates how a distributed variable will be aggregated. 1680 Accepted values are constants defined in the class 1681 `tf.VariableAggregation`. 1682 distribute_strategy: DistributionStrategy under which this variable was 1683 created. 1684 shape: (optional) The shape of this variable. If None, the shape of 1685 `initial_value` will be used. When setting this argument to 1686 `tf.TensorShape(None)` (representing an unspecified shape), the variable 1687 can be assigned with values of different shapes. 1688 1689 Raises: 1690 ValueError: If the initial value is not specified, or does not have a 1691 shape and `validate_shape` is `True`. 1692 1693 @compatibility(eager) 1694 When Eager Execution is enabled, variables are never added to collections. 1695 It is not implicitly added to the `GLOBAL_VARIABLES` or 1696 `TRAINABLE_VARIABLES` collections, and the `collections` argument is 1697 ignored. 1698 @end_compatibility 1699 """ 1700 synchronization, aggregation, trainable = ( 1701 variables.validate_synchronization_aggregation_trainable( 1702 synchronization, aggregation, trainable, name)) 1703 if initial_value is None: 1704 raise ValueError("The `initial_value` arg to `tf.Variable` must " 1705 "be specified except when you are not providing a " 1706 "`variable_def`. You provided neither.") 1707 init_from_fn = callable(initial_value) 1708 1709 if isinstance(initial_value, ops.Tensor) and hasattr( 1710 initial_value, "graph") and initial_value.graph.building_function: 1711 raise ValueError(f"Argument `initial_value` ({initial_value}) could not " 1712 "be lifted out of a `tf.function`. " 1713 "(Tried to create variable with name='{name}'). " 1714 "To avoid this error, when constructing `tf.Variable`s " 1715 "inside of `tf.function` you can create the " 1716 "`initial_value` tensor in a " 1717 "`tf.init_scope` or pass a callable `initial_value` " 1718 "(e.g., `tf.Variable(lambda : " 1719 "tf.truncated_normal([10, 40]))`). " 1720 "Please file a feature request if this " 1721 "restriction inconveniences you.") 1722 1723 if collections is None: 1724 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 1725 if not isinstance(collections, (list, tuple, set)): 1726 raise ValueError( 1727 f"collections argument to Variable constructor must be a list, " 1728 f"tuple, or set. Got {collections} of type {type(collections)}") 1729 if constraint is not None and not callable(constraint): 1730 raise ValueError(f"Argument `constraint` must be None or a callable. " 1731 f"a callable. Got a {type(constraint)}: {constraint}") 1732 1733 if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: 1734 collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] 1735 with ops.init_scope(): 1736 self._in_graph_mode = not context.executing_eagerly() 1737 with ops.name_scope( 1738 name, 1739 "Variable", [] if init_from_fn else [initial_value], 1740 skip_on_eager=False) as name: 1741 # pylint: disable=protected-access 1742 handle_name = ops.name_from_scope_name(name) 1743 if self._in_graph_mode: 1744 shared_name = handle_name 1745 unique_id = shared_name 1746 else: 1747 # When in eager mode use a uid for the shared_name, to prevent 1748 # accidental sharing. 1749 unique_id = "%s_%d" % (handle_name, ops.uid()) 1750 shared_name = None # Never shared 1751 # Use attr_scope and device(None) to simulate the behavior of 1752 # colocate_with when the variable we want to colocate with doesn't 1753 # yet exist. 1754 device_context_manager = ( 1755 ops.device if self._in_graph_mode else ops.NullContextmanager) 1756 attr = attr_value_pb2.AttrValue( 1757 list=attr_value_pb2.AttrValue.ListValue( 1758 s=[compat.as_bytes("loc:@%s" % handle_name)])) 1759 with ops.get_default_graph()._attr_scope({"_class": attr}): 1760 with ops.name_scope("Initializer"), device_context_manager(None): 1761 if init_from_fn: 1762 initial_value = initial_value() 1763 if isinstance(initial_value, trackable.CheckpointInitialValue): 1764 self._maybe_initialize_trackable() 1765 self._update_uid = initial_value.checkpoint_position.restore_uid 1766 initial_value = initial_value.wrapped_value 1767 initial_value = ops.convert_to_tensor(initial_value, 1768 name="initial_value", 1769 dtype=dtype) 1770 if shape is not None: 1771 if not initial_value.shape.is_compatible_with(shape): 1772 raise ValueError( 1773 f"In this `tf.Variable` creation, the initial value's shape " 1774 f"({initial_value.shape}) is not compatible with " 1775 f"the explicitly supplied `shape` argument ({shape}).") 1776 else: 1777 shape = initial_value.shape 1778 handle = eager_safe_variable_handle( 1779 initial_value=initial_value, 1780 shape=shape, 1781 shared_name=shared_name, 1782 name=name, 1783 graph_mode=self._in_graph_mode) 1784 # pylint: disable=protected-access 1785 if (self._in_graph_mode and initial_value is not None and 1786 initial_value.op._get_control_flow_context() is not None): 1787 raise ValueError( 1788 f"The `initial_value` passed to `tf.Variable` {name} is from " 1789 f"inside a control-flow construct, such as a loop or " 1790 f"conditional. When creating a " 1791 f"`tf.Variable` inside a loop or conditional, use a lambda as " 1792 f"the `initial_value`. Got: initial_value=({initial_value})") 1793 # pylint: enable=protected-access 1794 dtype = initial_value.dtype.base_dtype 1795 1796 if self._in_graph_mode: 1797 with ops.name_scope("IsInitialized"): 1798 is_initialized_op = ( 1799 gen_resource_variable_ops.var_is_initialized_op(handle)) 1800 if initial_value is not None: 1801 # pylint: disable=g-backslash-continuation 1802 with ops.name_scope("Assign") as n, \ 1803 ops.colocate_with(None, ignore_existing=True), \ 1804 ops.device(handle.device): 1805 # pylint: disable=protected-access 1806 initializer_op = ( 1807 gen_resource_variable_ops.assign_variable_op( 1808 handle, 1809 variables._try_guard_against_uninitialized_dependencies( 1810 name, initial_value), 1811 name=n)) 1812 # pylint: enable=protected-access 1813 # pylint: enable=g-backslash-continuation 1814 with ops.name_scope("Read"): 1815 # Manually assign reads to the handle's device to avoid log 1816 # messages. 1817 with ops.device(handle.device): 1818 value = gen_resource_variable_ops.read_variable_op(handle, dtype) 1819 _maybe_set_handle_data(dtype, handle, value) 1820 graph_element = value 1821 if caching_device is not None: 1822 # Variables may be created in a tf.device() or ops.colocate_with() 1823 # context. At the same time, users would expect caching device to 1824 # be independent of this context, and/or would not expect the 1825 # current device context to be merged with the caching device 1826 # spec. Therefore we reset the colocation stack before creating 1827 # the cached value. Note that resetting the colocation stack will 1828 # also reset the device stack. 1829 with ops.colocate_with(None, ignore_existing=True): 1830 with ops.device(caching_device): 1831 cached_value = array_ops.identity(value) 1832 else: 1833 cached_value = None 1834 else: 1835 gen_resource_variable_ops.assign_variable_op(handle, initial_value) 1836 is_initialized_op = None 1837 initializer_op = None 1838 graph_element = None 1839 if caching_device: 1840 with ops.device(caching_device): 1841 cached_value = gen_resource_variable_ops.read_variable_op( 1842 handle, dtype) 1843 _maybe_set_handle_data(dtype, handle, cached_value) 1844 else: 1845 cached_value = None 1846 1847 if cached_value is not None: 1848 # Store the variable object so that the original variable can be 1849 # accessed to generate functions that are compatible with SavedModel. 1850 cached_value._cached_variable = weakref.ref(self) # pylint: disable=protected-access 1851 1852 if not context.executing_eagerly(): 1853 # Eager variables are only added to collections if they are part of an 1854 # eager variable store (otherwise in an interactive session they would 1855 # hog memory and cause OOM). This is done in ops/variable_scope.py. 1856 ops.add_to_collections(collections, self) 1857 elif ops.GraphKeys.GLOBAL_STEP in collections: 1858 ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self) 1859 initial_value = initial_value if self._in_graph_mode else None 1860 super(ResourceVariable, self).__init__( 1861 trainable=trainable, 1862 shape=shape, 1863 dtype=dtype, 1864 handle=handle, 1865 synchronization=synchronization, 1866 constraint=constraint, 1867 aggregation=aggregation, 1868 distribute_strategy=distribute_strategy, 1869 name=name, 1870 unique_id=unique_id, 1871 handle_name=handle_name, 1872 graph_element=graph_element, 1873 initial_value=initial_value, 1874 initializer_op=initializer_op, 1875 is_initialized_op=is_initialized_op, 1876 cached_value=cached_value, 1877 caching_device=caching_device) 1878 1879 def _init_from_proto(self, variable_def, import_scope=None): 1880 """Initializes from `VariableDef` proto.""" 1881 # Note that init_from_proto is currently not supported in Eager mode. 1882 assert not context.executing_eagerly() 1883 self._in_graph_mode = True 1884 assert isinstance(variable_def, variable_pb2.VariableDef) 1885 if not variable_def.is_resource: 1886 raise ValueError(f"The `variable_def` you passed to `tf.Variable` is " 1887 f"Trying to restore a TF 1.x Reference Variable " 1888 f"as a TF 2.x ResourceVariable. This is unsupported. " 1889 f"Got variable_def={variable_def}") 1890 1891 # Create from variable_def. 1892 g = ops.get_default_graph() 1893 self._handle = g.as_graph_element( 1894 ops.prepend_name_scope( 1895 variable_def.variable_name, import_scope=import_scope)) 1896 self._shape = tensor_shape.TensorShape(self._handle.op.get_attr("shape")) 1897 self._handle_name = self._handle.name 1898 self._unique_id = self._handle_name 1899 self._initializer_op = g.as_graph_element( 1900 ops.prepend_name_scope( 1901 variable_def.initializer_name, import_scope=import_scope)) 1902 # Check whether initial_value_name exists for backwards compatibility. 1903 if (hasattr(variable_def, "initial_value_name") and 1904 variable_def.initial_value_name): 1905 self._initial_value = g.as_graph_element( 1906 ops.prepend_name_scope( 1907 variable_def.initial_value_name, import_scope=import_scope)) 1908 else: 1909 self._initial_value = None 1910 synchronization, aggregation, trainable = ( 1911 variables.validate_synchronization_aggregation_trainable( 1912 variable_def.synchronization, variable_def.aggregation, 1913 variable_def.trainable, variable_def.variable_name)) 1914 self._synchronization = synchronization 1915 self._aggregation = aggregation 1916 self._trainable = trainable 1917 if variable_def.snapshot_name: 1918 snapshot = g.as_graph_element( 1919 ops.prepend_name_scope( 1920 variable_def.snapshot_name, import_scope=import_scope)) 1921 if snapshot.op.type != "ReadVariableOp": 1922 self._cached_value = snapshot 1923 else: 1924 self._cached_value = None 1925 while snapshot.op.type != "ReadVariableOp": 1926 snapshot = snapshot.op.inputs[0] 1927 self._graph_element = snapshot 1928 else: 1929 self._cached_value = None 1930 # Legacy case for protos without the snapshot name; assume it's the 1931 # following. 1932 self._graph_element = g.get_tensor_by_name(self._handle.op.name + 1933 "/Read/ReadVariableOp:0") 1934 if variable_def.HasField("save_slice_info_def"): 1935 self._save_slice_info = variables.Variable.SaveSliceInfo( 1936 save_slice_info_def=variable_def.save_slice_info_def, 1937 import_scope=import_scope) 1938 else: 1939 self._save_slice_info = None 1940 self._caching_device = None 1941 self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype")) 1942 self._constraint = None 1943 1944 1945class UninitializedVariable(BaseResourceVariable): 1946 """A variable with no initializer.""" 1947 1948 def __init__( # pylint: disable=super-init-not-called 1949 self, 1950 trainable=None, 1951 caching_device=None, 1952 name=None, 1953 shape=None, 1954 dtype=None, 1955 constraint=None, 1956 synchronization=None, 1957 aggregation=None, 1958 extra_handle_data=None, 1959 distribute_strategy=None, 1960 **unused_kwargs): 1961 """Creates the variable handle. 1962 1963 Args: 1964 trainable: If `True`, GradientTapes automatically watch uses of this 1965 Variable. 1966 caching_device: Optional device string or function describing where the 1967 Variable should be cached for reading. Defaults to the Variable's 1968 device. If not `None`, caches on another device. Typical use is to 1969 cache on the device where the Ops using the Variable reside, to 1970 deduplicate copying through `Switch` and other conditional statements. 1971 name: Optional name for the variable. Defaults to `'Variable'` and gets 1972 uniquified automatically. 1973 shape: The variable's shape. 1974 dtype: The variable's dtype. 1975 constraint: An optional projection function to be applied to the variable 1976 after being updated by an `Optimizer` (e.g. used to implement norm 1977 constraints or value constraints for layer weights). The function must 1978 take as input the unprojected Tensor representing the value of the 1979 variable and return the Tensor for the projected value (which must have 1980 the same shape). Constraints are not safe to use when doing asynchronous 1981 distributed training. 1982 synchronization: Indicates when a distributed a variable will be 1983 aggregated. Accepted values are constants defined in the class 1984 `tf.VariableSynchronization`. By default the synchronization is set to 1985 `AUTO` and the current `DistributionStrategy` chooses when to 1986 synchronize. 1987 aggregation: Indicates how a distributed variable will be aggregated. 1988 Accepted values are constants defined in the class 1989 `tf.VariableAggregation`. 1990 extra_handle_data: Optional, another resource handle or Tensor with handle 1991 data to merge with `shape` and `dtype`. 1992 distribute_strategy: The tf.distribute.Strategy this variable is being 1993 created inside of. 1994 """ 1995 with ops.init_scope(): 1996 # Here we are detecting eagerness within an init_scope, so this will only 1997 # be true when we are running in TF1 graph mode. 1998 self._in_graph_mode = not context.executing_eagerly() 1999 with ops.name_scope(name, "Variable", skip_on_eager=False) as name: 2000 handle_name = ops.name_from_scope_name(name) 2001 if self._in_graph_mode: 2002 shared_name = handle_name 2003 unique_id = shared_name 2004 else: 2005 unique_id = "%s_%d" % (handle_name, ops.uid()) 2006 shared_name = None # Never shared 2007 handle = _variable_handle_from_shape_and_dtype( 2008 shape=shape, 2009 dtype=dtype, 2010 shared_name=shared_name, 2011 name=name, 2012 graph_mode=self._in_graph_mode, 2013 initial_value=extra_handle_data) 2014 if self._in_graph_mode: 2015 # We only need to add the read_variable_op in TF1. 2016 with ops.name_scope("Read"): 2017 # Manually assign reads to the handle's device to avoid log 2018 # messages. 2019 with ops.device(handle.device): 2020 value = gen_resource_variable_ops.read_variable_op(handle, dtype) 2021 _maybe_set_handle_data(dtype, handle, value) 2022 graph_element = value 2023 ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self) 2024 # Do *not* add to TRAINABLE_VARIABLES here, even if self._trainable, 2025 # because retraining or frozen use of imported SavedModels is 2026 # controlled at higher levels of model building. 2027 else: 2028 graph_element = None 2029 super(UninitializedVariable, self).__init__( 2030 distribute_strategy=distribute_strategy, 2031 shape=shape, 2032 dtype=dtype, 2033 unique_id=unique_id, 2034 handle_name=handle_name, 2035 constraint=constraint, 2036 handle=handle, 2037 graph_element=graph_element, 2038 trainable=trainable, 2039 synchronization=synchronization, 2040 aggregation=aggregation, 2041 in_graph_mode=self._in_graph_mode) 2042 2043 2044_pywrap_utils.RegisterType("ResourceVariable", ResourceVariable) 2045math_ops._resource_variable_type = ResourceVariable # pylint: disable=protected-access 2046 2047 2048def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False): 2049 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 2050 2051 2052# Register a conversion function which reads the value of the variable, 2053# allowing instances of the class to be used as tensors. 2054ops.register_tensor_conversion_function(BaseResourceVariable, 2055 _dense_var_to_tensor) 2056 2057 2058class _UnreadVariable(BaseResourceVariable): 2059 """Represents a future for a read of a variable. 2060 2061 Pretends to be the tensor if anyone looks. 2062 """ 2063 2064 def __init__(self, handle, dtype, shape, in_graph_mode, deleter, parent_op, 2065 unique_id): 2066 if isinstance(handle, ops.EagerTensor): 2067 handle_name = "" 2068 else: 2069 handle_name = handle.name 2070 # Only create a graph_element if we're in session.run-land as only 2071 # session.run requires a preexisting tensor to evaluate. Otherwise we can 2072 # avoid accidentally reading the variable. 2073 if context.executing_eagerly() or ops.inside_function(): 2074 graph_element = None 2075 else: 2076 with ops.control_dependencies([parent_op]): 2077 graph_element = gen_resource_variable_ops.read_variable_op( 2078 handle, dtype) 2079 _maybe_set_handle_data(dtype, handle, graph_element) 2080 super(_UnreadVariable, self).__init__( 2081 handle=handle, 2082 shape=shape, 2083 handle_name=handle_name, 2084 unique_id=unique_id, 2085 dtype=dtype, 2086 handle_deleter=deleter, 2087 graph_element=graph_element) 2088 self._parent_op = parent_op 2089 2090 @property 2091 def name(self): 2092 if self._in_graph_mode: 2093 return self._parent_op.name 2094 else: 2095 return "UnreadVariable" 2096 2097 def value(self): 2098 return self._read_variable_op() 2099 2100 def read_value(self): 2101 return self._read_variable_op() 2102 2103 def _read_variable_op(self): 2104 with ops.control_dependencies([self._parent_op]): 2105 result = gen_resource_variable_ops.read_variable_op( 2106 self._handle, self._dtype) 2107 _maybe_set_handle_data(self._dtype, self._handle, result) 2108 return result 2109 2110 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 2111 with ops.control_dependencies([self._parent_op]): 2112 return super(_UnreadVariable, self).assign_sub(delta, use_locking, name, 2113 read_value) 2114 2115 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 2116 with ops.control_dependencies([self._parent_op]): 2117 return super(_UnreadVariable, self).assign_add(delta, use_locking, name, 2118 read_value) 2119 2120 def assign(self, value, use_locking=None, name=None, read_value=True): 2121 with ops.control_dependencies([self._parent_op]): 2122 return super(_UnreadVariable, self).assign(value, use_locking, name, 2123 read_value) 2124 2125 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 2126 with ops.control_dependencies([self._parent_op]): 2127 return super(_UnreadVariable, self).scatter_sub(sparse_delta, use_locking, 2128 name) 2129 2130 def scatter_add(self, sparse_delta, use_locking=False, name=None): 2131 with ops.control_dependencies([self._parent_op]): 2132 return super(_UnreadVariable, self).scatter_add(sparse_delta, use_locking, 2133 name) 2134 2135 def scatter_max(self, sparse_delta, use_locking=False, name=None): 2136 with ops.control_dependencies([self._parent_op]): 2137 return super(_UnreadVariable, self).scatter_max(sparse_delta, use_locking, 2138 name) 2139 2140 def scatter_min(self, sparse_delta, use_locking=False, name=None): 2141 with ops.control_dependencies([self._parent_op]): 2142 return super(_UnreadVariable, self).scatter_min(sparse_delta, use_locking, 2143 name) 2144 2145 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 2146 with ops.control_dependencies([self._parent_op]): 2147 return super(_UnreadVariable, self).scatter_mul(sparse_delta, use_locking, 2148 name) 2149 2150 def scatter_div(self, sparse_delta, use_locking=False, name=None): 2151 with ops.control_dependencies([self._parent_op]): 2152 return super(_UnreadVariable, self).scatter_div(sparse_delta, use_locking, 2153 name) 2154 2155 def scatter_update(self, sparse_delta, use_locking=False, name=None): 2156 with ops.control_dependencies([self._parent_op]): 2157 return super(_UnreadVariable, 2158 self).scatter_update(sparse_delta, use_locking, name) 2159 2160 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 2161 with ops.control_dependencies([self._parent_op]): 2162 return super(_UnreadVariable, 2163 self).batch_scatter_update(sparse_delta, use_locking, name) 2164 2165 def scatter_nd_sub(self, indices, updates, name=None): 2166 with ops.control_dependencies([self._parent_op]): 2167 return super(_UnreadVariable, self).scatter_nd_sub(indices, updates, name) 2168 2169 def scatter_nd_add(self, indices, updates, name=None): 2170 with ops.control_dependencies([self._parent_op]): 2171 return super(_UnreadVariable, self).scatter_nd_add(indices, updates, name) 2172 2173 def scatter_nd_update(self, indices, updates, name=None): 2174 with ops.control_dependencies([self._parent_op]): 2175 return super(_UnreadVariable, 2176 self).scatter_nd_update(indices, updates, name) 2177 2178 def scatter_nd_max(self, indices, updates, name=None): 2179 with ops.control_dependencies([self._parent_op]): 2180 return super(_UnreadVariable, self).scatter_nd_max(indices, updates, name) 2181 2182 def scatter_nd_min(self, indices, updates, name=None): 2183 with ops.control_dependencies([self._parent_op]): 2184 return super(_UnreadVariable, self).scatter_nd_min(indices, updates, name) 2185 2186 @property 2187 def op(self): 2188 """The op for this variable.""" 2189 return self._parent_op 2190 2191 2192@ops.RegisterGradient("ReadVariableOp") 2193def _ReadGrad(_, grad): 2194 """Gradient for read op.""" 2195 return grad 2196 2197 2198def variable_shape(handle, out_type=dtypes.int32): 2199 handle_data = get_eager_safe_handle_data(handle) 2200 if handle_data is None or not handle_data.is_set: 2201 return gen_resource_variable_ops.variable_shape(handle, out_type=out_type) 2202 shape_proto = handle_data.shape_and_type[0].shape 2203 if shape_proto.unknown_rank or any(x.size == -1 for x in shape_proto.dim): 2204 return gen_resource_variable_ops.variable_shape(handle, out_type=out_type) 2205 return constant_op.constant([x.size for x in shape_proto.dim], dtype=out_type) 2206 2207 2208@ops.RegisterGradient("ResourceGather") 2209def _GatherGrad(op, grad): 2210 """Gradient for gather op.""" 2211 # Build appropriately shaped IndexedSlices 2212 handle = op.inputs[0] 2213 indices = op.inputs[1] 2214 params_shape = variable_shape(handle) 2215 size = array_ops.expand_dims(array_ops.size(indices), 0) 2216 values_shape = array_ops.concat([size, params_shape[1:]], 0) 2217 values = array_ops.reshape(grad, values_shape) 2218 indices = array_ops.reshape(indices, size) 2219 return (ops.IndexedSlices(values, indices, params_shape), None) 2220 2221 2222def _to_proto_fn(v, export_scope=None): 2223 """Converts Variable and ResourceVariable to VariableDef for collections.""" 2224 return v.to_proto(export_scope=export_scope) 2225 2226 2227def _from_proto_fn(v, import_scope=None): 2228 """Creates Variable or ResourceVariable from VariableDef as needed.""" 2229 if v.is_resource: 2230 return ResourceVariable.from_proto(v, import_scope=import_scope) 2231 return variables.Variable.from_proto(v, import_scope=import_scope) 2232 2233 2234ops.register_proto_function( 2235 ops.GraphKeys.GLOBAL_VARIABLES, 2236 proto_type=variable_pb2.VariableDef, 2237 to_proto=_to_proto_fn, 2238 from_proto=_from_proto_fn) 2239ops.register_proto_function( 2240 ops.GraphKeys.TRAINABLE_VARIABLES, 2241 proto_type=variable_pb2.VariableDef, 2242 to_proto=_to_proto_fn, 2243 from_proto=_from_proto_fn) 2244ops.register_proto_function( 2245 ops.GraphKeys.MOVING_AVERAGE_VARIABLES, 2246 proto_type=variable_pb2.VariableDef, 2247 to_proto=_to_proto_fn, 2248 from_proto=_from_proto_fn) 2249ops.register_proto_function( 2250 ops.GraphKeys.LOCAL_VARIABLES, 2251 proto_type=variable_pb2.VariableDef, 2252 to_proto=_to_proto_fn, 2253 from_proto=_from_proto_fn) 2254ops.register_proto_function( 2255 ops.GraphKeys.MODEL_VARIABLES, 2256 proto_type=variable_pb2.VariableDef, 2257 to_proto=_to_proto_fn, 2258 from_proto=_from_proto_fn) 2259ops.register_proto_function( 2260 ops.GraphKeys.GLOBAL_STEP, 2261 proto_type=variable_pb2.VariableDef, 2262 to_proto=_to_proto_fn, 2263 from_proto=_from_proto_fn) 2264ops.register_proto_function( 2265 ops.GraphKeys.METRIC_VARIABLES, 2266 proto_type=variable_pb2.VariableDef, 2267 to_proto=_to_proto_fn, 2268 from_proto=_from_proto_fn) 2269 2270 2271@tf_export("__internal__.ops.is_resource_variable", v1=[]) 2272def is_resource_variable(var): 2273 """"Returns True if `var` is to be considered a ResourceVariable.""" 2274 return isinstance(var, BaseResourceVariable) or hasattr( 2275 var, "_should_act_as_resource_variable") 2276 2277 2278def copy_to_graph_uninitialized(var): 2279 """Copies an existing variable to a new graph, with no initializer.""" 2280 # Like ResourceVariable.__deepcopy__, but does not set an initializer on the 2281 # new variable. 2282 # pylint: disable=protected-access 2283 new_variable = UninitializedVariable( 2284 trainable=var.trainable, 2285 constraint=var._constraint, 2286 shape=var.shape, 2287 dtype=var.dtype, 2288 name=var._shared_name, 2289 synchronization=var.synchronization, 2290 aggregation=var.aggregation, 2291 extra_handle_data=var.handle) 2292 new_variable._maybe_initialize_trackable() 2293 # pylint: enable=protected-access 2294 return new_variable 2295 2296 2297ops.NotDifferentiable("Assert") 2298ops.NotDifferentiable("VarIsInitializedOp") 2299ops.NotDifferentiable("VariableShape") 2300 2301 2302class VariableSpec(tensor_spec.DenseSpec): 2303 """Describes a tf.Variable.""" 2304 2305 __slots__ = ["trainable"] 2306 2307 value_type = property(lambda self: BaseResourceVariable) 2308 2309 def __init__(self, shape, dtype=dtypes.float32, 2310 name=None, trainable=True): 2311 super(VariableSpec, self).__init__(shape, dtype=dtype, name=name) 2312 self.trainable = trainable 2313 2314 def _to_components(self, value): 2315 raise NotImplementedError 2316 2317 def _from_components(self, components): 2318 raise NotImplementedError 2319 2320 def _from_compatible_tensor_list(self, tensor_list): 2321 assert len(tensor_list) == 1 2322 return tensor_list[0] 2323 2324 2325_pywrap_utils.RegisterType("VariableSpec", VariableSpec) 2326 2327 2328def write_object_proto_for_resource_variable(resource_variable, proto, options): 2329 """Writes additional information of the variable into the SavedObject proto. 2330 2331 This allows users to define a `hook` to provide extra information of the 2332 variable to the SavedObject. 2333 2334 For example, DistritubtedVariable class would fill in components in the 2335 distributed context. 2336 2337 Args: 2338 resource_variable: A `ResourceVariable` or `DistributedValue` that has the 2339 information to be saved into the proto. 2340 proto: `SavedObject` proto to update. 2341 options: A `SaveOption` instance that configures save behavior. 2342 """ 2343 proto.variable.SetInParent() 2344 if not resource_variable.name.endswith(":0"): 2345 raise ValueError(f"Cowardly refusing to save variable " 2346 f"{resource_variable.name} because of " 2347 f"unexpected suffix in the name (':0') " 2348 f"which won't be restored.") 2349 proto.variable.name = meta_graph._op_name(resource_variable.name) # pylint: disable=protected-access 2350 proto.variable.trainable = resource_variable.trainable 2351 proto.variable.dtype = resource_variable.dtype.as_datatype_enum 2352 proto.variable.synchronization = resource_variable.synchronization.value 2353 proto.variable.aggregation = resource_variable.aggregation.value 2354 proto.variable.shape.CopyFrom(resource_variable.shape.as_proto()) 2355 if options.experimental_variable_policy._save_variable_devices( # pylint: disable=protected-access 2356 ): 2357 if hasattr(resource_variable, "device"): 2358 proto.variable.device = resource_variable.device 2359