1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ops to use variables as resources.""" 16 17# pylint: disable=g-bad-name 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.core.framework import attr_value_pb2 23from tensorflow.core.framework import variable_pb2 24from tensorflow.python.eager import context 25from tensorflow.python.eager import tape 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import gen_array_ops 31from tensorflow.python.ops import gen_resource_variable_ops 32from tensorflow.python.ops import gen_state_ops 33from tensorflow.python.ops import variables 34# go/tf-wildcard-import 35# pylint: disable=wildcard-import 36from tensorflow.python.ops.gen_resource_variable_ops import * 37# pylint: enable=wildcard-import 38from tensorflow.python.training import checkpointable 39from tensorflow.python.util import compat 40 41 42def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): 43 """Creates a variable handle with information to do shape inference.""" 44 container = ops.get_default_graph()._container # pylint: disable=protected-access 45 if container is None: 46 container = "" 47 if not graph_mode: 48 # When in eager mode use a uid for the shared_name, to prevent accidental 49 # sharing. 50 shared_name = str(ops.uid()) 51 handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, 52 shared_name=shared_name, 53 name=name, 54 container=container) 55 if graph_mode: 56 return handle 57 58 # We do not want two distinct ResourceVariable objects for the same 59 # underlying resource in the runtime. 60 # When in eager mode, explicitly ensure so here. When in graph mode, it's 61 # ensured by always generating different variable names. 62 exists = gen_resource_variable_ops.var_is_initialized_op(handle) 63 if exists: 64 raise ValueError("variable object with name '%s' already created. Use " 65 "get_variable() if reuse is desired." % 66 shared_name) 67 with context.graph_mode(), ops.Graph().as_default() as graph: 68 h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, 69 shared_name=shared_name, 70 name=name, 71 container=container) 72 73 # Tensor._handle_data contains information for the shape-inference code to 74 # know the shape and dtype of the variable pointed to by a handle. Since 75 # shape inference doesn't run in eager mode we copy this data here for when 76 # the handle is captured by an eager mode function. 77 handle._handle_data = h._handle_data # pylint: disable=protected-access 78 # Clean up our reference cycles to avoid making the garbage collector run. 79 # pylint: disable=protected-access 80 # OrderedDict, constructed on Graph creation, makes a simple reference loop 81 # and hides it in an __attribute in some Python versions. We don't need to 82 # throw an error if we can't find it, but if we do find it we can break the 83 # loop to avoid creating work for the garbage collector. 84 problematic_cycle = graph._functions.__dict__.get("_OrderedDict__root", None) 85 # pylint: enable=protected-access 86 if problematic_cycle: 87 try: 88 del problematic_cycle[0][:] 89 except TypeError: 90 # This is probably not one of the problematic Python versions. Continue 91 # with the rest of our cleanup. 92 pass 93 # Now clean up our own reference cycles by clearing all of the attributes for 94 # the Graph and op we created. 95 h.__dict__ = {} 96 graph.__dict__ = {} 97 return handle 98 99 100class EagerResourceDeleter(object): 101 """An object which cleans up a resource handle. 102 103 An alternative to defining a __del__ method on an object. The intended use is 104 that ResourceVariables or other objects with resource handles will maintain a 105 single reference to this object. When the parent object is collected, this 106 object will be too. Even if the parent object is part of a reference cycle, 107 the cycle will be collectable. 108 """ 109 110 def __init__(self, handle, handle_device): 111 if not isinstance(handle, ops.Tensor): 112 raise ValueError( 113 ("Passed handle=%s to EagerResourceDeleter. Was expecting a handle " 114 "Tensor." % (handle,))) 115 self._handle = handle 116 self._handle_device = handle_device 117 118 def __del__(self): 119 # Resources follow object-identity when executing eagerly, so it is safe to 120 # delete the resource we have a handle to. Each Graph has a unique container 121 # name, which prevents resource sharing. 122 try: 123 # This resource was created in eager mode. However, this destructor may be 124 # running in graph mode (especially during unit tests). To clean up 125 # successfully, we switch back into eager mode temporarily. 126 with context.eager_mode(): 127 with ops.device(self._handle_device): 128 gen_resource_variable_ops.destroy_resource_op( 129 self._handle, ignore_lookup_error=True) 130 except TypeError: 131 # Suppress some exceptions, mainly for the case when we're running on 132 # module deletion. Things that can go wrong include the context module 133 # already being unloaded, self._handle._handle_data no longer being 134 # valid, and so on. Printing warnings in these cases is silly 135 # (exceptions raised from __del__ are printed as warnings to stderr). 136 pass # 'NoneType' object is not callable when the handle has been 137 # partially unloaded. 138 except AttributeError: 139 pass # 'NoneType' object has no attribute 'eager_mode' when context has 140 # been unloaded. Will catch other module unloads as well. 141 142 143def shape_safe_assign_variable_handle(handle, shape, value, name=None): 144 """Helper that checks shape compatibility and assigns variable.""" 145 value_tensor = ops.convert_to_tensor(value) 146 shape.assert_is_compatible_with(value_tensor.shape) 147 return gen_resource_variable_ops.assign_variable_op(handle, 148 value_tensor, 149 name=name) 150 151 152class ResourceVariable(variables.Variable): 153 """Variable based on resource handles. 154 155 See the ${variables} documentation for more details. 156 157 A `ResourceVariable` allows you to maintain state across subsequent calls to 158 session.run. 159 160 The `ResourceVariable` constructor requires an initial value for the variable, 161 which can be a `Tensor` of any type and shape. The initial value defines the 162 type and shape of the variable. After construction, the type and shape of 163 the variable are fixed. The value can be changed using one of the assign 164 methods. 165 166 Just like any `Tensor`, variables created with `ResourceVariable()` can be 167 used as inputs for other Ops in the graph. Additionally, all the operators 168 overloaded for the `Tensor` class are carried over to variables, so you can 169 also add nodes to the graph by just doing arithmetic on variables. 170 171 Unlike tf.Variable, a tf.ResourceVariable has well-defined semantics. Each 172 usage of a ResourceVariable in a TensorFlow graph adds a read_value operation 173 to the graph. The Tensors returned by a read_value operation are guaranteed 174 to see all modifications to the value of the variable which happen in any 175 operation on which the read_value depends on (either directly, indirectly, or 176 via a control dependency) and guaranteed to not see any modification to the 177 value of the variable on which the read_value operation does not depend on. 178 179 For example, if there is more than one assignment to a ResourceVariable in 180 a single session.run call there is a well-defined value for each operation 181 which uses the variable's value if the assignments and the read are connected 182 by edges in the graph. Consider the following example, in which two writes 183 can cause tf.Variable and tf.ResourceVariable to behave differently: 184 185 ```python 186 a = tf.ResourceVariable(1.0) 187 a.initializer.run() 188 189 assign = a.assign(2.0) 190 with tf.control_dependencies([assign]): 191 b = a.read_value() 192 with tf.control_dependencies([b]): 193 other_assign = a.assign(3.0) 194 with tf.control_dependencies([other_assign]): 195 # Will print 2.0 because the value was read before other_assign ran. If 196 # `a` was a tf.Variable instead, 2.0 or 3.0 could be printed. 197 tf.Print(b, [b]).eval() 198 ``` 199 200 To enforce these consistency properties tf.ResourceVariable might make more 201 copies than an equivalent tf.Variable under the hood, so tf.Variable is still 202 not deprecated. 203 """ 204 205 def __init__(self, 206 initial_value=None, 207 trainable=True, 208 collections=None, 209 validate_shape=True, 210 caching_device=None, 211 name=None, 212 dtype=None, 213 variable_def=None, 214 import_scope=None, 215 constraint=None): 216 """Creates a variable. 217 218 Args: 219 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 220 which is the initial value for the Variable. The initial value must have 221 a shape specified unless `validate_shape` is set to False. Can also be a 222 callable with no argument that returns the initial value when called. 223 (Note that initializer functions from init_ops.py must first be bound 224 to a shape before being used here.) 225 trainable: If `True`, the default, also adds the variable to the graph 226 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 227 the default list of variables to use by the `Optimizer` classes. 228 collections: List of graph collections keys. The new variable is added to 229 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 230 validate_shape: Ignored. Provided for compatibility with tf.Variable. 231 caching_device: Optional device string or function describing where the 232 Variable should be cached for reading. Defaults to the Variable's 233 device. If not `None`, caches on another device. Typical use is to 234 cache on the device where the Ops using the Variable reside, to 235 deduplicate copying through `Switch` and other conditional statements. 236 name: Optional name for the variable. Defaults to `'Variable'` and gets 237 uniquified automatically. 238 dtype: If set, initial_value will be converted to the given type. 239 If None, either the datatype will be kept (if initial_value is 240 a Tensor) or float32 will be used (if it is a Python object convertible 241 to a Tensor). 242 variable_def: `VariableDef` protocol buffer. If not None, recreates the 243 `ResourceVariable` object with its contents. `variable_def` and other 244 arguments (except for import_scope) are mutually exclusive. 245 import_scope: Optional `string`. Name scope to add to the 246 ResourceVariable. Only used when `variable_def` is provided. 247 constraint: An optional projection function to be applied to the variable 248 after being updated by an `Optimizer` (e.g. used to implement norm 249 constraints or value constraints for layer weights). The function must 250 take as input the unprojected Tensor representing the value of the 251 variable and return the Tensor for the projected value 252 (which must have the same shape). Constraints are not safe to 253 use when doing asynchronous distributed training. 254 255 Raises: 256 ValueError: If the initial value is not specified, or does not have a 257 shape and `validate_shape` is `True`. 258 259 @compatibility(eager) 260 When Eager Execution is enabled, the default for the `collections` argument 261 is `None`, which signifies that this `Variable` will not be added to any 262 collections. 263 @end_compatibility 264 """ 265 if variable_def: 266 if initial_value is not None: 267 raise ValueError("variable_def and initial_value are mutually " 268 "exclusive.") 269 if not context.in_graph_mode(): 270 raise ValueError("Creating ResourceVariable from variable_def" 271 " only supported in GRAPH mode.") 272 self._init_from_proto(variable_def, import_scope=import_scope) 273 else: 274 self._init_from_args( 275 initial_value=initial_value, 276 trainable=trainable, 277 collections=collections, 278 validate_shape=validate_shape, 279 caching_device=caching_device, 280 name=name, 281 dtype=dtype, 282 constraint=constraint) 283 284 # pylint: disable=unused-argument 285 def _init_from_args(self, 286 initial_value=None, 287 trainable=True, 288 collections=None, 289 validate_shape=True, 290 caching_device=None, 291 name=None, 292 dtype=None, 293 constraint=None): 294 """Creates a variable. 295 296 Args: 297 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 298 which is the initial value for the Variable. The initial value must have 299 a shape specified unless `validate_shape` is set to False. Can also be a 300 callable with no argument that returns the initial value when called. 301 (Note that initializer functions from init_ops.py must first be bound 302 to a shape before being used here.) 303 trainable: If `True`, the default, also adds the variable to the graph 304 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 305 the default list of variables to use by the `Optimizer` classes. 306 collections: List of graph collections keys. The new variable is added to 307 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 308 validate_shape: Ignored. Provided for compatibility with tf.Variable. 309 caching_device: Optional device string or function describing where the 310 Variable should be cached for reading. Defaults to the Variable's 311 device. If not `None`, caches on another device. Typical use is to 312 cache on the device where the Ops using the Variable reside, to 313 deduplicate copying through `Switch` and other conditional statements. 314 name: Optional name for the variable. Defaults to `'Variable'` and gets 315 uniquified automatically. 316 dtype: If set, initial_value will be converted to the given type. 317 If None, either the datatype will be kept (if initial_value is 318 a Tensor) or float32 will be used (if it is a Python object convertible 319 to a Tensor). 320 constraint: An optional projection function to be applied to the variable 321 after being updated by an `Optimizer` (e.g. used to implement norm 322 constraints or value constraints for layer weights). The function must 323 take as input the unprojected Tensor representing the value of the 324 variable and return the Tensor for the projected value 325 (which must have the same shape). Constraints are not safe to 326 use when doing asynchronous distributed training. 327 328 Raises: 329 ValueError: If the initial value is not specified, or does not have a 330 shape and `validate_shape` is `True`. 331 332 @compatibility(eager) 333 When Eager Execution is enabled, variables are never added to collections. 334 It is not implicitly added to the `GLOBAL_VARIABLES` or 335 `TRAINABLE_VARIABLES` collections, and the `collections` argument is 336 ignored. 337 @end_compatibility 338 """ 339 if initial_value is None: 340 raise ValueError("initial_value must be specified.") 341 init_from_fn = callable(initial_value) 342 343 if collections is None: 344 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 345 if not isinstance(collections, (list, tuple, set)): 346 raise ValueError( 347 "collections argument to Variable constructor must be a list, tuple, " 348 "or set. Got %s of type %s" % (collections, type(collections))) 349 if constraint is not None and not callable(constraint): 350 raise ValueError("The `constraint` argument must be a callable.") 351 352 if isinstance(initial_value, checkpointable.CheckpointInitialValue): 353 self._maybe_initialize_checkpointable() 354 self._update_uid = initial_value.checkpoint_position.restore_uid 355 initial_value = initial_value.wrapped_value 356 357 self._trainable = trainable 358 if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: 359 collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] 360 self._save_slice_info = None 361 # Store the graph key so optimizers know how to only retrieve variables from 362 # this graph. 363 self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 364 with ops.init_scope(): 365 self._in_graph_mode = context.in_graph_mode() 366 with ops.name_scope(name, "Variable", [] 367 if init_from_fn else [initial_value]) as name: 368 # pylint: disable=protected-access 369 handle_name = ops._name_from_scope_name(name) 370 if init_from_fn: 371 # Use attr_scope and device(None) to simulate the behavior of 372 # colocate_with when the variable we want to colocate with doesn't 373 # yet exist. 374 if self._in_graph_mode: 375 attr = attr_value_pb2.AttrValue( 376 list=attr_value_pb2.AttrValue.ListValue( 377 s=[compat.as_bytes("loc:@%s" % handle_name)])) 378 with ops.get_default_graph()._attr_scope({"_class": attr}): 379 with ops.name_scope("Initializer"), ops.device(None): 380 initial_value = ops.convert_to_tensor( 381 initial_value(), name="initial_value", dtype=dtype) 382 self._handle = _eager_safe_variable_handle( 383 shape=initial_value.get_shape(), 384 dtype=initial_value.dtype.base_dtype, 385 shared_name=handle_name, 386 name=name, 387 graph_mode=self._in_graph_mode) 388 self._handle_device = ( 389 self._handle.device if self._in_graph_mode else 390 context.get_default_context().device_name) 391 self._shape = initial_value.get_shape() 392 else: 393 initial_value = initial_value() 394 with ops.name_scope("Initializer"): 395 initial_value = ops.convert_to_tensor( 396 initial_value, name="initial_value", dtype=dtype) 397 self._handle = _eager_safe_variable_handle( 398 shape=initial_value.get_shape(), 399 dtype=initial_value.dtype.base_dtype, 400 shared_name=handle_name, 401 name=name, 402 graph_mode=False) 403 self._handle_device = ( 404 self._handle.device if self._in_graph_mode else 405 context.get_default_context().device_name) 406 self._shape = initial_value.get_shape() 407 # pylint: enable=protected-access 408 409 # Or get the initial value from a Tensor or Python object. 410 else: 411 with ops.name_scope("Initializer"): 412 initial_value = ops.convert_to_tensor( 413 initial_value, name="initial_value", dtype=dtype) 414 # pylint: disable=protected-access 415 if (self._in_graph_mode and initial_value is not None and 416 initial_value.op._get_control_flow_context() is not None): 417 raise ValueError( 418 "Initializer for variable %s is from inside a control-flow " 419 "construct, such as a loop or conditional. When creating a " 420 "variable inside a loop or conditional, use a lambda as the " 421 "initializer." % name) 422 # pylint: enable=protected-access 423 self._handle = _eager_safe_variable_handle( 424 shape=initial_value.get_shape(), 425 dtype=initial_value.dtype.base_dtype, 426 shared_name=handle_name, 427 name=name, 428 graph_mode=self._in_graph_mode) 429 self._handle_device = (self._handle.device if self._in_graph_mode else 430 context.get_default_context().device_name) 431 self._shape = initial_value.get_shape() 432 433 self._initial_value = initial_value if self._in_graph_mode else None 434 self._handle_name = handle_name + ":0" 435 self._dtype = initial_value.dtype.base_dtype 436 self._constraint = constraint 437 438 if self._in_graph_mode: 439 with ops.name_scope("IsInitialized"): 440 self._is_initialized_op = ( 441 gen_resource_variable_ops.var_is_initialized_op(self._handle)) 442 if initial_value is not None: 443 with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): 444 self._initializer_op = ( 445 gen_resource_variable_ops.assign_variable_op( 446 self._handle, 447 self._try_guard_against_uninitialized_dependencies( 448 initial_value), 449 name=n)) 450 with ops.name_scope("Read"), ops.colocate_with(self._handle): 451 # Manually assign reads to the handle's device to avoid log 452 # messages. 453 with ops.device(self._handle_device): 454 value = self._read_variable_op() 455 self._graph_element = value 456 if caching_device is not None: 457 # Variables may be created in a tf.device() or ops.colocate_with() 458 # context. At the same time, users would expect caching device to 459 # be independent of this context, and/or would not expect the 460 # current device context to be merged with the caching device 461 # spec. Therefore we reset the colocation stack before creating 462 # the cached value. Note that resetting the colocation stack will 463 # also reset the device stack. 464 with ops.colocate_with(None, ignore_existing=True): 465 with ops.device(caching_device): 466 self._cached_value = array_ops.identity(value) 467 else: 468 self._cached_value = None 469 else: 470 gen_resource_variable_ops.assign_variable_op(self._handle, 471 initial_value) 472 self._is_initialized_op = None 473 self._initializer_op = None 474 self._graph_element = None 475 if caching_device: 476 with ops.device(caching_device): 477 self._cached_value = self._read_variable_op() 478 else: 479 self._cached_value = None 480 if context.in_graph_mode(): 481 ops.add_to_collections(collections, self) 482 elif ops.GraphKeys.GLOBAL_STEP in collections: 483 ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self) 484 485 if not self._in_graph_mode: 486 # After the handle has been created, set up a way to clean it up when 487 # executing eagerly. We'll hold the only reference to the deleter, so that 488 # when this object is garbage collected the deleter will be too. This 489 # means ResourceVariables can be part of reference cycles without those 490 # cycles being uncollectable, and means that no __del__ will be defined at 491 # all in graph mode. 492 self._handle_deleter = EagerResourceDeleter( 493 handle=self._handle, handle_device=self._handle_device) 494 495 def _init_from_proto(self, variable_def, import_scope=None): 496 """Initializes from `VariableDef` proto.""" 497 # Note that init_from_proto is currently not supported in Eager mode. 498 assert context.in_graph_mode() 499 self._in_graph_mode = True 500 assert isinstance(variable_def, variable_pb2.VariableDef) 501 if not variable_def.is_resource: 502 raise ValueError("Trying to restore Variable as ResourceVariable.") 503 504 # Create from variable_def. 505 g = ops.get_default_graph() 506 self._handle = g.as_graph_element( 507 ops.prepend_name_scope( 508 variable_def.variable_name, import_scope=import_scope)) 509 self._shape = tensor_shape.TensorShape( 510 self._handle.op.get_attr("shape")) 511 self._handle_device = self._handle.device 512 self._handle_name = self._handle.name 513 self._initializer_op = g.as_graph_element( 514 ops.prepend_name_scope( 515 variable_def.initializer_name, import_scope=import_scope)) 516 # Check whether initial_value_name exists for backwards compatibility. 517 if (hasattr(variable_def, "initial_value_name") and 518 variable_def.initial_value_name): 519 self._initial_value = g.as_graph_element( 520 ops.prepend_name_scope(variable_def.initial_value_name, 521 import_scope=import_scope)) 522 else: 523 self._initial_value = None 524 if variable_def.snapshot_name: 525 self._cached_value = g.as_graph_element( 526 ops.prepend_name_scope( 527 variable_def.snapshot_name, import_scope=import_scope)) 528 else: 529 self._cached_value = None 530 if variable_def.HasField("save_slice_info_def"): 531 self._save_slice_info = variables.Variable.SaveSliceInfo( 532 save_slice_info_def=variable_def.save_slice_info_def, 533 import_scope=import_scope) 534 else: 535 self._save_slice_info = None 536 self._caching_device = None 537 self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype")) 538 self._graph_element = self.value() 539 self._constraint = None 540 541 def __nonzero__(self): 542 return self.__bool__() 543 544 def __bool__(self): 545 return bool(self.read_value()) 546 547 @property 548 def dtype(self): 549 """The dtype of this variable.""" 550 return self._dtype 551 552 @property 553 def device(self): 554 """The device this variable is on.""" 555 return self._handle_device 556 557 @property 558 def graph(self): 559 """The `Graph` of this variable.""" 560 return self._handle.graph 561 562 @property 563 def name(self): 564 """The name of the handle for this variable.""" 565 return self._handle_name 566 567 @property 568 def shape(self): 569 """The shape of this variable.""" 570 return self._shape 571 572 @property 573 def create(self): 574 """The op responsible for initializing this variable.""" 575 if not self._in_graph_mode: 576 raise RuntimeError("Calling create in EAGER mode not supported.") 577 return self._initializer_op 578 579 @property 580 def handle(self): 581 """The handle by which this variable can be accessed.""" 582 return self._handle 583 584 def value(self): 585 """A cached operation which reads the value of this variable.""" 586 if self._cached_value is not None: 587 return self._cached_value 588 with ops.colocate_with(None, ignore_existing=True): 589 with ops.device(self._handle_device): 590 return self._read_variable_op() 591 592 def _as_graph_element(self): 593 """Conversion function for Graph.as_graph_element().""" 594 return self._graph_element 595 596 @property 597 def initializer(self): 598 """The op responsible for initializing this variable.""" 599 return self._initializer_op 600 601 @property 602 def initial_value(self): 603 """Returns the Tensor used as the initial value for the variable.""" 604 if context.in_eager_mode(): 605 raise RuntimeError("initial_value not supported in EAGER mode.") 606 return self._initial_value 607 608 @property 609 def constraint(self): 610 """Returns the constraint function associated with this variable. 611 612 Returns: 613 The constraint function that was passed to the variable constructor. 614 Can be `None` if no constraint was passed. 615 """ 616 return self._constraint 617 618 @property 619 def op(self): 620 """The op for this variable.""" 621 return self._handle.op 622 623 def eval(self, session=None): 624 """Evaluates and returns the value of this variable.""" 625 if context.in_eager_mode(): 626 raise RuntimeError("Trying to eval in EAGER mode") 627 return self._graph_element.eval(session=session) 628 629 def numpy(self): 630 if context.in_graph_mode(): 631 raise NotImplementedError( 632 "numpy() is only available when eager execution is enabled.") 633 return self.read_value().numpy() 634 635 def count_up_to(self, limit): 636 """Increments this variable until it reaches `limit`. 637 638 When that Op is run it tries to increment the variable by `1`. If 639 incrementing the variable would bring it above `limit` then the Op raises 640 the exception `OutOfRangeError`. 641 642 If no error is raised, the Op outputs the value of the variable before 643 the increment. 644 645 This is essentially a shortcut for `count_up_to(self, limit)`. 646 647 Args: 648 limit: value at which incrementing the variable raises an error. 649 650 Returns: 651 A `Tensor` that will hold the variable value before the increment. If no 652 other Op modifies this variable, the values produced will all be 653 distinct. 654 """ 655 return gen_state_ops.resource_count_up_to(self.handle, limit=limit, 656 T=self.dtype) 657 658 def _set_save_slice_info(self, save_slice_info): 659 """Sets the slice info for this `ResourceVariable`. 660 661 Args: 662 save_slice_info: A `Variable.SaveSliceInfo` object. 663 """ 664 self._save_slice_info = save_slice_info 665 666 def _get_save_slice_info(self): 667 return self._save_slice_info 668 669 def _read_variable_op(self): 670 if hasattr(self, "_trainable") and self._trainable: 671 tape.watch_variable(self) 672 return gen_resource_variable_ops.read_variable_op(self._handle, 673 self._dtype) 674 675 def read_value(self): 676 """Constructs an op which reads the value of this variable. 677 678 Should be used when there are multiple reads, or when it is desirable to 679 read the value only after some condition is true. 680 681 Returns: 682 the read operation. 683 """ 684 with ops.name_scope("Read"): 685 # Ensure we read the variable in the same device as the handle. 686 with ops.device(self._handle_device): 687 value = self._read_variable_op() 688 # Return an identity so it can get placed on whatever device the context 689 # specifies instead of the device where the variable is. 690 return array_ops.identity(value) 691 692 def sparse_read(self, indices, name=None): 693 """Reads the value of this variable sparsely, using `gather`.""" 694 with ops.name_scope("Gather" if name is None else name) as name: 695 if self._trainable: 696 tape.watch_variable(self) 697 value = gen_resource_variable_ops.resource_gather( 698 self._handle, indices, dtype=self._dtype, name=name) 699 return array_ops.identity(value) 700 701 def to_proto(self, export_scope=None): 702 """Converts a `ResourceVariable` to a `VariableDef` protocol buffer. 703 704 Args: 705 export_scope: Optional `string`. Name scope to remove. 706 707 Raises: 708 RuntimeError: If run in EAGER mode. 709 710 Returns: 711 A `VariableDef` protocol buffer, or `None` if the `Variable` is not 712 in the specified name scope. 713 """ 714 if context.in_eager_mode(): 715 raise RuntimeError("to_proto not supported in EAGER mode.") 716 if export_scope is None or self.handle.name.startswith(export_scope): 717 var_def = variable_pb2.VariableDef() 718 var_def.variable_name = ops.strip_name_scope(self.handle.name, 719 export_scope) 720 if self._initial_value is not None: 721 # This is inside an if-statement for backwards compatibility, since 722 # self._initial_value might be None for variables constructed from old 723 # protos. 724 var_def.initial_value_name = ops.strip_name_scope( 725 self._initial_value.name, export_scope) 726 var_def.initializer_name = ops.strip_name_scope(self.initializer.name, 727 export_scope) 728 if self._cached_value is not None: 729 var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name, 730 export_scope) 731 var_def.is_resource = True 732 if self._save_slice_info: 733 var_def.save_slice_info_def.MergeFrom( 734 self._save_slice_info.to_proto(export_scope=export_scope)) 735 return var_def 736 else: 737 return None 738 739 @staticmethod 740 def from_proto(variable_def, import_scope=None): 741 if context.in_eager_mode(): 742 raise RuntimeError("from_proto not supported in EAGER mode.") 743 return ResourceVariable( 744 variable_def=variable_def, import_scope=import_scope) 745 746 @staticmethod 747 def _OverloadAllOperators(): # pylint: disable=invalid-name 748 """Register overloads for all operators.""" 749 for operator in ops.Tensor.OVERLOADABLE_OPERATORS: 750 ResourceVariable._OverloadOperator(operator) 751 # For slicing, bind getitem differently than a tensor (use SliceHelperVar 752 # instead) 753 # pylint: disable=protected-access 754 setattr(ResourceVariable, "__getitem__", array_ops._SliceHelperVar) 755 756 def _AsTensor(self): 757 return self.value() 758 759 def _ref(self): 760 """Unsupported.""" 761 raise NotImplementedError("ResourceVariable does not implement _ref()") 762 763 def set_shape(self, shape): 764 """Unsupported.""" 765 raise NotImplementedError("ResourceVariable does not implement set_shape()") 766 767 @staticmethod 768 def _OverloadOperator(operator): # pylint: disable=invalid-name 769 """Defer an operator overload to `ops.Tensor`. 770 771 We pull the operator out of ops.Tensor dynamically to avoid ordering issues. 772 773 Args: 774 operator: string. The operator name. 775 """ 776 777 def _run_op(a, *args): 778 # pylint: disable=protected-access 779 value = a._AsTensor() 780 return getattr(ops.Tensor, operator)(value, *args) 781 782 # Propagate __doc__ to wrapper 783 try: 784 _run_op.__doc__ = getattr(ops.Tensor, operator).__doc__ 785 except AttributeError: 786 pass 787 788 setattr(ResourceVariable, operator, _run_op) 789 790 __array_priority__ = 100 791 792 def assign_sub(self, delta, use_locking=None, name=None): 793 # TODO(apassos): this here and below is not atomic. Consider making it 794 # atomic if there's a way to do so without a performance cost for those who 795 # don't need it. 796 return self._lazy_read(gen_resource_variable_ops.assign_sub_variable_op( 797 self.handle, 798 ops.convert_to_tensor(delta, dtype=self.dtype), 799 name=name)) 800 801 def assign_add(self, delta, use_locking=None, name=None): 802 return self._lazy_read(gen_resource_variable_ops.assign_add_variable_op( 803 self.handle, 804 ops.convert_to_tensor(delta, dtype=self.dtype), 805 name=name)) 806 807 def _lazy_read(self, op): 808 if hasattr(self, "_trainable") and self._trainable: 809 tape.watch_variable(self) 810 return _UnreadVariable( 811 self._handle, self.dtype, self._handle_device, self._shape, 812 self._in_graph_mode, 813 self._handle_deleter if not self._in_graph_mode else None, op) 814 815 def assign(self, value, use_locking=None, name=None): 816 value_tensor = ops.convert_to_tensor(value, dtype=self.dtype) 817 self._shape.assert_is_compatible_with(value_tensor.shape) 818 return self._lazy_read( 819 gen_resource_variable_ops.assign_variable_op( 820 self.handle, 821 value_tensor, 822 name=name)) 823 824 def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask, 825 end_mask, ellipsis_mask, new_axis_mask, 826 shrink_axis_mask): 827 return self._lazy_read( 828 gen_array_ops.resource_strided_slice_assign( 829 ref=self.handle, 830 begin=begin, 831 end=end, 832 strides=strides, 833 value=value, 834 name=name, 835 begin_mask=begin_mask, 836 end_mask=end_mask, 837 ellipsis_mask=ellipsis_mask, 838 new_axis_mask=new_axis_mask, 839 shrink_axis_mask=shrink_axis_mask)) 840 841 def __int__(self): 842 if self.dtype != dtypes.int32 and self.dtype != dtypes.int64: 843 raise TypeError("Non-integer variable can't be converted to integer.") 844 return int(self.value().numpy()) 845 846 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 847 del name 848 if dtype is not None and dtype != self.dtype: 849 print("trying to switch the dtype to ", dtype, " from ", self.dtype) 850 return NotImplemented 851 if as_ref: 852 return self.read_value().op.inputs[0] 853 else: 854 return self.value() 855 856 def __iadd__(self, unused_other): 857 raise RuntimeError("Variable += value not supported. Use " 858 "variable.assign_add(value) to modify the variable " 859 "value and variable = variable + value to get a new " 860 "Tensor object.") 861 862 def __isub__(self, unused_other): 863 raise RuntimeError("Variable -= value not supported. Use " 864 "variable.assign_sub(value) to modify the variable " 865 "value and variable = variable - value to get a new " 866 "Tensor object.") 867 868 def __imul__(self, unused_other): 869 raise RuntimeError("Variable *= value not supported. Use " 870 "variable.assign_mul(value) to modify the variable " 871 "value and variable = variable * value to get a new " 872 "Tensor object.") 873 874 def __idiv__(self, unused_other): 875 raise RuntimeError("Variable /= value not supported. Use " 876 "variable.assign_div(value) to modify the variable " 877 "value and variable = variable / value to get a new " 878 "Tensor object.") 879 880 def __itruediv__(self, unused_other): 881 raise RuntimeError("Variable /= value not supported. Use " 882 "variable.assign_div(value) to modify the variable " 883 "value and variable = variable / value to get a new " 884 "Tensor object.") 885 886 def __irealdiv__(self, unused_other): 887 raise RuntimeError("Variable /= value not supported. Use " 888 "variable.assign_div(value) to modify the variable " 889 "value and variable = variable / value to get a new " 890 "Tensor object.") 891 892 def __ipow__(self, unused_other): 893 raise RuntimeError("Variable **= value not supported. Use " 894 "value and variable = variable ** value to get a new " 895 "Tensor object.") 896 897 898def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False): 899 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 900 901 902class _UnreadVariable(ResourceVariable): 903 """Represents a future for a read of a variable. 904 905 Pretends to be the tensor if anyone looks. 906 """ 907 908 def __init__(self, handle, dtype, handle_device, # pylint: disable=super-init-not-called 909 shape, in_graph_mode, deleter, parent_op): 910 # We do not call super init on purpose. 911 self._trainable = False 912 self._save_slice_info = None 913 self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 914 self._in_graph_mode = in_graph_mode 915 self._handle = handle 916 self._handle_device = handle_device 917 self._shape = shape 918 self._initial_value = None 919 if isinstance(self._handle, ops.EagerTensor): 920 self._handle_name = "" 921 else: 922 self._handle_name = self._handle.name 923 self._dtype = dtype 924 self._constraint = None 925 self._cached_value = None 926 self._is_initialized_op = None 927 self._initializer_op = None 928 self._parent_op = parent_op 929 if context.in_graph_mode(): 930 self._graph_element = self.read_value() 931 else: 932 self._graph_element = None 933 self._handle_deleter = deleter 934 935 def value(self): 936 return self._read_variable_op() 937 938 def read_value(self): 939 return self._read_variable_op() 940 941 def _read_variable_op(self): 942 with ops.control_dependencies([self._parent_op]): 943 return gen_resource_variable_ops.read_variable_op(self._handle, 944 self._dtype) 945 946 def set_shape(self, shape): 947 self._shape = shape 948 949 @property 950 def op(self): 951 """The op for this variable.""" 952 return self._parent_op 953 954ops.register_tensor_conversion_function(_UnreadVariable, _dense_var_to_tensor) 955ops.register_dense_tensor_like_type(_UnreadVariable) 956 957# Register a conversion function which reads the value of the variable, 958# allowing instances of the class to be used as tensors. 959 960# Note: registering for Variable after ResourceVariable because inheritance will 961# otherwise lead to the wrong behavior. 962ops.register_tensor_conversion_function(ResourceVariable, _dense_var_to_tensor) 963ops.register_tensor_conversion_function( 964 variables.Variable, variables.Variable._TensorConversionFunction) # pylint: disable=protected-access 965 966# pylint: disable=protected-access 967ResourceVariable._OverloadAllOperators() 968ops.register_dense_tensor_like_type(ResourceVariable) 969 970 971@ops.RegisterGradient("ReadVariableOp") 972def _ReadGrad(_, grad): 973 """Gradient for read op.""" 974 return grad 975 976 977@ops.RegisterGradient("ResourceGather") 978def _GatherGrad(op, grad): 979 """Gradient for gather op.""" 980 # Build appropriately shaped IndexedSlices 981 handle = op.inputs[0] 982 indices = op.inputs[1] 983 params_shape = gen_resource_variable_ops.variable_shape(handle) 984 size = array_ops.expand_dims(array_ops.size(indices), 0) 985 values_shape = array_ops.concat([size, params_shape[1:]], 0) 986 values = array_ops.reshape(grad, values_shape) 987 indices = array_ops.reshape(indices, size) 988 return (ops.IndexedSlices(values, indices, params_shape), None) 989 990 991def _to_proto_fn(v, export_scope=None): 992 """Converts Variable and ResourceVariable to VariableDef for collections.""" 993 return v.to_proto(export_scope=export_scope) 994 995 996def _from_proto_fn(v, import_scope=None): 997 """Creates Variable or ResourceVariable from VariableDef as needed.""" 998 if v.is_resource: 999 return ResourceVariable.from_proto(v, import_scope=import_scope) 1000 return variables.Variable.from_proto(v, import_scope=import_scope) 1001 1002 1003ops.register_proto_function( 1004 ops.GraphKeys.GLOBAL_VARIABLES, 1005 proto_type=variable_pb2.VariableDef, 1006 to_proto=_to_proto_fn, 1007 from_proto=_from_proto_fn) 1008ops.register_proto_function( 1009 ops.GraphKeys.TRAINABLE_VARIABLES, 1010 proto_type=variable_pb2.VariableDef, 1011 to_proto=_to_proto_fn, 1012 from_proto=_from_proto_fn) 1013ops.register_proto_function( 1014 ops.GraphKeys.MOVING_AVERAGE_VARIABLES, 1015 proto_type=variable_pb2.VariableDef, 1016 to_proto=_to_proto_fn, 1017 from_proto=_from_proto_fn) 1018ops.register_proto_function( 1019 ops.GraphKeys.LOCAL_VARIABLES, 1020 proto_type=variable_pb2.VariableDef, 1021 to_proto=_to_proto_fn, 1022 from_proto=_from_proto_fn) 1023ops.register_proto_function( 1024 ops.GraphKeys.MODEL_VARIABLES, 1025 proto_type=variable_pb2.VariableDef, 1026 to_proto=_to_proto_fn, 1027 from_proto=_from_proto_fn) 1028 1029 1030def is_resource_variable(var): 1031 """"Returns True if `var` is to be considered a ResourceVariable.""" 1032 return isinstance(var, ResourceVariable) or hasattr( 1033 var, "_should_act_as_resource_variable") 1034