1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ops to use variables as resources.""" 16 17# pylint: disable=g-bad-name 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import contextlib 23import functools 24 25from tensorflow.core.framework import attr_value_pb2 26from tensorflow.core.framework import variable_pb2 27from tensorflow.python import pywrap_tensorflow 28from tensorflow.python.eager import context 29from tensorflow.python.eager import tape 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import cpp_shape_inference_pb2 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import tensor_shape 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import gen_array_ops 37from tensorflow.python.ops import gen_resource_variable_ops 38from tensorflow.python.ops import gen_state_ops 39from tensorflow.python.ops import math_ops 40from tensorflow.python.ops import state_ops 41from tensorflow.python.ops import variables 42# go/tf-wildcard-import 43# pylint: disable=wildcard-import 44from tensorflow.python.ops.gen_resource_variable_ops import * 45# pylint: enable=wildcard-import 46from tensorflow.python.training.tracking import base as trackable 47from tensorflow.python.util import compat 48from tensorflow.python.util.deprecation import deprecated 49 50 51def get_resource_handle_data(graph_op): 52 assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck 53 54 handle_data = pywrap_tensorflow.GetHandleShapeAndType( 55 graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access 56 57 return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString( 58 compat.as_bytes(handle_data)) 59 60 61def get_eager_safe_handle_data(handle): 62 """Get the data handle from the Tensor `handle`.""" 63 assert isinstance(handle, ops.Tensor) 64 65 if isinstance(handle, ops.EagerTensor): 66 return handle._handle_data # pylint: disable=protected-access 67 else: 68 return get_resource_handle_data(handle) 69 70 71def _set_handle_shapes_and_types(tensor, handle_data, graph_mode): 72 """Sets the shape inference result HandleData on tensor. 73 74 Args: 75 tensor: A `Tensor` or `EagerTensor`. 76 handle_data: A `CppShapeInferenceResult.HandleData`. 77 graph_mode: A python bool. 78 """ 79 tensor._handle_data = handle_data # pylint: disable=protected-access 80 if not graph_mode: 81 return 82 83 # Not an EagerTensor, so a graph tensor. 84 shapes, types = zip(*[(pair.shape, pair.dtype) 85 for pair in handle_data.shape_and_type]) 86 ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] 87 shapes = [[d.size for d in s.dim] 88 if not s.unknown_rank else None for s in shapes] 89 pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( 90 tensor._op._graph._c_graph, # pylint: disable=protected-access 91 tensor._as_tf_output(), # pylint: disable=protected-access 92 shapes, ranks, types) 93 94 95def _combine_handle_data(handle, initial_value): 96 """Concats HandleData from tensors `handle` and `initial_value`. 97 98 Args: 99 handle: A `Tensor` of dtype `resource`. 100 initial_value: A `Tensor`. 101 102 Returns: 103 A `CppShapeInferenceResult.HandleData`. If `initial_value` has dtype 104 `variant`, the `HandleData` contains the concatenation of the shape_and_type 105 from both `handle` and `initial_value`. 106 107 Raises: 108 RuntimeError: If handle, which was returned by VarHandleOp, either has 109 no handle data, or its len(handle_data.shape_and_type) != 1. 110 """ 111 assert handle.dtype == dtypes.resource 112 113 variable_handle_data = get_eager_safe_handle_data(handle) 114 115 if initial_value.dtype != dtypes.variant: 116 return variable_handle_data 117 118 extra_handle_data = get_eager_safe_handle_data(initial_value) 119 if extra_handle_data is not None and extra_handle_data.is_set: 120 if (variable_handle_data is None 121 or not variable_handle_data.is_set 122 or len(variable_handle_data.shape_and_type) != 1): 123 raise RuntimeError( 124 "Expected VarHandleOp to return a length==1 shape_and_type, " 125 "but saw: '%s'" % (variable_handle_data,)) 126 variable_handle_data.shape_and_type.extend( 127 extra_handle_data.shape_and_type) 128 return variable_handle_data 129 130 131def eager_safe_variable_handle(initial_value, shared_name, name, graph_mode): 132 """Creates a variable handle with information to do shape inference. 133 134 The shape and dtype are read from `initial_value` and stored in the returned 135 resource tensor's handle data. 136 137 If `initial_value.dtype == tf.variant`, we additionally extract the handle 138 data (if any) from `initial_value` and append it to the `handle_data`. 139 In this case, the returned tensor's handle data is in the form 140 141 ``` 142 is_set: true 143 shape_and_type { 144 shape { 145 // initial_value.shape 146 } 147 dtype: DT_VARIANT 148 } 149 shape_and_type { 150 // handle_data(initial_value).shape_and_type[0] 151 } 152 shape_and_type { 153 // handle_data(initial_value).shape_and_type[1] 154 } 155 ... 156 ``` 157 158 Ops that read from this tensor, such as `ReadVariableOp` and 159 `AssignVariableOp`, know that `handle_data(handle).shape_and_type[1:]` 160 correspond to the handle data of the variant(s) stored in the Variable. 161 162 Args: 163 initial_value: A `Tensor`. 164 shared_name: A string. 165 name: A string. 166 graph_mode: A python bool. 167 168 Returns: 169 The handle, a `Tensor` of type `resource`. 170 """ 171 shape = initial_value.get_shape() 172 dtype = initial_value.dtype.base_dtype 173 container = ops.get_default_graph()._container # pylint: disable=protected-access 174 if container is None: 175 container = "" 176 handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, 177 shared_name=shared_name, 178 name=name, 179 container=container) 180 181 if graph_mode: 182 full_handle_data = _combine_handle_data(handle, initial_value) 183 _set_handle_shapes_and_types(handle, full_handle_data, graph_mode) 184 return handle 185 else: 186 # We do not want two distinct ResourceVariable objects for the same 187 # underlying resource in the runtime. 188 # When in eager mode, explicitly ensure so here. When in graph mode, it's 189 # ensured by always generating different variable names. 190 exists = gen_resource_variable_ops.var_is_initialized_op(handle) 191 if exists: 192 raise ValueError("variable object with name '%s' already created. Use " 193 "get_variable() if reuse is desired." % 194 shared_name) 195 with context.graph_mode(), ops.Graph().as_default() as graph: 196 h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, 197 shared_name=shared_name, 198 name=name, 199 container=container) 200 201 # Tensor._handle_data contains information for the shape-inference code to 202 # know the shape and dtype of the variable pointed to by a handle. Since 203 # shape inference doesn't run in eager mode we copy this data here for 204 # when the handle is captured by an eager mode function. 205 # pylint: disable=protected-access 206 full_handle_data = _combine_handle_data(h, initial_value) 207 _set_handle_shapes_and_types(handle, full_handle_data, graph_mode) 208 # pylint: enable=protected-access 209 # Clean up op->graph->op reference cycles. 210 ops.dismantle_graph(graph) 211 return handle 212 213 214@contextlib.contextmanager 215def _handle_graph(handle): 216 # Note: might have an eager tensor but not be executing eagerly when building 217 # functions. 218 if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor) 219 or ops.has_default_graph()): 220 yield 221 else: 222 with handle.graph.as_default(): 223 yield 224 225 226class EagerResourceDeleter(object): 227 """An object which cleans up a resource handle. 228 229 An alternative to defining a __del__ method on an object. The intended use is 230 that ResourceVariables or other objects with resource handles will maintain a 231 single reference to this object. When the parent object is collected, this 232 object will be too. Even if the parent object is part of a reference cycle, 233 the cycle will be collectable. 234 """ 235 236 def __init__(self, handle, handle_device): 237 if not isinstance(handle, ops.Tensor): 238 raise ValueError( 239 ("Passed handle=%s to EagerResourceDeleter. Was expecting a handle " 240 "Tensor." % (handle,))) 241 self._handle = handle 242 self._handle_device = handle_device 243 244 def __del__(self): 245 # Resources follow object-identity when executing eagerly, so it is safe to 246 # delete the resource we have a handle to. 247 try: 248 # This resource was created in eager mode. However, this destructor may be 249 # running in graph mode (especially during unit tests). To clean up 250 # successfully, we switch back into eager mode temporarily. 251 with context.eager_mode(): 252 with ops.device(self._handle_device): 253 gen_resource_variable_ops.destroy_resource_op( 254 self._handle, ignore_lookup_error=True) 255 except TypeError: 256 # Suppress some exceptions, mainly for the case when we're running on 257 # module deletion. Things that can go wrong include the context module 258 # already being unloaded, self._handle._handle_data no longer being 259 # valid, and so on. Printing warnings in these cases is silly 260 # (exceptions raised from __del__ are printed as warnings to stderr). 261 pass # 'NoneType' object is not callable when the handle has been 262 # partially unloaded. 263 except AttributeError: 264 pass # 'NoneType' object has no attribute 'eager_mode' when context has 265 # been unloaded. Will catch other module unloads as well. 266 267 268def shape_safe_assign_variable_handle(handle, shape, value, name=None): 269 """Helper that checks shape compatibility and assigns variable.""" 270 with _handle_graph(handle): 271 value_tensor = ops.convert_to_tensor(value) 272 shape.assert_is_compatible_with(value_tensor.shape) 273 return gen_resource_variable_ops.assign_variable_op(handle, 274 value_tensor, 275 name=name) 276 277 278def _maybe_set_handle_data(dtype, handle, tensor): 279 if dtype == dtypes.variant: 280 # For DT_VARIANT types, the handle's shape_and_type[1:] stores the 281 # variant's handle data. Extract it. 282 handle_data = get_eager_safe_handle_data(handle) 283 if handle_data.is_set and len(handle_data.shape_and_type) > 1: 284 tensor._handle_data = ( # pylint: disable=protected-access 285 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData( 286 is_set=True, 287 shape_and_type=handle_data.shape_and_type[1:])) 288 289 290class ResourceVariable(variables.VariableV1): 291 """Variable based on resource handles. 292 293 See the [Variables How To](https://tensorflow.org/guide/variables) 294 for a high level overview. 295 296 A `ResourceVariable` allows you to maintain state across subsequent calls to 297 session.run. 298 299 The `ResourceVariable` constructor requires an initial value for the variable, 300 which can be a `Tensor` of any type and shape. The initial value defines the 301 type and shape of the variable. After construction, the type and shape of 302 the variable are fixed. The value can be changed using one of the assign 303 methods. 304 305 Just like any `Tensor`, variables created with 306 `tf.Variable(use_resource=True)` can be used as inputs for other Ops in the 307 graph. Additionally, all the operators overloaded for the `Tensor` class are 308 carried over to variables, so you can also add nodes to the graph by just 309 doing arithmetic on variables. 310 311 Unlike ref-based variable, a ResourceVariable has well-defined semantics. Each 312 usage of a ResourceVariable in a TensorFlow graph adds a read_value operation 313 to the graph. The Tensors returned by a read_value operation are guaranteed to 314 see all modifications to the value of the variable which happen in any 315 operation on which the read_value depends on (either directly, indirectly, or 316 via a control dependency) and guaranteed to not see any modification to the 317 value of the variable from operations that depend on the read_value operation. 318 Updates from operations that have no dependency relationship to the read_value 319 operation might or might not be visible to read_value. 320 321 For example, if there is more than one assignment to a ResourceVariable in 322 a single session.run call there is a well-defined value for each operation 323 which uses the variable's value if the assignments and the read are connected 324 by edges in the graph. Consider the following example, in which two writes 325 can cause tf.Variable and tf.ResourceVariable to behave differently: 326 327 ```python 328 a = tf.Variable(1.0, use_resource=True) 329 a.initializer.run() 330 331 assign = a.assign(2.0) 332 with tf.control_dependencies([assign]): 333 b = a.read_value() 334 with tf.control_dependencies([b]): 335 other_assign = a.assign(3.0) 336 with tf.control_dependencies([other_assign]): 337 # Will print 2.0 because the value was read before other_assign ran. If 338 # `a` was a tf.Variable instead, 2.0 or 3.0 could be printed. 339 tf.Print(b, [b]).eval() 340 ``` 341 """ 342 343 def __init__(self, 344 initial_value=None, 345 trainable=True, 346 collections=None, 347 validate_shape=True, # pylint: disable=unused-argument 348 caching_device=None, 349 name=None, 350 dtype=None, 351 variable_def=None, 352 import_scope=None, 353 constraint=None, 354 distribute_strategy=None): 355 """Creates a variable. 356 357 Args: 358 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 359 which is the initial value for the Variable. Can also be a 360 callable with no argument that returns the initial value when called. 361 (Note that initializer functions from init_ops.py must first be bound 362 to a shape before being used here.) 363 trainable: If `True`, the default, also adds the variable to the graph 364 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 365 the default list of variables to use by the `Optimizer` classes. 366 collections: List of graph collections keys. The new variable is added to 367 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 368 validate_shape: Ignored. Provided for compatibility with tf.Variable. 369 caching_device: Optional device string or function describing where the 370 Variable should be cached for reading. Defaults to the Variable's 371 device. If not `None`, caches on another device. Typical use is to 372 cache on the device where the Ops using the Variable reside, to 373 deduplicate copying through `Switch` and other conditional statements. 374 name: Optional name for the variable. Defaults to `'Variable'` and gets 375 uniquified automatically. 376 dtype: If set, initial_value will be converted to the given type. 377 If None, either the datatype will be kept (if initial_value is 378 a Tensor) or float32 will be used (if it is a Python object convertible 379 to a Tensor). 380 variable_def: `VariableDef` protocol buffer. If not None, recreates the 381 `ResourceVariable` object with its contents. `variable_def` and other 382 arguments (except for import_scope) are mutually exclusive. 383 import_scope: Optional `string`. Name scope to add to the 384 ResourceVariable. Only used when `variable_def` is provided. 385 constraint: An optional projection function to be applied to the variable 386 after being updated by an `Optimizer` (e.g. used to implement norm 387 constraints or value constraints for layer weights). The function must 388 take as input the unprojected Tensor representing the value of the 389 variable and return the Tensor for the projected value 390 (which must have the same shape). Constraints are not safe to 391 use when doing asynchronous distributed training. 392 distribute_strategy: The tf.distribute.Strategy this variable is being 393 created inside of. 394 395 Raises: 396 ValueError: If the initial value is not specified, or does not have a 397 shape and `validate_shape` is `True`. 398 399 @compatibility(eager) 400 When Eager Execution is enabled, the default for the `collections` argument 401 is `None`, which signifies that this `Variable` will not be added to any 402 collections. 403 @end_compatibility 404 """ 405 self._distribute_strategy = distribute_strategy 406 if variable_def: 407 if initial_value is not None: 408 raise ValueError("variable_def and initial_value are mutually " 409 "exclusive.") 410 if context.executing_eagerly(): 411 raise ValueError("Creating ResourceVariable from variable_def is " 412 "not supported when eager execution is enabled.") 413 self._init_from_proto(variable_def, import_scope=import_scope) 414 else: 415 self._init_from_args( 416 initial_value=initial_value, 417 trainable=trainable, 418 collections=collections, 419 caching_device=caching_device, 420 name=name, 421 dtype=dtype, 422 constraint=constraint) 423 424 def __repr__(self): 425 if context.executing_eagerly() and not self._in_graph_mode: 426 return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % ( 427 self.name, self.get_shape(), self.dtype.name, 428 ops.numpy_text(self.read_value(), is_repr=True)) 429 else: 430 return "<tf.Variable '%s' shape=%s dtype=%s>" % ( 431 self.name, self.get_shape(), self.dtype.name) 432 433 def _init_from_args(self, 434 initial_value=None, 435 trainable=True, 436 collections=None, 437 caching_device=None, 438 name=None, 439 dtype=None, 440 constraint=None): 441 """Creates a variable. 442 443 Args: 444 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 445 which is the initial value for the Variable. The initial value must have 446 a shape specified unless `validate_shape` is set to False. Can also be a 447 callable with no argument that returns the initial value when called. 448 (Note that initializer functions from init_ops.py must first be bound 449 to a shape before being used here.) 450 trainable: If `True`, the default, also adds the variable to the graph 451 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 452 the default list of variables to use by the `Optimizer` classes. 453 collections: List of graph collections keys. The new variable is added to 454 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 455 validate_shape: Ignored. Provided for compatibility with tf.Variable. 456 caching_device: Optional device string or function describing where the 457 Variable should be cached for reading. Defaults to the Variable's 458 device. If not `None`, caches on another device. Typical use is to 459 cache on the device where the Ops using the Variable reside, to 460 deduplicate copying through `Switch` and other conditional statements. 461 name: Optional name for the variable. Defaults to `'Variable'` and gets 462 uniquified automatically. 463 dtype: If set, initial_value will be converted to the given type. 464 If None, either the datatype will be kept (if initial_value is 465 a Tensor) or float32 will be used (if it is a Python object convertible 466 to a Tensor). 467 constraint: An optional projection function to be applied to the variable 468 after being updated by an `Optimizer` (e.g. used to implement norm 469 constraints or value constraints for layer weights). The function must 470 take as input the unprojected Tensor representing the value of the 471 variable and return the Tensor for the projected value 472 (which must have the same shape). Constraints are not safe to 473 use when doing asynchronous distributed training. 474 475 Raises: 476 ValueError: If the initial value is not specified, or does not have a 477 shape and `validate_shape` is `True`. 478 479 @compatibility(eager) 480 When Eager Execution is enabled, variables are never added to collections. 481 It is not implicitly added to the `GLOBAL_VARIABLES` or 482 `TRAINABLE_VARIABLES` collections, and the `collections` argument is 483 ignored. 484 @end_compatibility 485 """ 486 if initial_value is None: 487 raise ValueError("initial_value must be specified.") 488 init_from_fn = callable(initial_value) 489 490 if isinstance(initial_value, ops.Tensor) and hasattr( 491 initial_value, "graph") and initial_value.graph.building_function: 492 raise ValueError("Tensor-typed variable initializers must either be " 493 "wrapped in an init_scope or callable " 494 "(e.g., `tf.Variable(lambda : " 495 "tf.truncated_normal([10, 40]))`) when building " 496 "functions. Please file a feature request if this " 497 "restriction inconveniences you.") 498 499 if collections is None: 500 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 501 if not isinstance(collections, (list, tuple, set)): 502 raise ValueError( 503 "collections argument to Variable constructor must be a list, tuple, " 504 "or set. Got %s of type %s" % (collections, type(collections))) 505 if constraint is not None and not callable(constraint): 506 raise ValueError("The `constraint` argument must be a callable.") 507 508 if isinstance(initial_value, trackable.CheckpointInitialValue): 509 self._maybe_initialize_trackable() 510 self._update_uid = initial_value.checkpoint_position.restore_uid 511 initial_value = initial_value.wrapped_value 512 513 self._trainable = trainable 514 if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: 515 collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] 516 self._save_slice_info = None 517 # Store the graph key so optimizers know how to only retrieve variables from 518 # this graph. 519 self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 520 with ops.init_scope(): 521 self._in_graph_mode = not context.executing_eagerly() 522 with ops.name_scope(name, "Variable", [] 523 if init_from_fn else [initial_value]) as name: 524 # pylint: disable=protected-access 525 handle_name = ops._name_from_scope_name(name) 526 if self._in_graph_mode: 527 shared_name = handle_name 528 unique_id = shared_name 529 else: 530 # When in eager mode use a uid for the shared_name, to prevent 531 # accidental sharing. 532 unique_id = "%s_%d" % (handle_name, ops.uid()) 533 shared_name = context.shared_name() 534 # Use attr_scope and device(None) to simulate the behavior of 535 # colocate_with when the variable we want to colocate with doesn't 536 # yet exist. 537 device_context_manager = ( 538 ops.device if self._in_graph_mode else ops.NullContextmanager) 539 attr = attr_value_pb2.AttrValue( 540 list=attr_value_pb2.AttrValue.ListValue( 541 s=[compat.as_bytes("loc:@%s" % handle_name)])) 542 with ops.get_default_graph()._attr_scope({"_class": attr}): 543 with ops.name_scope("Initializer"), device_context_manager(None): 544 initial_value = ops.convert_to_tensor( 545 initial_value() if init_from_fn else initial_value, 546 name="initial_value", dtype=dtype) 547 self._handle = eager_safe_variable_handle( 548 initial_value=initial_value, 549 shared_name=shared_name, 550 name=name, 551 graph_mode=self._in_graph_mode) 552 self._shape = initial_value.shape 553 # pylint: disable=protected-access 554 if (self._in_graph_mode and initial_value is not None and 555 initial_value.op._get_control_flow_context() is not None): 556 raise ValueError( 557 "Initializer for variable %s is from inside a control-flow " 558 "construct, such as a loop or conditional. When creating a " 559 "variable inside a loop or conditional, use a lambda as the " 560 "initializer." % name) 561 # pylint: enable=protected-access 562 self._unique_id = unique_id 563 self._initial_value = initial_value if self._in_graph_mode else None 564 self._handle_name = handle_name + ":0" 565 self._dtype = initial_value.dtype.base_dtype 566 self._constraint = constraint 567 568 if self._in_graph_mode: 569 with ops.name_scope("IsInitialized"): 570 self._is_initialized_op = ( 571 gen_resource_variable_ops.var_is_initialized_op(self._handle)) 572 if initial_value is not None: 573 with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): 574 # pylint: disable=protected-access 575 self._initializer_op = ( 576 gen_resource_variable_ops.assign_variable_op( 577 self._handle, 578 variables._try_guard_against_uninitialized_dependencies( 579 name, 580 initial_value), 581 name=n)) 582 # pylint: enable=protected-access 583 with ops.name_scope("Read"), ops.colocate_with(self._handle): 584 # Manually assign reads to the handle's device to avoid log 585 # messages. 586 with ops.device(self._handle.device): 587 value = self._read_variable_op() 588 self._graph_element = value 589 if caching_device is not None: 590 # Variables may be created in a tf.device() or ops.colocate_with() 591 # context. At the same time, users would expect caching device to 592 # be independent of this context, and/or would not expect the 593 # current device context to be merged with the caching device 594 # spec. Therefore we reset the colocation stack before creating 595 # the cached value. Note that resetting the colocation stack will 596 # also reset the device stack. 597 with ops.colocate_with(None, ignore_existing=True): 598 with ops.device(caching_device): 599 self._cached_value = array_ops.identity(value) 600 else: 601 self._cached_value = None 602 else: 603 gen_resource_variable_ops.assign_variable_op(self._handle, 604 initial_value) 605 self._is_initialized_op = None 606 self._initializer_op = None 607 self._graph_element = None 608 if caching_device: 609 with ops.device(caching_device): 610 self._cached_value = self._read_variable_op() 611 else: 612 self._cached_value = None 613 if not context.executing_eagerly(): 614 # Eager variables are only added to collections if they are part of an 615 # eager variable store (otherwise in an interactive session they would 616 # hog memory and cause OOM). This is done in ops/variable_scope.py. 617 ops.add_to_collections(collections, self) 618 elif ops.GraphKeys.GLOBAL_STEP in collections: 619 ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self) 620 621 if not self._in_graph_mode: 622 # After the handle has been created, set up a way to clean it up when 623 # executing eagerly. We'll hold the only reference to the deleter, so that 624 # when this object is garbage collected the deleter will be too. This 625 # means ResourceVariables can be part of reference cycles without those 626 # cycles being uncollectable, and means that no __del__ will be defined at 627 # all in graph mode. 628 self._handle_deleter = EagerResourceDeleter( 629 handle=self._handle, handle_device=self._handle.device) 630 631 def _init_from_proto(self, variable_def, import_scope=None): 632 """Initializes from `VariableDef` proto.""" 633 # Note that init_from_proto is currently not supported in Eager mode. 634 assert not context.executing_eagerly() 635 self._in_graph_mode = True 636 assert isinstance(variable_def, variable_pb2.VariableDef) 637 if not variable_def.is_resource: 638 raise ValueError("Trying to restore Variable as ResourceVariable.") 639 640 # Create from variable_def. 641 g = ops.get_default_graph() 642 self._handle = g.as_graph_element( 643 ops.prepend_name_scope( 644 variable_def.variable_name, import_scope=import_scope)) 645 self._shape = tensor_shape.TensorShape( 646 self._handle.op.get_attr("shape")) 647 self._handle_name = self._handle.name 648 self._unique_id = self._handle_name 649 self._initializer_op = g.as_graph_element( 650 ops.prepend_name_scope( 651 variable_def.initializer_name, import_scope=import_scope)) 652 # Check whether initial_value_name exists for backwards compatibility. 653 if (hasattr(variable_def, "initial_value_name") and 654 variable_def.initial_value_name): 655 self._initial_value = g.as_graph_element( 656 ops.prepend_name_scope(variable_def.initial_value_name, 657 import_scope=import_scope)) 658 else: 659 self._initial_value = None 660 self._trainable = getattr(variable_def, "trainable", True) 661 if variable_def.snapshot_name: 662 snapshot = g.as_graph_element( 663 ops.prepend_name_scope( 664 variable_def.snapshot_name, import_scope=import_scope)) 665 if snapshot.op.type != "ReadVariableOp": 666 self._cached_value = snapshot 667 else: 668 self._cached_value = None 669 while snapshot.op.type != "ReadVariableOp": 670 snapshot = snapshot.op.inputs[0] 671 self._graph_element = snapshot 672 else: 673 self._cached_value = None 674 # Legacy case for protos without the snapshot name; assume it's the 675 # following. 676 self._graph_element = g.get_tensor_by_name( 677 self._handle.op.name + "/Read/ReadVariableOp:0") 678 if variable_def.HasField("save_slice_info_def"): 679 self._save_slice_info = variables.Variable.SaveSliceInfo( 680 save_slice_info_def=variable_def.save_slice_info_def, 681 import_scope=import_scope) 682 else: 683 self._save_slice_info = None 684 self._caching_device = None 685 self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype")) 686 self._constraint = None 687 688 @contextlib.contextmanager 689 def _assign_dependencies(self): 690 """Makes assignments depend on the cached value, if any. 691 692 This prevents undefined behavior with reads not ordered wrt writes. 693 694 Yields: 695 None. 696 """ 697 if self._cached_value is not None: 698 with ops.control_dependencies([self._cached_value]): 699 yield 700 else: 701 yield 702 703 def __nonzero__(self): 704 return self.__bool__() 705 706 def __bool__(self): 707 return bool(self.read_value()) 708 709 def __copy__(self): 710 return self 711 712 def __deepcopy__(self, memo): 713 if not context.executing_eagerly(): 714 raise NotImplementedError( 715 "__deepcopy__() is only available when eager execution is enabled.") 716 copied_variable = ResourceVariable( 717 initial_value=self.read_value(), 718 trainable=self._trainable, 719 constraint=self._constraint, 720 dtype=self._dtype, 721 name=self._shared_name + "_copy", 722 distribute_strategy=self._distribute_strategy) 723 memo[self._unique_id] = copied_variable 724 return copied_variable 725 726 @property 727 def dtype(self): 728 """The dtype of this variable.""" 729 return self._dtype 730 731 @property 732 def device(self): 733 """The device this variable is on.""" 734 return self._handle.device 735 736 @property 737 def graph(self): 738 """The `Graph` of this variable.""" 739 return self._handle.graph 740 741 @property 742 def name(self): 743 """The name of the handle for this variable.""" 744 return self._handle_name 745 746 @property 747 def shape(self): 748 """The shape of this variable.""" 749 return self._shape 750 751 def _shape_as_list(self): 752 if self.shape.ndims is None: 753 return None 754 return [dim.value for dim in self.shape.dims] 755 756 def _shape_tuple(self): 757 shape = self._shape_as_list() 758 if shape is None: 759 return None 760 return tuple(shape) 761 762 @property 763 def create(self): 764 """The op responsible for initializing this variable.""" 765 if not self._in_graph_mode: 766 raise RuntimeError("Calling create is not supported when eager execution" 767 " is enabled.") 768 return self._initializer_op 769 770 @property 771 def handle(self): 772 """The handle by which this variable can be accessed.""" 773 return self._handle 774 775 def value(self): 776 """A cached operation which reads the value of this variable.""" 777 if self._cached_value is not None: 778 return self._cached_value 779 with ops.colocate_with(None, ignore_existing=True): 780 with ops.device(self._handle.device): 781 return self._read_variable_op() 782 783 def _as_graph_element(self): 784 """Conversion function for Graph.as_graph_element().""" 785 return self._graph_element 786 787 @property 788 def initializer(self): 789 """The op responsible for initializing this variable.""" 790 return self._initializer_op 791 792 @property 793 def initial_value(self): 794 """Returns the Tensor used as the initial value for the variable.""" 795 if context.executing_eagerly(): 796 raise RuntimeError("initial_value not supported in EAGER mode.") 797 return self._initial_value 798 799 @property 800 def constraint(self): 801 """Returns the constraint function associated with this variable. 802 803 Returns: 804 The constraint function that was passed to the variable constructor. 805 Can be `None` if no constraint was passed. 806 """ 807 return self._constraint 808 809 @property 810 def op(self): 811 """The op for this variable.""" 812 return self._handle.op 813 814 @property 815 def trainable(self): 816 return self._trainable 817 818 def eval(self, session=None): 819 """Evaluates and returns the value of this variable.""" 820 if context.executing_eagerly(): 821 raise RuntimeError("Trying to eval in EAGER mode") 822 return self._graph_element.eval(session=session) 823 824 def numpy(self): 825 if context.executing_eagerly(): 826 return self.read_value().numpy() 827 raise NotImplementedError( 828 "numpy() is only available when eager execution is enabled.") 829 830 @deprecated(None, "Prefer Dataset.range instead.") 831 def count_up_to(self, limit): 832 """Increments this variable until it reaches `limit`. 833 834 When that Op is run it tries to increment the variable by `1`. If 835 incrementing the variable would bring it above `limit` then the Op raises 836 the exception `OutOfRangeError`. 837 838 If no error is raised, the Op outputs the value of the variable before 839 the increment. 840 841 This is essentially a shortcut for `count_up_to(self, limit)`. 842 843 Args: 844 limit: value at which incrementing the variable raises an error. 845 846 Returns: 847 A `Tensor` that will hold the variable value before the increment. If no 848 other Op modifies this variable, the values produced will all be 849 distinct. 850 """ 851 return gen_state_ops.resource_count_up_to(self.handle, limit=limit, 852 T=self.dtype) 853 854 def _read_variable_op(self): 855 if self.trainable: 856 tape.variable_accessed(self) 857 result = gen_resource_variable_ops.read_variable_op(self._handle, 858 self._dtype) 859 _maybe_set_handle_data(self._dtype, self._handle, result) 860 861 if not context.executing_eagerly(): 862 # Note that if a control flow context is active the input of the read op 863 # might not actually be the handle. This line bypasses it. 864 tape.record_operation( 865 "ReadVariableOp", [result], [self._handle], lambda x: [x]) 866 return result 867 868 def read_value(self): 869 """Constructs an op which reads the value of this variable. 870 871 Should be used when there are multiple reads, or when it is desirable to 872 read the value only after some condition is true. 873 874 Returns: 875 the read operation. 876 """ 877 with ops.name_scope("Read"): 878 # Ensure we read the variable in the same device as the handle. 879 with ops.device(self._handle.device): 880 value = self._read_variable_op() 881 # Return an identity so it can get placed on whatever device the context 882 # specifies instead of the device where the variable is. 883 return array_ops.identity(value) 884 885 def sparse_read(self, indices, name=None): 886 """Reads the value of this variable sparsely, using `gather`.""" 887 with ops.name_scope("Gather" if name is None else name) as name: 888 if self.trainable: 889 tape.variable_accessed(self) 890 value = gen_resource_variable_ops.resource_gather( 891 self._handle, indices, dtype=self._dtype, name=name) 892 893 if self._dtype == dtypes.variant: 894 # For DT_VARIANT types, the handle's shape_and_type[1:] stores the 895 # variant's handle data. Extract it. 896 handle_data = get_eager_safe_handle_data(self._handle) 897 if handle_data.is_set and len(handle_data.shape_and_type) > 1: 898 value._handle_data = ( # pylint: disable=protected-access 899 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData( 900 is_set=True, 901 shape_and_type=handle_data.shape_and_type[1:])) 902 903 return array_ops.identity(value) 904 905 def to_proto(self, export_scope=None): 906 """Converts a `ResourceVariable` to a `VariableDef` protocol buffer. 907 908 Args: 909 export_scope: Optional `string`. Name scope to remove. 910 911 Raises: 912 RuntimeError: If run in EAGER mode. 913 914 Returns: 915 A `VariableDef` protocol buffer, or `None` if the `Variable` is not 916 in the specified name scope. 917 """ 918 if context.executing_eagerly(): 919 raise RuntimeError("to_proto not supported in EAGER mode.") 920 if export_scope is None or self.handle.name.startswith(export_scope): 921 var_def = variable_pb2.VariableDef() 922 var_def.variable_name = ops.strip_name_scope(self.handle.name, 923 export_scope) 924 if self._initial_value is not None: 925 # This is inside an if-statement for backwards compatibility, since 926 # self._initial_value might be None for variables constructed from old 927 # protos. 928 var_def.initial_value_name = ops.strip_name_scope( 929 self._initial_value.name, export_scope) 930 var_def.initializer_name = ops.strip_name_scope(self.initializer.name, 931 export_scope) 932 if self._cached_value is not None: 933 var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name, 934 export_scope) 935 else: 936 # Store the graph_element here 937 var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name, 938 export_scope) 939 var_def.is_resource = True 940 var_def.trainable = self.trainable 941 if self._save_slice_info: 942 var_def.save_slice_info_def.MergeFrom( 943 self._save_slice_info.to_proto(export_scope=export_scope)) 944 return var_def 945 else: 946 return None 947 948 @staticmethod 949 def from_proto(variable_def, import_scope=None): 950 if context.executing_eagerly(): 951 raise RuntimeError("from_proto not supported in EAGER mode.") 952 return ResourceVariable( 953 variable_def=variable_def, import_scope=import_scope) 954 955 def set_shape(self, shape): 956 """Unsupported.""" 957 raise NotImplementedError("ResourceVariable does not implement set_shape()") 958 959 __array_priority__ = 100 960 961 def is_initialized(self, name=None): 962 """Checks whether a resource variable has been initialized. 963 964 Outputs boolean scalar indicating whether the tensor has been initialized. 965 966 Args: 967 name: A name for the operation (optional). 968 969 Returns: 970 A `Tensor` of type `bool`. 971 """ 972 return gen_resource_variable_ops.var_is_initialized_op(self.handle, name) 973 974 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 975 """Subtracts a value from this variable. 976 977 Args: 978 delta: A `Tensor`. The value to subtract from this variable. 979 use_locking: If `True`, use locking during the operation. 980 name: The name to use for the operation. 981 read_value: A `bool`. Whether to read and return the new value of the 982 variable or not. 983 984 Returns: 985 If `read_value` is `True`, this method will return the new value of the 986 variable after the assignment has completed. Otherwise, when in graph mode 987 it will return the `Operation` that does the assignment, and when in eager 988 mode it will return `None`. 989 """ 990 # TODO(apassos): this here and below is not atomic. Consider making it 991 # atomic if there's a way to do so without a performance cost for those who 992 # don't need it. 993 with _handle_graph(self.handle), self._assign_dependencies(): 994 assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op( 995 self.handle, ops.convert_to_tensor(delta, dtype=self.dtype), 996 name=name) 997 if read_value: 998 return self._lazy_read(assign_sub_op) 999 return assign_sub_op 1000 1001 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 1002 """Adds a value to this variable. 1003 1004 Args: 1005 delta: A `Tensor`. The value to add to this variable. 1006 use_locking: If `True`, use locking during the operation. 1007 name: The name to use for the operation. 1008 read_value: A `bool`. Whether to read and return the new value of the 1009 variable or not. 1010 1011 Returns: 1012 If `read_value` is `True`, this method will return the new value of the 1013 variable after the assignment has completed. Otherwise, when in graph mode 1014 it will return the `Operation` that does the assignment, and when in eager 1015 mode it will return `None`. 1016 """ 1017 with _handle_graph(self.handle), self._assign_dependencies(): 1018 assign_add_op = gen_resource_variable_ops.assign_add_variable_op( 1019 self.handle, ops.convert_to_tensor(delta, dtype=self.dtype), 1020 name=name) 1021 if read_value: 1022 return self._lazy_read(assign_add_op) 1023 return assign_add_op 1024 1025 def _lazy_read(self, op): 1026 if self.trainable: 1027 tape.variable_accessed(self) 1028 return _UnreadVariable( 1029 handle=self._handle, dtype=self.dtype, shape=self._shape, 1030 in_graph_mode=self._in_graph_mode, 1031 deleter=self._handle_deleter if not self._in_graph_mode else None, 1032 parent_op=op, unique_id=self._unique_id) 1033 1034 def assign(self, value, use_locking=None, name=None, read_value=True): 1035 """Assigns a new value to this variable. 1036 1037 Args: 1038 value: A `Tensor`. The new value for this variable. 1039 use_locking: If `True`, use locking during the assignment. 1040 name: The name to use for the assignment. 1041 read_value: A `bool`. Whether to read and return the new value of the 1042 variable or not. 1043 1044 Returns: 1045 If `read_value` is `True`, this method will return the new value of the 1046 variable after the assignment has completed. Otherwise, when in graph mode 1047 it will return the `Operation` that does the assignment, and when in eager 1048 mode it will return `None`. 1049 """ 1050 # Note: not depending on the cached value here since this can used to 1051 # initialize the variable. 1052 with _handle_graph(self.handle): 1053 value_tensor = ops.convert_to_tensor(value, dtype=self.dtype) 1054 self._shape.assert_is_compatible_with(value_tensor.shape) 1055 assign_op = gen_resource_variable_ops.assign_variable_op( 1056 self.handle, value_tensor, name=name) 1057 if read_value: 1058 return self._lazy_read(assign_op) 1059 return assign_op 1060 1061 def __reduce__(self): 1062 # The implementation mirrors that of __deepcopy__. 1063 return functools.partial( 1064 ResourceVariable, 1065 initial_value=self.numpy(), 1066 trainable=self.trainable, 1067 name=self._shared_name, 1068 dtype=self.dtype, 1069 constraint=self.constraint, 1070 distribute_strategy=self._distribute_strategy), () 1071 1072 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 1073 """Subtracts `IndexedSlices` from this variable. 1074 1075 Args: 1076 sparse_delta: `IndexedSlices` to be subtracted from this variable. 1077 use_locking: If `True`, use locking during the operation. 1078 name: the name of the operation. 1079 1080 Returns: 1081 A `Tensor` that will hold the new value of this variable after 1082 the scattered subtraction has completed. 1083 1084 Raises: 1085 ValueError: if `sparse_delta` is not an `IndexedSlices`. 1086 """ 1087 if not isinstance(sparse_delta, ops.IndexedSlices): 1088 raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 1089 return self._lazy_read(gen_resource_variable_ops.resource_scatter_sub( 1090 self.handle, sparse_delta.indices, 1091 ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name)) 1092 1093 def scatter_add(self, sparse_delta, use_locking=False, name=None): 1094 """Adds `IndexedSlices` from this variable. 1095 1096 Args: 1097 sparse_delta: `IndexedSlices` to be added to this variable. 1098 use_locking: If `True`, use locking during the operation. 1099 name: the name of the operation. 1100 1101 Returns: 1102 A `Tensor` that will hold the new value of this variable after 1103 the scattered subtraction has completed. 1104 1105 Raises: 1106 ValueError: if `sparse_delta` is not an `IndexedSlices`. 1107 """ 1108 if not isinstance(sparse_delta, ops.IndexedSlices): 1109 raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 1110 return self._lazy_read(gen_resource_variable_ops.resource_scatter_add( 1111 self.handle, sparse_delta.indices, 1112 ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name)) 1113 1114 def scatter_update(self, sparse_delta, use_locking=False, name=None): 1115 """Assigns `IndexedSlices` to this variable. 1116 1117 Args: 1118 sparse_delta: `IndexedSlices` to be assigned to this variable. 1119 use_locking: If `True`, use locking during the operation. 1120 name: the name of the operation. 1121 1122 Returns: 1123 A `Tensor` that will hold the new value of this variable after 1124 the scattered subtraction has completed. 1125 1126 Raises: 1127 ValueError: if `sparse_delta` is not an `IndexedSlices`. 1128 """ 1129 if not isinstance(sparse_delta, ops.IndexedSlices): 1130 raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta) 1131 return self._lazy_read(gen_resource_variable_ops.resource_scatter_update( 1132 self.handle, sparse_delta.indices, 1133 ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name)) 1134 1135 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 1136 """Assigns `IndexedSlices` to this variable batch-wise. 1137 1138 Analogous to `batch_gather`. This assumes that this variable and the 1139 sparse_delta IndexedSlices have a series of leading dimensions that are the 1140 same for all of them, and the updates are performed on the last dimension of 1141 indices. In other words, the dimensions should be the following: 1142 1143 `num_prefix_dims = sparse_delta.indices.ndims - 1` 1144 `batch_dim = num_prefix_dims + 1` 1145 `sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[ 1146 batch_dim:]` 1147 1148 where 1149 1150 `sparse_delta.updates.shape[:num_prefix_dims]` 1151 `== sparse_delta.indices.shape[:num_prefix_dims]` 1152 `== var.shape[:num_prefix_dims]` 1153 1154 And the operation performed can be expressed as: 1155 1156 `var[i_1, ..., i_n, 1157 sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[ 1158 i_1, ..., i_n, j]` 1159 1160 When sparse_delta.indices is a 1D tensor, this operation is equivalent to 1161 `scatter_update`. 1162 1163 To avoid this operation one can looping over the first `ndims` of the 1164 variable and using `scatter_update` on the subtensors that result of slicing 1165 the first dimension. This is a valid option for `ndims = 1`, but less 1166 efficient than this implementation. 1167 1168 Args: 1169 sparse_delta: `IndexedSlices` to be assigned to this variable. 1170 use_locking: If `True`, use locking during the operation. 1171 name: the name of the operation. 1172 1173 Returns: 1174 A `Tensor` that will hold the new value of this variable after 1175 the scattered subtraction has completed. 1176 1177 Raises: 1178 ValueError: if `sparse_delta` is not an `IndexedSlices`. 1179 """ 1180 return self._lazy_read(state_ops.batch_scatter_update( 1181 self, sparse_delta.indices, sparse_delta.values, 1182 use_locking=use_locking, name=name)) 1183 1184 def scatter_nd_sub(self, indices, updates, name=None): 1185 """Applies sparse subtraction to individual values or slices in a Variable. 1186 1187 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 1188 1189 `indices` must be integer tensor, containing indices into `ref`. 1190 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 1191 1192 The innermost dimension of `indices` (with length `K`) corresponds to 1193 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 1194 dimension of `ref`. 1195 1196 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 1197 1198 ``` 1199 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 1200 ``` 1201 1202 For example, say we want to add 4 scattered elements to a rank-1 tensor to 1203 8 elements. In Python, that update would look like this: 1204 1205 ```python 1206 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 1207 indices = tf.constant([[4], [3], [1] ,[7]]) 1208 updates = tf.constant([9, 10, 11, 12]) 1209 op = ref.scatter_nd_sub(indices, updates) 1210 with tf.Session() as sess: 1211 print sess.run(op) 1212 ``` 1213 1214 The resulting update to ref would look like this: 1215 1216 [1, -9, 3, -6, -6, 6, 7, -4] 1217 1218 See `tf.scatter_nd` for more details about how to make updates to 1219 slices. 1220 1221 Args: 1222 indices: The indices to be used in the operation. 1223 updates: The values to be used in the operation. 1224 name: the name of the operation. 1225 1226 Returns: 1227 A `Tensor` that will hold the new value of this variable after 1228 the scattered subtraction has completed. 1229 1230 Raises: 1231 ValueError: if `sparse_delta` is not an `IndexedSlices`. 1232 """ 1233 return self._lazy_read(gen_state_ops.resource_scatter_nd_sub( 1234 self.handle, indices, ops.convert_to_tensor(updates, self.dtype), 1235 name=name)) 1236 1237 def scatter_nd_add(self, indices, updates, name=None): 1238 """Applies sparse addition to individual values or slices in a Variable. 1239 1240 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 1241 1242 `indices` must be integer tensor, containing indices into `ref`. 1243 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 1244 1245 The innermost dimension of `indices` (with length `K`) corresponds to 1246 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 1247 dimension of `ref`. 1248 1249 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 1250 1251 ``` 1252 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 1253 ``` 1254 1255 For example, say we want to add 4 scattered elements to a rank-1 tensor to 1256 8 elements. In Python, that update would look like this: 1257 1258 ```python 1259 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 1260 indices = tf.constant([[4], [3], [1] ,[7]]) 1261 updates = tf.constant([9, 10, 11, 12]) 1262 add = ref.scatter_nd_add(indices, updates) 1263 with tf.Session() as sess: 1264 print sess.run(add) 1265 ``` 1266 1267 The resulting update to ref would look like this: 1268 1269 [1, 13, 3, 14, 14, 6, 7, 20] 1270 1271 See `tf.scatter_nd` for more details about how to make updates to 1272 slices. 1273 1274 Args: 1275 indices: The indices to be used in the operation. 1276 updates: The values to be used in the operation. 1277 name: the name of the operation. 1278 1279 Returns: 1280 A `Tensor` that will hold the new value of this variable after 1281 the scattered subtraction has completed. 1282 1283 Raises: 1284 ValueError: if `sparse_delta` is not an `IndexedSlices`. 1285 """ 1286 return self._lazy_read(gen_state_ops.resource_scatter_nd_add( 1287 self.handle, indices, ops.convert_to_tensor(updates, self.dtype), 1288 name=name)) 1289 1290 def scatter_nd_update(self, indices, updates, name=None): 1291 """Applies sparse assignment to individual values or slices in a Variable. 1292 1293 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 1294 1295 `indices` must be integer tensor, containing indices into `ref`. 1296 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 1297 1298 The innermost dimension of `indices` (with length `K`) corresponds to 1299 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 1300 dimension of `ref`. 1301 1302 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 1303 1304 ``` 1305 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 1306 ``` 1307 1308 For example, say we want to add 4 scattered elements to a rank-1 tensor to 1309 8 elements. In Python, that update would look like this: 1310 1311 ```python 1312 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 1313 indices = tf.constant([[4], [3], [1] ,[7]]) 1314 updates = tf.constant([9, 10, 11, 12]) 1315 op = ref.scatter_nd_update(indices, updates) 1316 with tf.Session() as sess: 1317 print sess.run(op) 1318 ``` 1319 1320 The resulting update to ref would look like this: 1321 1322 [1, 11, 3, 10, 9, 6, 7, 12] 1323 1324 See `tf.scatter_nd` for more details about how to make updates to 1325 slices. 1326 1327 Args: 1328 indices: The indices to be used in the operation. 1329 updates: The values to be used in the operation. 1330 name: the name of the operation. 1331 1332 Returns: 1333 A `Tensor` that will hold the new value of this variable after 1334 the scattered subtraction has completed. 1335 1336 Raises: 1337 ValueError: if `sparse_delta` is not an `IndexedSlices`. 1338 """ 1339 return self._lazy_read(gen_state_ops.resource_scatter_nd_update( 1340 self.handle, indices, ops.convert_to_tensor(updates, self.dtype), 1341 name=name)) 1342 1343 def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask, 1344 end_mask, ellipsis_mask, new_axis_mask, 1345 shrink_axis_mask): 1346 with _handle_graph(self.handle), self._assign_dependencies(): 1347 return self._lazy_read( 1348 gen_array_ops.resource_strided_slice_assign( 1349 ref=self.handle, 1350 begin=begin, 1351 end=end, 1352 strides=strides, 1353 value=ops.convert_to_tensor(value, dtype=self.dtype), 1354 name=name, 1355 begin_mask=begin_mask, 1356 end_mask=end_mask, 1357 ellipsis_mask=ellipsis_mask, 1358 new_axis_mask=new_axis_mask, 1359 shrink_axis_mask=shrink_axis_mask)) 1360 1361 def __int__(self): 1362 if self.dtype != dtypes.int32 and self.dtype != dtypes.int64: 1363 raise TypeError("Non-integer variable can't be converted to integer.") 1364 return int(self.value().numpy()) 1365 1366 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 1367 del name 1368 if dtype is not None and not dtype.is_compatible_with(self.dtype): 1369 raise ValueError( 1370 "Incompatible type conversion requested to type {!r} for variable " 1371 "of type {!r}".format(dtype.name, self.dtype.name)) 1372 if as_ref: 1373 return self.read_value().op.inputs[0] 1374 else: 1375 return self.value() 1376 1377 def __iadd__(self, unused_other): 1378 raise RuntimeError("Variable += value not supported. Use " 1379 "variable.assign_add(value) to modify the variable " 1380 "value and variable = variable + value to get a new " 1381 "Tensor object.") 1382 1383 def __isub__(self, unused_other): 1384 raise RuntimeError("Variable -= value not supported. Use " 1385 "variable.assign_sub(value) to modify the variable " 1386 "value and variable = variable - value to get a new " 1387 "Tensor object.") 1388 1389 def __imul__(self, unused_other): 1390 raise RuntimeError("Variable *= value not supported. Use " 1391 "`var.assign(var * value)` to modify the variable or " 1392 "`var = var * value` to get a new Tensor object.") 1393 1394 def __idiv__(self, unused_other): 1395 raise RuntimeError("Variable /= value not supported. Use " 1396 "`var.assign(var / value)` to modify the variable or " 1397 "`var = var / value` to get a new Tensor object.") 1398 1399 def __itruediv__(self, unused_other): 1400 raise RuntimeError("Variable /= value not supported. Use " 1401 "`var.assign(var / value)` to modify the variable or " 1402 "`var = var / value` to get a new Tensor object.") 1403 1404 def __irealdiv__(self, unused_other): 1405 raise RuntimeError("Variable /= value not supported. Use " 1406 "`var.assign(var / value)` to modify the variable or " 1407 "`var = var / value` to get a new Tensor object.") 1408 1409 def __ipow__(self, unused_other): 1410 raise RuntimeError("Variable **= value not supported. Use " 1411 "`var.assign(var ** value)` to modify the variable or " 1412 "`var = var ** value` to get a new Tensor object.") 1413 1414 1415pywrap_tensorflow.TFE_Py_RegisterResourceVariableType(ResourceVariable) 1416math_ops._resource_variable_type = ResourceVariable # pylint: disable=protected-access 1417 1418 1419def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False): 1420 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 1421 1422 1423# Register a conversion function which reads the value of the variable, 1424# allowing instances of the class to be used as tensors. 1425ops.register_tensor_conversion_function(ResourceVariable, _dense_var_to_tensor) 1426ops.register_dense_tensor_like_type(ResourceVariable) 1427 1428 1429class _UnreadVariable(ResourceVariable): 1430 """Represents a future for a read of a variable. 1431 1432 Pretends to be the tensor if anyone looks. 1433 """ 1434 1435 def __init__(self, handle, dtype, # pylint: disable=super-init-not-called 1436 shape, in_graph_mode, deleter, parent_op, unique_id): 1437 # We do not call super init on purpose. 1438 self._trainable = False 1439 self._save_slice_info = None 1440 self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 1441 self._in_graph_mode = in_graph_mode 1442 self._handle = handle 1443 self._shape = shape 1444 self._initial_value = None 1445 if isinstance(self._handle, ops.EagerTensor): 1446 self._handle_name = "" 1447 else: 1448 self._handle_name = self._handle.name 1449 self._unique_id = unique_id 1450 self._dtype = dtype 1451 self._constraint = None 1452 self._cached_value = None 1453 self._is_initialized_op = None 1454 self._initializer_op = None 1455 self._parent_op = parent_op 1456 if context.executing_eagerly(): 1457 self._graph_element = None 1458 else: 1459 self._graph_element = self.read_value() 1460 self._handle_deleter = deleter 1461 1462 @property 1463 def name(self): 1464 if self._in_graph_mode: 1465 return self._parent_op.name 1466 else: 1467 return "UnreadVariable" 1468 1469 def value(self): 1470 return self._read_variable_op() 1471 1472 def read_value(self): 1473 return self._read_variable_op() 1474 1475 def _read_variable_op(self): 1476 with ops.control_dependencies([self._parent_op]): 1477 result = gen_resource_variable_ops.read_variable_op(self._handle, 1478 self._dtype) 1479 _maybe_set_handle_data(self._dtype, self._handle, result) 1480 return result 1481 1482 1483 @property 1484 def op(self): 1485 """The op for this variable.""" 1486 return self._parent_op 1487 1488 1489ops.register_dense_tensor_like_type(_UnreadVariable) 1490 1491 1492@ops.RegisterGradient("ReadVariableOp") 1493def _ReadGrad(_, grad): 1494 """Gradient for read op.""" 1495 return grad 1496 1497 1498def variable_shape(handle, out_type=dtypes.int32): 1499 if getattr( 1500 handle, "_handle_data", None) is None or not handle._handle_data.is_set: 1501 return gen_resource_variable_ops.variable_shape(handle, out_type=out_type) 1502 shape_proto = handle._handle_data.shape_and_type[0].shape 1503 if shape_proto.unknown_rank or any(x.size == -1 for x in shape_proto.dim): 1504 return gen_resource_variable_ops.variable_shape(handle, out_type=out_type) 1505 return constant_op.constant([x.size for x in shape_proto.dim], dtype=out_type) 1506 1507 1508@ops.RegisterGradient("ResourceGather") 1509def _GatherGrad(op, grad): 1510 """Gradient for gather op.""" 1511 # Build appropriately shaped IndexedSlices 1512 handle = op.inputs[0] 1513 indices = op.inputs[1] 1514 params_shape = variable_shape(handle) 1515 size = array_ops.expand_dims(array_ops.size(indices), 0) 1516 values_shape = array_ops.concat([size, params_shape[1:]], 0) 1517 values = array_ops.reshape(grad, values_shape) 1518 indices = array_ops.reshape(indices, size) 1519 return (ops.IndexedSlices(values, indices, params_shape), None) 1520 1521 1522def _to_proto_fn(v, export_scope=None): 1523 """Converts Variable and ResourceVariable to VariableDef for collections.""" 1524 return v.to_proto(export_scope=export_scope) 1525 1526 1527def _from_proto_fn(v, import_scope=None): 1528 """Creates Variable or ResourceVariable from VariableDef as needed.""" 1529 if v.is_resource: 1530 return ResourceVariable.from_proto(v, import_scope=import_scope) 1531 return variables.Variable.from_proto(v, import_scope=import_scope) 1532 1533 1534ops.register_proto_function( 1535 ops.GraphKeys.GLOBAL_VARIABLES, 1536 proto_type=variable_pb2.VariableDef, 1537 to_proto=_to_proto_fn, 1538 from_proto=_from_proto_fn) 1539ops.register_proto_function( 1540 ops.GraphKeys.TRAINABLE_VARIABLES, 1541 proto_type=variable_pb2.VariableDef, 1542 to_proto=_to_proto_fn, 1543 from_proto=_from_proto_fn) 1544ops.register_proto_function( 1545 ops.GraphKeys.MOVING_AVERAGE_VARIABLES, 1546 proto_type=variable_pb2.VariableDef, 1547 to_proto=_to_proto_fn, 1548 from_proto=_from_proto_fn) 1549ops.register_proto_function( 1550 ops.GraphKeys.LOCAL_VARIABLES, 1551 proto_type=variable_pb2.VariableDef, 1552 to_proto=_to_proto_fn, 1553 from_proto=_from_proto_fn) 1554ops.register_proto_function( 1555 ops.GraphKeys.MODEL_VARIABLES, 1556 proto_type=variable_pb2.VariableDef, 1557 to_proto=_to_proto_fn, 1558 from_proto=_from_proto_fn) 1559ops.register_proto_function( 1560 ops.GraphKeys.GLOBAL_STEP, 1561 proto_type=variable_pb2.VariableDef, 1562 to_proto=_to_proto_fn, 1563 from_proto=_from_proto_fn) 1564 1565 1566def is_resource_variable(var): 1567 """"Returns True if `var` is to be considered a ResourceVariable.""" 1568 return isinstance(var, ResourceVariable) or hasattr( 1569 var, "_should_act_as_resource_variable") 1570 1571 1572def copy_to_graph_uninitialized(var): 1573 """Copies an existing variable to a new graph, with no initializer.""" 1574 # Like ResourceVariable.__deepcopy__, but does not set an initializer on the 1575 # new variable. 1576 # pylint: disable=protected-access 1577 new_variable = ResourceVariable( 1578 initial_value=array_ops.placeholder( 1579 shape=var.shape, dtype=var.dtype, 1580 name="unused_initial_variable_value"), 1581 trainable=var.trainable, 1582 constraint=var._constraint, 1583 dtype=var.dtype, 1584 name=var._shared_name) 1585 new_variable._maybe_initialize_trackable() 1586 # pylint: enable=protected-access 1587 return new_variable 1588 1589ops.NotDifferentiable("VarIsInitializedOp") 1590ops.NotDifferentiable("VariableShape") 1591