1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for variable store.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import gc 22import threading 23 24import numpy 25 26from tensorflow.python.eager import context 27from tensorflow.python.eager import function 28from tensorflow.python.eager import wrap_function 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import errors 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import test_util 34from tensorflow.python.layers import core as core_layers 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import init_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import resource_variable_ops 40from tensorflow.python.ops import state_ops 41from tensorflow.python.ops import variable_scope 42from tensorflow.python.ops import variables as variables_lib 43from tensorflow.python.platform import test 44from tensorflow.python.util import compat 45from tensorflow.python.util import tf_inspect 46 47 48def run_inside_wrap_function_in_eager_mode(graph_function): 49 """Decorator to execute the same graph code in eager and graph modes. 50 51 In graph mode, we just execute the graph_function passed as argument. In eager 52 mode, we wrap the function using wrap_function and then execute the wrapped 53 result. 54 55 Args: 56 graph_function: python function containing graph code to be wrapped 57 58 Returns: 59 decorated function 60 """ 61 def wrap_and_execute(self): 62 if context.executing_eagerly(): 63 wrapped = wrap_function.wrap_function(graph_function, [self]) 64 # use the wrapped graph function 65 wrapped() 66 else: 67 # use the original function 68 graph_function(self) 69 return wrap_and_execute 70 71 72class VariableScopeTest(test.TestCase): 73 74 def tearDown(self): 75 gc.collect() 76 # This will only contain uncollectable garbage, i.e. reference cycles 77 # involving objects with __del__ defined. 78 self.assertEqual(0, len(gc.garbage)) 79 80 @test_util.run_in_graph_and_eager_modes 81 @run_inside_wrap_function_in_eager_mode 82 def testGetVar(self): 83 vs = variable_scope._get_default_variable_store() 84 v = vs.get_variable("v", [1]) 85 v1 = vs.get_variable("v", [1]) 86 self.assertIs(v, v1) 87 88 @test_util.run_in_graph_and_eager_modes 89 @run_inside_wrap_function_in_eager_mode 90 def testResource(self): 91 vs = variable_scope._get_default_variable_store() 92 v1 = vs.get_variable("v", [1], use_resource=True) 93 self.assertTrue(isinstance(v1, resource_variable_ops.ResourceVariable)) 94 95 @test_util.run_in_graph_and_eager_modes 96 @run_inside_wrap_function_in_eager_mode 97 def testNameExists(self): 98 vs = variable_scope._get_default_variable_store() 99 # No check by default, so we can both create and get existing names. 100 v = vs.get_variable("v", [1]) 101 v1 = vs.get_variable("v", [1]) 102 self.assertIs(v, v1) 103 104 # When reuse is False, we fail when variables are already there. 105 vs.get_variable("w", [1], reuse=False) # That's ok. 106 with self.assertRaises(ValueError): 107 vs.get_variable("v", [1], reuse=False) # That fails. 108 # When reuse is True, we fail when variables are new. 109 vs.get_variable("v", [1], reuse=True) # That's ok. 110 with self.assertRaises(ValueError): 111 vs.get_variable("u", [1], reuse=True) # That fails. 112 113 @test_util.run_in_graph_and_eager_modes 114 @run_inside_wrap_function_in_eager_mode 115 def testNamelessStore(self): 116 vs = variable_scope._get_default_variable_store() 117 vs.get_variable("v1", [2]) 118 vs.get_variable("v2", [2]) 119 expected_names = ["%s:0" % name for name in ["v1", "v2"]] 120 self.assertEqual( 121 set(expected_names), set(v.name for v in vs._vars.values())) 122 123 # TODO(mihaimaruseac): Not converted to use wrap_function because of 124 # TypeError: Expected tf.group() expected Tensor arguments not 'None' with 125 # type '<type 'NoneType'>' 126 @test_util.run_in_graph_and_eager_modes 127 def testVarScopeInitializer(self): 128 init = init_ops.constant_initializer(0.3) 129 with variable_scope.variable_scope("tower0") as tower: 130 with variable_scope.variable_scope("foo", initializer=init): 131 v = variable_scope.get_variable("v", []) 132 self.evaluate(variables_lib.variables_initializer([v])) 133 self.assertAllClose(self.evaluate(v.value()), 0.3) 134 with variable_scope.variable_scope(tower, initializer=init): 135 w = variable_scope.get_variable("w", []) 136 self.evaluate(variables_lib.variables_initializer([w])) 137 self.assertAllClose(self.evaluate(w.value()), 0.3) 138 139 @test_util.run_in_graph_and_eager_modes 140 @run_inside_wrap_function_in_eager_mode 141 def testVarScopeConstraint(self): 142 constraint = lambda x: 0. * x 143 with variable_scope.variable_scope("tower1") as tower: 144 with variable_scope.variable_scope("foo", constraint=constraint): 145 v = variable_scope.get_variable("v", []) 146 self.assertEqual(v.constraint, constraint) 147 with variable_scope.variable_scope(tower, constraint=constraint): 148 w = variable_scope.get_variable("w", []) 149 self.assertEqual(w.constraint, constraint) 150 151 @test_util.run_in_graph_and_eager_modes 152 @run_inside_wrap_function_in_eager_mode 153 def testVarScopeNestingError(self): 154 with variable_scope.variable_scope("aa"): 155 scope = variable_scope.variable_scope("bb") 156 scope.__enter__() 157 with variable_scope.variable_scope("cc"): 158 with self.assertRaises(RuntimeError): 159 scope.__exit__(None, None, None) 160 scope.__exit__(None, None, None) 161 162 # TODO(mihaimaruseac): Not converted to use wrap_function because of 163 # TypeError: Fetch argument <tf.Variable 'string:0' shape=() dtype=string> 164 # has invalid type <class '...ResourceVariable'>, must be a string or Tensor. 165 # (Can not convert a ResourceVariable into a Tensor or Operation.) 166 @test_util.run_deprecated_v1 167 def testStringDefaultInitializer(self): 168 with self.cached_session(): 169 v = variable_scope.get_variable("string", shape=[], dtype=dtypes.string) 170 variables_lib.global_variables_initializer().run() 171 self.assertAllEqual(compat.as_bytes(self.evaluate(v)), b"") 172 173 @test_util.run_in_graph_and_eager_modes 174 @run_inside_wrap_function_in_eager_mode 175 def testVarScopeDType(self): 176 with variable_scope.variable_scope("tower2") as tower: 177 with variable_scope.variable_scope("foo", dtype=dtypes.float16): 178 v = variable_scope.get_variable("v", []) 179 self.assertEqual(v.dtype.base_dtype, dtypes.float16) 180 with variable_scope.variable_scope(tower, dtype=dtypes.float16): 181 w = variable_scope.get_variable("w", []) 182 self.assertEqual(w.dtype.base_dtype, dtypes.float16) 183 184 def testGetVariableInGraphNestedUnderEagerContext(self): 185 with context.eager_mode(): 186 187 @function.defun 188 def f(): 189 v = variable_scope.get_variable("should_be_resource", []) 190 self.assertEqual(type(v), resource_variable_ops.ResourceVariable) 191 192 f() 193 194 def testEagerVariableStore(self): 195 with context.eager_mode(): 196 store = variable_scope.EagerVariableStore() 197 with store.as_default(): 198 v = variable_scope.get_variable("v", shape=(), trainable=True) 199 w = variable_scope.get_variable("w", shape=(), trainable=False) 200 201 self.assertTrue(v in store.variables()) 202 self.assertTrue(w in store.variables()) 203 self.assertTrue(v in store.trainable_variables()) 204 self.assertFalse(w in store.trainable_variables()) 205 self.assertFalse(v in store.non_trainable_variables()) 206 self.assertTrue(w in store.non_trainable_variables()) 207 208 # Test copying. 209 new_store = store.copy() 210 with new_store.as_default(): 211 new_v = variable_scope.get_variable("v") 212 new_w = variable_scope.get_variable("w") 213 self.assertEqual(new_v.numpy(), v.numpy()) 214 self.assertEqual(new_w.numpy(), w.numpy()) 215 self.assertTrue(new_v in new_store.variables()) 216 self.assertTrue(new_w in new_store.variables()) 217 self.assertTrue(new_v in new_store.trainable_variables()) 218 self.assertFalse(new_w in new_store.trainable_variables()) 219 self.assertFalse(new_v in new_store.non_trainable_variables()) 220 self.assertTrue(new_w in new_store.non_trainable_variables()) 221 222 # Check that variables are separate instances. 223 for v in store.variables(): 224 v.assign(-1) 225 for v in new_store.variables(): 226 v.assign(1) 227 for v in store.variables(): 228 self.assertEqual(v.numpy(), -1) 229 for v in new_store.variables(): 230 self.assertEqual(v.numpy(), 1) 231 232 def testEagerVariableStoreWithEagerDefun(self): 233 with context.eager_mode(): 234 235 @function.defun 236 def f(): 237 x = constant_op.constant([[2.0]]) 238 d1 = core_layers.Dense( 239 1, name="my_dense", kernel_initializer=init_ops.ones_initializer()) 240 _ = d1(x) # create variables 241 self.assertEqual(len(d1.variables), 2) 242 v1, v2 = d1.variables 243 d2 = core_layers.Dense( 244 1, 245 name="my_dense", 246 kernel_initializer=init_ops.ones_initializer(), 247 _reuse=True) 248 _ = d2(x) 249 self.assertEqual(len(d2.variables), 2) 250 v3, v4 = d2.variables 251 self.assertIs(v1, v3) 252 self.assertIs(v2, v4) 253 254 f() 255 256 # TODO(mihaimaruseac): Not converted to use wrap_function because of 257 # obtaining different results in the eager case compared to the graph one 258 @test_util.run_in_graph_and_eager_modes 259 def testEagerVariablesStoreAddsToCollections(self): 260 store = variable_scope.EagerVariableStore() 261 with store.as_default(): 262 trainable = variable_scope.get_variable("v1", [], trainable=True) 263 not_trainable = variable_scope.get_variable("v2", [], trainable=False) 264 concat = variable_scope.get_variable( 265 "v3", [], collections=[ops.GraphKeys.CONCATENATED_VARIABLES]) 266 self.assertEqual( 267 ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES), 268 [trainable, not_trainable]) 269 self.assertEqual( 270 ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES), 271 [trainable, concat]) 272 self.assertEqual( 273 ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES), [concat]) 274 275 def testEagerVariablesOutsideStoreNotAddedToCollections(self): 276 with context.eager_mode(): 277 variable_scope.get_variable("v1", [], trainable=True) 278 variable_scope.get_variable("v2", [], trainable=False) 279 self.assertFalse(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) 280 self.assertFalse(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) 281 282 def testEagerVariableStoreWithFunctionalLayer(self): 283 with context.eager_mode(): 284 container = variable_scope.EagerVariableStore() 285 x = constant_op.constant([[2.0]]) 286 with container.as_default(): 287 y = core_layers.dense(x, 1, name="my_dense", 288 kernel_initializer=init_ops.ones_initializer()) 289 self.assertAllEqual(y, [[2.0]]) 290 self.assertEqual(len(container.variables()), 2) 291 # Recreate the layer to test reuse. 292 with container.as_default(): 293 core_layers.dense(x, 1, name="my_dense", 294 kernel_initializer=init_ops.ones_initializer()) 295 self.assertEqual(len(container.variables()), 2) 296 297 # TODO(mihaimaruseac): Not converted to use wrap_function because of 298 # TypeError: Expected tf.group() expected Tensor arguments not 'None' with 299 # type '<type 'NoneType'>'. 300 @test_util.run_in_graph_and_eager_modes 301 def testInitFromNonTensorValue(self): 302 v = variable_scope.get_variable("v4", initializer=4, dtype=dtypes.int32) 303 self.evaluate(variables_lib.variables_initializer([v])) 304 self.assertAllClose(self.evaluate(v.value()), 4) 305 306 w = variable_scope.get_variable( 307 "w4", initializer=numpy.array([1, 2, 3]), dtype=dtypes.int64) 308 self.evaluate(variables_lib.variables_initializer([w])) 309 self.assertAllClose(self.evaluate(w.value()), [1, 2, 3]) 310 311 # A quirk to be revisited? 312 error = ValueError if context.executing_eagerly() else TypeError 313 with self.assertRaises(error): 314 variable_scope.get_variable("x4", initializer={}) 315 316 # TODO(mihaimaruseac): Not converted to use wrap_function because of 317 # InvalidArgumentError=: You must feed a value for placeholder tensor 318 # 'ReadVariableOp/resource' with dtype resource 319 @test_util.run_in_graph_and_eager_modes 320 def testInitFromNonInitializer(self): 321 # Test various dtypes with zeros initializer as following: 322 types = [ 323 dtypes.int8, dtypes.uint8, dtypes.int16, dtypes.uint16, dtypes.int32, 324 dtypes.int64, dtypes.bool 325 ] 326 327 # Use different variable_name to distinguish various dtypes 328 for (i, dtype) in enumerate(types): 329 x = variable_scope.get_variable( 330 name="xx%d" % i, shape=(3, 4), dtype=dtype) 331 y = variable_scope.get_variable( 332 name="yy%d" % i, 333 shape=(3, 4), 334 dtype=dtype, 335 initializer=init_ops.zeros_initializer(dtype=dtype)) 336 337 self.evaluate(variables_lib.global_variables_initializer()) 338 self.assertAllEqual(self.evaluate(x.value()), self.evaluate(y.value())) 339 340 # TODO(mihaimaruseac): Not converted to use wrap_function because of 341 # InvalidArgumentError: /job:moo/replica:0/task:0/device:CPU:0 unknown device. 342 @test_util.run_deprecated_v1 343 def testVarScopeCachingDevice(self): 344 with self.cached_session(): 345 caching_device = "/job:moo" 346 with variable_scope.variable_scope("tower"): 347 with variable_scope.variable_scope( 348 "caching", caching_device=caching_device): 349 v = variable_scope.get_variable("v", []) 350 self.assertTrue(v.value().device.startswith(caching_device)) 351 352 with variable_scope.variable_scope("child"): 353 v2 = variable_scope.get_variable("v", []) 354 self.assertTrue(v2.value().device.startswith(caching_device)) 355 356 with variable_scope.variable_scope("not_cached", caching_device=""): 357 v2_not_cached = variable_scope.get_variable("v", []) 358 self.assertFalse( 359 v2_not_cached.value().device.startswith(caching_device)) 360 361 with variable_scope.variable_scope( 362 "not_cached_identity_device", 363 caching_device=lambda op: op.device): 364 v2_identity_device = variable_scope.get_variable("v", []) 365 self.assertFalse( 366 v2_identity_device.value().device.startswith(caching_device)) 367 368 with variable_scope.variable_scope("we_will_do_it_live") as vs_live: 369 vs_live.set_caching_device("/job:live") 370 v_live = variable_scope.get_variable("v", []) 371 self.assertTrue(v_live.value().device.startswith("/job:live")) 372 373 v_tower = variable_scope.get_variable("v", []) 374 self.assertFalse(v_tower.value().device.startswith(caching_device)) 375 376 # TODO(mihaimaruseac): Not converted to use wrap_function because of 377 # AttributeError: Tensor.name is meaningless when eager execution is enabled. 378 @test_util.run_in_graph_and_eager_modes 379 def testVarScopeRegularizer(self): 380 init = init_ops.constant_initializer(0.3) 381 382 def regularizer1(v): 383 return math_ops.reduce_mean(v) + 0.1 384 385 def regularizer2(v): 386 return math_ops.reduce_mean(v) + 0.2 387 388 with variable_scope.variable_scope( 389 "tower3", regularizer=regularizer1) as tower: 390 with variable_scope.variable_scope("foo", initializer=init): 391 v = variable_scope.get_variable("v", []) 392 self.evaluate(variables_lib.variables_initializer([v])) 393 losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) 394 self.assertEqual(1, len(losses)) 395 self.assertAllClose(self.evaluate(losses[0]), 0.4) 396 with variable_scope.variable_scope(tower, initializer=init) as vs: 397 u = variable_scope.get_variable("u", []) 398 vs.set_regularizer(regularizer2) 399 w = variable_scope.get_variable("w", []) 400 # Next 3 variable not regularized to test disabling regularization. 401 x = variable_scope.get_variable( 402 "x", [], regularizer=variable_scope.no_regularizer) 403 with variable_scope.variable_scope( 404 "baz", regularizer=variable_scope.no_regularizer): 405 y = variable_scope.get_variable("y", []) 406 vs.set_regularizer(variable_scope.no_regularizer) 407 z = variable_scope.get_variable("z", []) 408 # Check results. 409 losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) 410 self.assertEqual(3, len(losses)) 411 self.evaluate(variables_lib.variables_initializer([u, w, x, y, z])) 412 self.assertAllClose(self.evaluate(losses[0]), 0.4) 413 self.assertAllClose(self.evaluate(losses[1]), 0.4) 414 self.assertAllClose(self.evaluate(losses[2]), 0.5) 415 with variable_scope.variable_scope("foo", reuse=True): 416 # reuse=True is for now only supported when eager execution is disabled. 417 if not context.executing_eagerly(): 418 v = variable_scope.get_variable("v", 419 []) # "v" is already there, reused 420 losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) 421 self.assertEqual(3, len(losses)) # No new loss added. 422 423 # TODO(mihaimaruseac): Not converted to use wrap_function because of 424 # ValueError: Tensor-typed variable initializers must either be wrapped in an 425 # init_scope or callable... 426 @test_util.run_in_graph_and_eager_modes 427 def testInitializeFromValue(self): 428 init = constant_op.constant(0.1) 429 w = variable_scope.get_variable("v", initializer=init) 430 self.evaluate(variables_lib.variables_initializer([w])) 431 self.assertAllClose(self.evaluate(w.value()), 0.1) 432 433 with self.assertRaisesRegex(ValueError, "shape"): 434 # We disallow explicit shape specification when initializer is constant. 435 variable_scope.get_variable("u", [1], initializer=init) 436 437 with variable_scope.variable_scope("foo", initializer=init): 438 # Constant initializer can be passed through scopes if needed. 439 v = variable_scope.get_variable("v") 440 self.evaluate(variables_lib.variables_initializer([v])) 441 self.assertAllClose(self.evaluate(v.value()), 0.1) 442 443 # Check that non-float32 initializer creates a non-float32 variable. 444 init = constant_op.constant(1, dtype=dtypes.int32) 445 t = variable_scope.get_variable("t", initializer=init) 446 self.assertEqual(t.dtype.base_dtype, dtypes.int32) 447 448 # Raise error if `initializer` dtype and `dtype` are not identical. 449 with self.assertRaisesRegex(ValueError, "don't match"): 450 variable_scope.get_variable("s", initializer=init, dtype=dtypes.float64) 451 452 # TODO(mihaimaruseac): Not converted to use wrap_function because of 453 # TypeError: Fetch argument <tf.Variable 'v0:0' shape=(1,) dtype=float32> has 454 # invalid type <class '...ops.resource_variable_ops.ResourceVariable'>, must 455 # be a string or Tensor. (Can not convert a ResourceVariable into a Tensor or 456 # Operation.) 457 @test_util.run_deprecated_v1 458 def testControlDeps(self): 459 with self.cached_session() as sess: 460 v0 = variable_scope.get_variable( 461 "v0", [1], initializer=init_ops.constant_initializer(0)) 462 with ops.control_dependencies([v0.value()]): 463 v1 = variable_scope.get_variable( 464 "v1", [1], initializer=init_ops.constant_initializer(1)) 465 add = v1 + v0 466 # v0 should be uninitialized. 467 with self.assertRaisesRegex(errors.OpError, "uninitialized"): 468 self.evaluate(v0) 469 # We should be able to initialize and run v1 without initializing 470 # v0, even if the variable was created with a control dep on v0. 471 self.evaluate(v1.initializer) 472 self.assertEqual(1, self.evaluate(v1)) 473 # v0 should still be uninitialized. 474 with self.assertRaisesRegex(errors.OpError, "uninitialized"): 475 self.evaluate(v0) 476 with self.assertRaisesRegex(errors.OpError, "uninitialized"): 477 self.evaluate(add) 478 # If we initialize v0 we should be able to run 'add'. 479 self.evaluate(v0.initializer) 480 self.evaluate(add) 481 482 # TODO(mihaimaruseac): Not converted to use wrap_function because of 483 # AssertionError: True is not false (last assertFalse) 484 @test_util.run_deprecated_v1 485 def testEnableResourceVariables(self): 486 old = variable_scope._DEFAULT_USE_RESOURCE 487 try: 488 variable_scope.enable_resource_variables() 489 self.assertTrue(isinstance(variables_lib.VariableV1(1.0), 490 resource_variable_ops.ResourceVariable)) 491 variable_scope.disable_resource_variables() 492 self.assertFalse(isinstance(variables_lib.VariableV1(1.0), 493 resource_variable_ops.ResourceVariable)) 494 finally: 495 variable_scope._DEFAULT_USE_RESOURCE = old 496 497 # TODO(mihaimaruseac): Not converted to use wrap_function because of 498 # TypeError: Fetch argument None has invalid type <type 'NoneType'> 499 @test_util.run_deprecated_v1 500 def testControlFlow(self): 501 with self.cached_session() as sess: 502 v0 = variable_scope.get_variable( 503 "v0", [], initializer=init_ops.constant_initializer(0)) 504 var_dict = {} 505 506 # Call get_variable in each of the cond clauses. 507 def var_in_then_clause(): 508 v1 = variable_scope.get_variable( 509 "v1", [1], initializer=init_ops.constant_initializer(1)) 510 var_dict["v1"] = v1 511 return v1 + v0 512 513 def var_in_else_clause(): 514 v2 = variable_scope.get_variable( 515 "v2", [1], initializer=init_ops.constant_initializer(2)) 516 var_dict["v2"] = v2 517 return v2 + v0 518 519 add = control_flow_ops.cond( 520 math_ops.less(v0, 10), var_in_then_clause, var_in_else_clause) 521 v1 = var_dict["v1"] 522 v2 = var_dict["v2"] 523 # We should be able to initialize and run v1 and v2 without initializing 524 # v0, even if the variable was created with a control dep on v0. 525 self.evaluate(v1.initializer) 526 self.assertEqual([1], self.evaluate(v1)) 527 self.evaluate(v2.initializer) 528 self.assertEqual([2], self.evaluate(v2)) 529 # v0 should still be uninitialized. 530 with self.assertRaisesRegex(errors.OpError, "uninitialized"): 531 self.evaluate(v0) 532 # We should not be able to run 'add' yet. 533 with self.assertRaisesRegex(errors.OpError, "uninitialized"): 534 self.evaluate(add) 535 # If we initialize v0 we should be able to run 'add'. 536 self.evaluate(v0.initializer) 537 self.evaluate(add) 538 539 # TODO(mihaimaruseac): Not converted to use wrap_function because of 540 # TypeError: Expected tf.group() expected Tensor arguments not 'None' with 541 # type '<type 'NoneType'>'. 542 @test_util.run_in_graph_and_eager_modes 543 def testGetVariableScope(self): 544 # Test the get_variable_scope() function and setting properties of result. 545 init = init_ops.constant_initializer(0.3) 546 with variable_scope.variable_scope("bar"): 547 new_init1 = variable_scope.get_variable_scope().initializer 548 self.assertEqual(new_init1, None) 549 # Check that we can set initializer like this. 550 variable_scope.get_variable_scope().set_initializer(init) 551 v = variable_scope.get_variable("v", []) 552 self.evaluate(variables_lib.variables_initializer([v])) 553 self.assertAllClose(self.evaluate(v.value()), 0.3) 554 if not context.executing_eagerly(): 555 # Check that we can set reuse. 556 variable_scope.get_variable_scope().reuse_variables() 557 with self.assertRaises(ValueError): # Fail, w does not exist yet. 558 variable_scope.get_variable("w", [1]) 559 # Check that the set initializer goes away. 560 new_init = variable_scope.get_variable_scope().initializer 561 self.assertEqual(new_init, None) 562 563 @test_util.run_in_graph_and_eager_modes 564 @run_inside_wrap_function_in_eager_mode 565 def testVarScope(self): 566 with variable_scope.variable_scope("tower4") as tower: 567 self.assertEqual(tower.name, "tower4") 568 with ops.name_scope("scope") as sc: 569 self.assertEqual(sc, "tower4/scope/") 570 571 with variable_scope.variable_scope("tower5"): 572 with variable_scope.variable_scope("bar") as bar: 573 self.assertEqual(bar.name, "tower5/bar") 574 with ops.name_scope("scope") as sc: 575 self.assertEqual(sc, "tower5/bar/scope/") 576 577 with variable_scope.variable_scope("tower6"): 578 with variable_scope.variable_scope(tower, reuse=True) as tower_shared: 579 self.assertEqual(tower_shared.name, "tower4") 580 with ops.name_scope("scope") as sc: 581 self.assertEqual(sc, "tower6/tower4/scope/") 582 583 @test_util.run_in_graph_and_eager_modes 584 @run_inside_wrap_function_in_eager_mode 585 def testVarScopeNameScope(self): 586 with ops.name_scope("testVarScopeNameScope1"): 587 with variable_scope.variable_scope("tower") as tower: 588 with ops.name_scope("scope2") as sc2: 589 self.assertEqual(sc2, "testVarScopeNameScope1/tower/scope2/") 590 if not context.executing_eagerly(): 591 with variable_scope.variable_scope( 592 tower): # Re-entering acts like another "tower". 593 with ops.name_scope("scope2") as sc2: 594 self.assertEqual(sc2, "testVarScopeNameScope1/tower_1/scope2/") 595 with variable_scope.variable_scope( 596 "tower"): # Re-entering by string acts the same. 597 with ops.name_scope("scope2") as sc2: 598 self.assertEqual(sc2, "testVarScopeNameScope1/tower_2/scope2/") 599 600 with ops.name_scope("testVarScopeNameScope2"): 601 with variable_scope.variable_scope("tower"): 602 with ops.name_scope("scope2") as sc2: 603 self.assertEqual(sc2, "testVarScopeNameScope2/tower/scope2/") 604 if not context.executing_eagerly(): 605 with variable_scope.variable_scope(tower): 606 with ops.name_scope("scope2") as sc2: 607 self.assertEqual(sc2, "testVarScopeNameScope2/tower_1/scope2/") 608 609 root_var_scope = variable_scope.get_variable_scope() 610 with ops.name_scope("testVarScopeNameScope3"): 611 with variable_scope.variable_scope(root_var_scope): 612 with ops.name_scope("scope2") as sc2: 613 self.assertEqual(sc2, "testVarScopeNameScope3/scope2/") 614 615 @test_util.run_in_graph_and_eager_modes 616 @run_inside_wrap_function_in_eager_mode 617 def testVarScopeOriginalNameScope(self): 618 with self.cached_session(): 619 with ops.name_scope("scope1"): 620 with variable_scope.variable_scope("tower") as tower: 621 self.assertEqual(tower.original_name_scope, "scope1/tower/") 622 with ops.name_scope("scope2") as sc2: 623 self.assertEqual(sc2, "scope1/tower/scope2/") 624 with ops.name_scope("scope2"): 625 with variable_scope.variable_scope(tower) as tower1: 626 # Re-entering preserves original name scope. 627 self.assertEqual(tower1.original_name_scope, "scope1/tower/") 628 with ops.name_scope("foo") as sc2: 629 self.assertEqual(sc2, "scope2/tower/foo/") 630 # Test re-entering original name scope. 631 with ops.name_scope(tower.original_name_scope): 632 with ops.name_scope("bar") as sc3: 633 self.assertEqual(sc3, "scope1/tower/bar/") 634 with ops.name_scope("scope2"): 635 with variable_scope.variable_scope(tower): 636 with ops.name_scope(tower.original_name_scope): 637 with ops.name_scope("bar") as sc3: 638 self.assertEqual(sc3, "scope1/tower/bar_1/") 639 640 @test_util.run_in_graph_and_eager_modes 641 @run_inside_wrap_function_in_eager_mode 642 def testVarScopeObjectReuse(self): 643 with self.cached_session(): 644 vs = None 645 with variable_scope.variable_scope("jump", reuse=True) as scope: 646 vs = scope 647 648 with variable_scope.variable_scope(vs) as jump: 649 self.assertTrue(jump.reuse) 650 651 with variable_scope.variable_scope(vs, reuse=True) as jump_reuse: 652 self.assertTrue(jump_reuse.reuse) 653 654 with variable_scope.variable_scope(vs, reuse=False) as jump_no_reuse: 655 self.assertTrue(jump_no_reuse.reuse) # Inherited, cannot be undone. 656 657 with variable_scope.variable_scope("jump", reuse=False) as scope: 658 vs = scope 659 660 with variable_scope.variable_scope(vs) as jump: 661 self.assertFalse(jump.reuse) 662 663 with variable_scope.variable_scope(vs, reuse=True) as jump_reuse: 664 self.assertTrue(jump_reuse.reuse) 665 666 with variable_scope.variable_scope(vs, reuse=False) as jump_no_reuse: 667 self.assertFalse(jump_no_reuse.reuse) 668 669 @test_util.run_in_graph_and_eager_modes 670 @run_inside_wrap_function_in_eager_mode 671 def testVarScopeGetOrCreateReuse(self): 672 with self.cached_session(): 673 674 def test_value(value): 675 x = constant_op.constant(value) 676 with variable_scope.variable_scope( 677 "testVarScopeGetOrCreateReuse_bar", 678 reuse=variable_scope.AUTO_REUSE): 679 _ = state_ops.assign(variable_scope.get_variable("var", []), x) 680 with variable_scope.variable_scope( 681 "testVarScopeGetOrCreateReuse_bar", 682 reuse=variable_scope.AUTO_REUSE): 683 _ = variable_scope.get_variable("var", []) 684 self.assertEqual(value, self.evaluate(x)) 685 686 test_value(42.) # Variable is created. 687 test_value(13.) # Variable is reused hereafter. 688 test_value(17.) 689 690 @test_util.run_in_graph_and_eager_modes 691 @run_inside_wrap_function_in_eager_mode 692 def testVarOpScope(self): 693 with self.cached_session(): 694 with ops.name_scope("testVarOpScope1"): 695 with variable_scope.variable_scope("tower", "default", []): 696 self.assertEqual( 697 variable_scope.get_variable("w", []).name, "tower/w:0") 698 with ops.name_scope("testVarOpScope2") as sc2: 699 self.assertEqual(sc2, "testVarOpScope1/tower/testVarOpScope2/") 700 with variable_scope.variable_scope("tower", "default", []): 701 with self.assertRaises(ValueError): 702 variable_scope.get_variable("w", []) 703 with ops.name_scope("testVarOpScope2") as sc2: 704 self.assertEqual(sc2, "testVarOpScope1/tower_1/testVarOpScope2/") 705 706 with ops.name_scope("testVarOpScope2"): 707 with variable_scope.variable_scope(None, "default", []): 708 self.assertEqual( 709 variable_scope.get_variable("w", []).name, "default/w:0") 710 with ops.name_scope("testVarOpScope2") as sc2: 711 self.assertEqual(sc2, "testVarOpScope2/default/testVarOpScope2/") 712 with variable_scope.variable_scope(None, "default", []): 713 self.assertEqual( 714 variable_scope.get_variable("w", []).name, "default_1/w:0") 715 with ops.name_scope("testVarOpScope2") as sc2: 716 self.assertEqual(sc2, "testVarOpScope2/default_1/testVarOpScope2/") 717 718 @test_util.run_in_graph_and_eager_modes 719 @run_inside_wrap_function_in_eager_mode 720 def testVarOpScopeUniqueNamesInterleavedSubstringScopes(self): 721 with self.cached_session(): 722 with variable_scope.variable_scope(None, "defaultScope1"): 723 with variable_scope.variable_scope(None, "layer"): 724 self.assertEqual( 725 variable_scope.get_variable("w", []).name, 726 "defaultScope1/layer/w:0") 727 with variable_scope.variable_scope(None, "defaultScope1"): 728 with variable_scope.variable_scope(None, "layer"): 729 self.assertEqual( 730 variable_scope.get_variable("w", []).name, 731 "defaultScope1_1/layer/w:0") 732 with variable_scope.variable_scope(None, "defaultScope"): 733 with variable_scope.variable_scope(None, "layer"): 734 self.assertEqual( 735 variable_scope.get_variable("w", []).name, 736 "defaultScope/layer/w:0") 737 with variable_scope.variable_scope(None, "defaultScope1"): 738 with variable_scope.variable_scope(None, "layer"): 739 self.assertEqual( 740 variable_scope.get_variable("w", []).name, 741 "defaultScope1_2/layer/w:0") 742 743 @test_util.run_in_graph_and_eager_modes 744 @run_inside_wrap_function_in_eager_mode 745 def testVarOpScopeUniqueNamesWithJump(self): 746 with self.cached_session(): 747 with variable_scope.variable_scope("default") as default: 748 with variable_scope.variable_scope(None, "layer"): 749 self.assertEqual( 750 variable_scope.get_variable("w", []).name, "default/layer/w:0") 751 with variable_scope.variable_scope(None, "layer"): 752 self.assertEqual( 753 variable_scope.get_variable("w", []).name, 754 "default/layer_1/w:0") 755 with variable_scope.variable_scope(default): 756 pass 757 # No matter the jump in the middle, unique numbering continues. 758 with variable_scope.variable_scope(None, "layer"): 759 self.assertEqual( 760 variable_scope.get_variable("w", []).name, 761 "default/layer_2/w:0") 762 763 @test_util.run_in_graph_and_eager_modes 764 @run_inside_wrap_function_in_eager_mode 765 def testVarOpScopeReuse(self): 766 with self.cached_session(): 767 with variable_scope.variable_scope("outer") as outer: 768 with variable_scope.variable_scope("tower", "default", []): 769 self.assertEqual( 770 variable_scope.get_variable("w", []).name, "outer/tower/w:0") 771 with ops.name_scope("scope2") as sc2: 772 self.assertEqual(sc2, "outer/tower/scope2/") 773 with variable_scope.variable_scope(None, "default", []): 774 self.assertEqual( 775 variable_scope.get_variable("w", []).name, "outer/default/w:0") 776 with ops.name_scope("scope2") as sc2: 777 self.assertEqual(sc2, "outer/default/scope2/") 778 779 with variable_scope.variable_scope(outer, reuse=True) as outer: 780 with variable_scope.variable_scope("tower", "default", []): 781 self.assertEqual( 782 variable_scope.get_variable("w", []).name, "outer/tower/w:0") 783 with ops.name_scope("scope2") as sc2: 784 self.assertEqual(sc2, "outer_1/tower/scope2/") 785 with variable_scope.variable_scope(None, "default", []): 786 self.assertEqual( 787 variable_scope.get_variable("w", []).name, "outer/default/w:0") 788 with ops.name_scope("scope2") as sc2: 789 self.assertEqual(sc2, "outer_1/default/scope2/") 790 791 @test_util.run_in_graph_and_eager_modes 792 @run_inside_wrap_function_in_eager_mode 793 def testVarScopeGetVar(self): 794 with self.cached_session(): 795 with variable_scope.variable_scope("root"): 796 with variable_scope.variable_scope("towerA") as tower_a: 797 va = variable_scope.get_variable("v", [1]) 798 self.assertEqual(va.name, "root/towerA/v:0") 799 800 with variable_scope.variable_scope(tower_a, reuse=True): 801 va2 = variable_scope.get_variable("v", [1]) 802 self.assertIs(va2, va) 803 804 with variable_scope.variable_scope("towerB"): 805 vb = variable_scope.get_variable("v", [1]) 806 self.assertEqual(vb.name, "root/towerB/v:0") 807 808 with self.assertRaises(ValueError): 809 with variable_scope.variable_scope("towerA"): 810 va2 = variable_scope.get_variable("v", [1]) 811 812 with variable_scope.variable_scope("towerA", reuse=True): 813 va2 = variable_scope.get_variable("v", [1]) 814 self.assertIs(va2, va) 815 816 with variable_scope.variable_scope("foo"): 817 with variable_scope.variable_scope("bar"): 818 v = variable_scope.get_variable("v", [1]) 819 self.assertEqual(v.name, "root/foo/bar/v:0") 820 with variable_scope.variable_scope(tower_a, reuse=True): 821 va3 = variable_scope.get_variable("v", [1]) 822 self.assertIs(va, va3) 823 824 with self.assertRaises(ValueError): 825 with variable_scope.variable_scope(tower_a, reuse=True): 826 with variable_scope.variable_scope("baz"): 827 variable_scope.get_variable("v", [1]) 828 829 with self.assertRaises(ValueError) as exc: 830 with variable_scope.variable_scope(tower_a, reuse=True): 831 variable_scope.get_variable("v", [2]) # Different shape. 832 self.assertEqual("shape" in str(exc.exception), True) 833 834 with self.assertRaises(ValueError) as exc: 835 with variable_scope.variable_scope(tower_a, reuse=True): 836 variable_scope.get_variable("v", [1], dtype=dtypes.int32) 837 self.assertEqual("dtype" in str(exc.exception), True) 838 839 @test_util.run_in_graph_and_eager_modes 840 @run_inside_wrap_function_in_eager_mode 841 def testVarScopeOuterScope(self): 842 with self.cached_session(): 843 with variable_scope.variable_scope("outer") as outer: 844 pass 845 with variable_scope.variable_scope(outer): 846 self.assertEqual( 847 variable_scope.get_variable("w", []).name, "outer/w:0") 848 with ops.name_scope("scope2") as sc2: 849 self.assertEqual(sc2, "outer_1/scope2/") 850 with variable_scope.variable_scope("default"): 851 self.assertEqual( 852 variable_scope.get_variable("w", []).name, "outer/default/w:0") 853 with ops.name_scope("scope2") as sc2: 854 self.assertEqual(sc2, "outer_1/default/scope2/") 855 856 with variable_scope.variable_scope(outer, reuse=True): 857 self.assertEqual( 858 variable_scope.get_variable("w", []).name, "outer/w:0") 859 with ops.name_scope("scope2") as sc2: 860 self.assertEqual(sc2, "outer_2/scope2/") 861 with variable_scope.variable_scope("default", reuse=True): 862 self.assertEqual( 863 variable_scope.get_variable("w", []).name, "outer/default/w:0") 864 with ops.name_scope("scope2") as sc2: 865 self.assertEqual(sc2, "outer_2/default/scope2/") 866 867 @test_util.run_in_graph_and_eager_modes 868 @run_inside_wrap_function_in_eager_mode 869 def testVarScopeNestedOuterScope(self): 870 with self.cached_session(): 871 with variable_scope.variable_scope("outer") as outer: 872 with variable_scope.variable_scope(outer): 873 self.assertEqual( 874 variable_scope.get_variable("w", []).name, "outer/w:0") 875 with ops.name_scope("scope2") as sc2: 876 self.assertEqual(sc2, "outer/outer/scope2/") 877 with variable_scope.variable_scope("default"): 878 self.assertEqual( 879 variable_scope.get_variable("w", []).name, "outer/default/w:0") 880 with ops.name_scope("scope2") as sc2: 881 self.assertEqual(sc2, "outer/default/scope2/") 882 883 with variable_scope.variable_scope(outer, reuse=True): 884 self.assertEqual( 885 variable_scope.get_variable("w", []).name, "outer/w:0") 886 with ops.name_scope("scope2") as sc2: 887 self.assertEqual(sc2, "outer/outer_1/scope2/") 888 with variable_scope.variable_scope("default", reuse=True): 889 self.assertEqual( 890 variable_scope.get_variable("w", []).name, "outer/default/w:0") 891 with ops.name_scope("scope2") as sc2: 892 self.assertEqual(sc2, "outer/default_1/scope2/") 893 894 @test_util.run_in_graph_and_eager_modes 895 @run_inside_wrap_function_in_eager_mode 896 def testVarOpScopeReuseParam(self): 897 with self.cached_session(): 898 with variable_scope.variable_scope("outer") as outer: 899 with variable_scope.variable_scope("tower", "default", []): 900 self.assertEqual( 901 variable_scope.get_variable("w", []).name, "outer/tower/w:0") 902 with ops.name_scope("scope2") as sc2: 903 self.assertEqual(sc2, "outer/tower/scope2/") 904 with variable_scope.variable_scope(None, "default", []): 905 self.assertEqual( 906 variable_scope.get_variable("w", []).name, "outer/default/w:0") 907 with ops.name_scope("scope2") as sc2: 908 self.assertEqual(sc2, "outer/default/scope2/") 909 910 with variable_scope.variable_scope(outer) as outer: 911 with variable_scope.variable_scope("tower", "default", reuse=True): 912 self.assertEqual( 913 variable_scope.get_variable("w", []).name, "outer/tower/w:0") 914 with ops.name_scope("scope2") as sc2: 915 self.assertEqual(sc2, "outer_1/tower/scope2/") 916 outer.reuse_variables() 917 with variable_scope.variable_scope(None, "default", []): 918 self.assertEqual( 919 variable_scope.get_variable("w", []).name, "outer/default/w:0") 920 with ops.name_scope("scope2") as sc2: 921 self.assertEqual(sc2, "outer_1/default/scope2/") 922 923 @test_util.run_in_graph_and_eager_modes 924 @run_inside_wrap_function_in_eager_mode 925 def testVarOpScopeReuseError(self): 926 with self.cached_session(): 927 with self.assertRaises(ValueError): 928 with variable_scope.variable_scope(None, "default", reuse=True): 929 self.assertEqual( 930 variable_scope.get_variable("w", []).name, "outer/tower/w:0") 931 932 @test_util.run_in_graph_and_eager_modes 933 @run_inside_wrap_function_in_eager_mode 934 def testVarOpScopeOuterScope(self): 935 with self.cached_session(): 936 with variable_scope.variable_scope("outer") as outer: 937 pass 938 with variable_scope.variable_scope(outer, "default", []): 939 self.assertEqual( 940 variable_scope.get_variable("w", []).name, "outer/w:0") 941 with ops.name_scope("scope2") as sc2: 942 self.assertEqual(sc2, "outer_1/scope2/") 943 with variable_scope.variable_scope(None, "default", []): 944 self.assertEqual( 945 variable_scope.get_variable("w", []).name, "outer/default/w:0") 946 with ops.name_scope("scope2") as sc2: 947 self.assertEqual(sc2, "outer_1/default/scope2/") 948 949 with variable_scope.variable_scope(outer, "default", reuse=True): 950 self.assertEqual( 951 variable_scope.get_variable("w", []).name, "outer/w:0") 952 with ops.name_scope("scope2") as sc2: 953 self.assertEqual(sc2, "outer_2/scope2/") 954 outer.reuse_variables() 955 with variable_scope.variable_scope(None, "default", []): 956 self.assertEqual( 957 variable_scope.get_variable("w", []).name, "outer/default/w:0") 958 with ops.name_scope("scope2") as sc2: 959 self.assertEqual(sc2, "outer_2/default/scope2/") 960 961 @test_util.run_in_graph_and_eager_modes 962 @run_inside_wrap_function_in_eager_mode 963 def testVarOpScopeNestedOuterScope(self): 964 with self.cached_session(): 965 with variable_scope.variable_scope("outer") as outer: 966 with variable_scope.variable_scope(outer, "default", []): 967 self.assertEqual( 968 variable_scope.get_variable("w", []).name, "outer/w:0") 969 with ops.name_scope("scope2") as sc2: 970 self.assertEqual(sc2, "outer/outer/scope2/") 971 with variable_scope.variable_scope(None, "default", []): 972 self.assertEqual( 973 variable_scope.get_variable("w", []).name, "outer/default/w:0") 974 with ops.name_scope("scope2") as sc2: 975 self.assertEqual(sc2, "outer/default/scope2/") 976 977 with variable_scope.variable_scope(outer, "default", reuse=True): 978 self.assertEqual( 979 variable_scope.get_variable("w", []).name, "outer/w:0") 980 with ops.name_scope("scope2") as sc2: 981 self.assertEqual(sc2, "outer_1/scope2/") 982 with variable_scope.variable_scope(None, "default", []): 983 self.assertEqual( 984 variable_scope.get_variable("w", []).name, "outer/default/w:0") 985 with ops.name_scope("scope2") as sc2: 986 self.assertEqual(sc2, "outer_1/default/scope2/") 987 988 @test_util.run_in_graph_and_eager_modes 989 @run_inside_wrap_function_in_eager_mode 990 def testBasicWhenAuxiliaryNameScopeIsFalse(self): 991 with self.cached_session(): 992 with variable_scope.variable_scope( 993 "scope", auxiliary_name_scope=False) as scope: 994 self.assertEqual(scope.original_name_scope, "") 995 self.assertEqual( 996 variable_scope.get_variable("w", []).name, "scope/w:0") 997 self.assertEqual(constant_op.constant([], name="c").name, "c:0") 998 with variable_scope.variable_scope(scope, auxiliary_name_scope=False): 999 self.assertEqual(scope.original_name_scope, "") 1000 self.assertEqual( 1001 variable_scope.get_variable("w1", []).name, "scope/w1:0") 1002 self.assertEqual(constant_op.constant([], name="c1").name, "c1:0") 1003 # Recheck: new name scope is NOT created before 1004 with ops.name_scope("scope"): 1005 self.assertEqual(constant_op.constant([], name="c").name, "scope/c:0") 1006 1007 with variable_scope.variable_scope("outer"): 1008 with variable_scope.variable_scope( 1009 "inner", auxiliary_name_scope=False) as inner: 1010 self.assertEqual(inner.original_name_scope, "outer/") 1011 self.assertEqual( 1012 variable_scope.get_variable("w", []).name, "outer/inner/w:0") 1013 self.assertEqual( 1014 constant_op.constant([], name="c").name, "outer/c:0") 1015 with variable_scope.variable_scope( 1016 inner, auxiliary_name_scope=False) as inner1: 1017 self.assertEqual(inner1.original_name_scope, "outer/") 1018 self.assertEqual( 1019 variable_scope.get_variable("w1", []).name, "outer/inner/w1:0") 1020 self.assertEqual( 1021 constant_op.constant([], name="c1").name, "outer/c1:0") 1022 # Recheck: new name scope is NOT created before 1023 with ops.name_scope("inner"): 1024 self.assertEqual( 1025 constant_op.constant([], name="c").name, "outer/inner/c:0") 1026 1027 @test_util.run_in_graph_and_eager_modes 1028 @run_inside_wrap_function_in_eager_mode 1029 def testCreatedByDefaultNameWhenAuxiliaryNameScopeIsFalse(self): 1030 with self.cached_session(): 1031 with variable_scope.variable_scope( 1032 None, default_name="default", auxiliary_name_scope=False) as scope: 1033 self.assertEqual(scope.original_name_scope, "") 1034 self.assertEqual( 1035 variable_scope.get_variable("w", []).name, "default/w:0") 1036 self.assertEqual(constant_op.constant([], name="c").name, "c:0") 1037 # Recheck: new name scope is NOT created before 1038 with ops.name_scope("default"): 1039 self.assertEqual( 1040 constant_op.constant([], name="c").name, "default/c:0") 1041 1042 with variable_scope.variable_scope("outer"): 1043 with variable_scope.variable_scope( 1044 None, default_name="default", 1045 auxiliary_name_scope=False) as inner: 1046 self.assertEqual(inner.original_name_scope, "outer/") 1047 self.assertEqual( 1048 variable_scope.get_variable("w", []).name, "outer/default/w:0") 1049 self.assertEqual( 1050 constant_op.constant([], name="c").name, "outer/c:0") 1051 # Recheck: new name scope is NOT created before 1052 with ops.name_scope("default"): 1053 self.assertEqual( 1054 constant_op.constant([], name="c").name, "outer/default/c:0") 1055 1056 @test_util.run_in_graph_and_eager_modes 1057 @run_inside_wrap_function_in_eager_mode 1058 def testReenterRootScopeWhenAuxiliaryNameScopeIsFalse(self): 1059 with self.cached_session(): 1060 root_scope = variable_scope.get_variable_scope() 1061 with variable_scope.variable_scope( 1062 root_scope, auxiliary_name_scope=False) as scope: 1063 self.assertEqual(scope.original_name_scope, "") 1064 self.assertEqual(variable_scope.get_variable("w", []).name, "w:0") 1065 self.assertEqual(constant_op.constant([], name="c").name, "c:0") 1066 1067 with variable_scope.variable_scope("outer"): 1068 with variable_scope.variable_scope( 1069 root_scope, auxiliary_name_scope=False) as inner: 1070 self.assertEqual(inner.original_name_scope, "") 1071 self.assertEqual(variable_scope.get_variable("w1", []).name, "w1:0") 1072 self.assertEqual( 1073 constant_op.constant([], name="c1").name, "outer/c1:0") 1074 1075 @test_util.run_in_graph_and_eager_modes 1076 @run_inside_wrap_function_in_eager_mode 1077 def testAuxiliaryNameScopeIsInvalid(self): 1078 with self.cached_session(): 1079 with self.assertRaisesRegex(TypeError, "auxiliary_name_scope"): 1080 with variable_scope.variable_scope( 1081 None, default_name="scope", auxiliary_name_scope="invalid"): 1082 pass 1083 1084 with self.assertRaisesRegex(TypeError, "auxiliary_name_scope"): 1085 with variable_scope.variable_scope( 1086 "scope", auxiliary_name_scope="invalid"): 1087 pass 1088 1089 with variable_scope.variable_scope("scope") as scope: 1090 pass 1091 with self.assertRaisesRegex(TypeError, "auxiliary_name_scope"): 1092 with variable_scope.variable_scope( 1093 scope, auxiliary_name_scope="invalid"): 1094 pass 1095 1096 @test_util.run_in_graph_and_eager_modes 1097 @run_inside_wrap_function_in_eager_mode 1098 def testReuseScopeWithoutNameScopeCollision(self): 1099 # Github issue: #13429 1100 with self.cached_session(): 1101 with variable_scope.variable_scope("outer"): 1102 with variable_scope.variable_scope("inner") as inner: 1103 pass 1104 1105 with variable_scope.variable_scope( 1106 inner, auxiliary_name_scope=False) as scope: 1107 with ops.name_scope(scope.original_name_scope): 1108 self.assertEqual( 1109 variable_scope.get_variable("w", []).name, "outer/inner/w:0") 1110 self.assertEqual( 1111 constant_op.constant([], name="c").name, "outer/inner/c:0") 1112 with ops.name_scope("inner"): 1113 self.assertEqual( 1114 constant_op.constant([], name="c").name, "inner/c:0") 1115 1116 with variable_scope.variable_scope("another"): 1117 with variable_scope.variable_scope( 1118 inner, auxiliary_name_scope=False) as scope1: 1119 with ops.name_scope(scope1.original_name_scope): 1120 self.assertEqual( 1121 variable_scope.get_variable("w1", []).name, 1122 "outer/inner/w1:0") 1123 self.assertEqual( 1124 constant_op.constant([], name="c1").name, "outer/inner/c1:0") 1125 with ops.name_scope("inner"): 1126 self.assertEqual( 1127 constant_op.constant([], name="c").name, "another/inner/c:0") 1128 1129 # TODO(mihaimaruseac): Not converted to use wrap_function because of 1130 # obtaining different results in the eager case compared to the graph one 1131 # (different assertions failing after wrapping, in both execution modes) 1132 @test_util.run_in_graph_and_eager_modes 1133 def testGetLocalVar(self): 1134 # Check that local variable respects naming. 1135 with variable_scope.variable_scope("outer") as outer: 1136 with variable_scope.variable_scope(outer, "default", []): 1137 local_var = variable_scope.get_local_variable( 1138 "w", [], collections=["foo"]) 1139 self.assertEqual(local_var.name, "outer/w:0") 1140 1141 if not context.executing_eagerly(): 1142 # Since variable is local, it should be in the local variable collection 1143 # but not the trainable collection. 1144 self.assertIn(local_var, 1145 ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) 1146 self.assertIn(local_var, ops.get_collection("foo")) 1147 self.assertNotIn(local_var, 1148 ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) 1149 # Check that local variable respects `reuse`. 1150 with variable_scope.variable_scope(outer, "default", reuse=True): 1151 self.assertEqual( 1152 variable_scope.get_local_variable("w", []).name, "outer/w:0") 1153 1154 @test_util.run_in_graph_and_eager_modes 1155 @run_inside_wrap_function_in_eager_mode 1156 def testSignatureGetVarVsGetLocalVar(self): 1157 """get_{local,}variable() must take the same list of args.""" 1158 arg_names = tf_inspect.getargspec(variable_scope.get_variable)[0] 1159 local_arg_names = tf_inspect.getargspec( 1160 variable_scope.get_local_variable)[0] 1161 self.assertEqual(arg_names, local_arg_names) 1162 1163 @test_util.run_in_graph_and_eager_modes 1164 @run_inside_wrap_function_in_eager_mode 1165 def testGetVarWithDevice(self): 1166 g = ops.Graph() 1167 varname_type = [] 1168 1169 def device_func(op): 1170 if op.type in ["Variable", "VariableV2", "VarHandleOp"]: 1171 varname_type.append((op.name, op.get_attr("dtype"))) 1172 return "/device:GPU:0" 1173 1174 with g.as_default(): 1175 with ops.device(device_func): 1176 _ = variable_scope.get_variable("x", (100, 200)) 1177 _ = variable_scope.get_variable( 1178 "y", dtype=dtypes.int64, initializer=numpy.arange(73)) 1179 self.assertEqual(varname_type[0], ("x", dtypes.float32)) 1180 self.assertEqual(varname_type[1], ("y", dtypes.int64)) 1181 1182 # TODO(mihaimaruseac): Not converted to use wrap_function because of 1183 # obtaining different results in the eager case compared to the graph one 1184 @test_util.run_deprecated_v1 1185 def testGetCollection(self): 1186 with self.cached_session(): 1187 _ = variable_scope.get_variable("testGetCollection_a", []) 1188 _ = variable_scope.get_variable( 1189 "testGetCollection_b", [], trainable=False) 1190 with variable_scope.variable_scope("testGetCollection_foo_") as scope1: 1191 _ = variable_scope.get_variable("testGetCollection_a", []) 1192 _ = variable_scope.get_variable( 1193 "testGetCollection_b", [], trainable=False) 1194 self.assertEqual([ 1195 v.name 1196 for v in scope1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) 1197 ], ["testGetCollection_foo_/testGetCollection_a:0"]) 1198 self.assertEqual([ 1199 v.name 1200 for v in scope1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 1201 ], [ 1202 "testGetCollection_foo_/testGetCollection_a:0", 1203 "testGetCollection_foo_/testGetCollection_b:0" 1204 ]) 1205 with variable_scope.variable_scope("testGetCollection_foo") as scope2: 1206 _ = variable_scope.get_variable("testGetCollection_a", []) 1207 _ = variable_scope.get_variable( 1208 "testGetCollection_b", [], trainable=False) 1209 self.assertEqual([ 1210 v.name 1211 for v in scope2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) 1212 ], ["testGetCollection_foo/testGetCollection_a:0"]) 1213 self.assertEqual([ 1214 v.name 1215 for v in scope2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 1216 ], [ 1217 "testGetCollection_foo/testGetCollection_a:0", 1218 "testGetCollection_foo/testGetCollection_b:0" 1219 ]) 1220 scope = variable_scope.get_variable_scope() 1221 self.assertEqual([ 1222 v.name for v in scope.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 1223 ], [ 1224 "testGetCollection_a:0", "testGetCollection_b:0", 1225 "testGetCollection_foo_/testGetCollection_a:0", 1226 "testGetCollection_foo_/testGetCollection_b:0", 1227 "testGetCollection_foo/testGetCollection_a:0", 1228 "testGetCollection_foo/testGetCollection_b:0" 1229 ]) 1230 self.assertEqual([ 1231 v.name 1232 for v in scope.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) 1233 ], [ 1234 "testGetCollection_a:0", 1235 "testGetCollection_foo_/testGetCollection_a:0", 1236 "testGetCollection_foo/testGetCollection_a:0" 1237 ]) 1238 1239 # TODO(mihaimaruseac): Not converted to use wrap_function because of 1240 # obtaining different results in the eager case compared to the graph one 1241 @test_util.run_deprecated_v1 1242 def testGetTrainableVariablesWithGetVariable(self): 1243 with self.cached_session(): 1244 _ = variable_scope.get_variable("testGetTrainableVariables_a", []) 1245 with variable_scope.variable_scope( 1246 "testGetTrainableVariables_foo") as scope: 1247 _ = variable_scope.get_variable("testGetTrainableVariables_b", []) 1248 _ = variable_scope.get_variable( 1249 "testGetTrainableVariables_c", [], trainable=False) 1250 1251 # sync `ON_READ` sets trainable=False 1252 _ = variable_scope.get_variable( 1253 "testGetTrainableVariables_d", [], 1254 synchronization=variable_scope.VariableSynchronization.ON_READ) 1255 self.assertEqual( 1256 [v.name for v in scope.trainable_variables()], 1257 ["testGetTrainableVariables_foo/testGetTrainableVariables_b:0"]) 1258 1259 _ = variable_scope.get_variable( 1260 "testGetTrainableVariables_e", [], 1261 synchronization=variable_scope.VariableSynchronization.ON_READ, 1262 trainable=True) 1263 self.assertEqual([v.name for v in scope.trainable_variables()], [ 1264 "testGetTrainableVariables_foo/testGetTrainableVariables_b:0", 1265 "testGetTrainableVariables_foo/testGetTrainableVariables_e:0", 1266 ]) 1267 1268 # All other sync values sets trainable=True 1269 _ = variable_scope.get_variable( 1270 "testGetTrainableVariables_f", [], 1271 synchronization=variable_scope.VariableSynchronization.ON_WRITE) 1272 self.assertEqual([v.name for v in scope.trainable_variables()], [ 1273 "testGetTrainableVariables_foo/testGetTrainableVariables_b:0", 1274 "testGetTrainableVariables_foo/testGetTrainableVariables_e:0", 1275 "testGetTrainableVariables_foo/testGetTrainableVariables_f:0", 1276 ]) 1277 1278 # TODO(mihaimaruseac): Not converted to use wrap_function because of 1279 # obtaining different results in the eager case compared to the graph one 1280 @test_util.run_deprecated_v1 1281 def testGetTrainableVariablesWithVariable(self): 1282 with self.cached_session(): 1283 _ = variable_scope.variable(1.0, name="testGetTrainableVariables_a") 1284 with variable_scope.variable_scope( 1285 "testGetTrainableVariables_foo") as scope: 1286 _ = variable_scope.variable(1.0, name="testGetTrainableVariables_b") 1287 _ = variable_scope.variable( 1288 1.0, name="testGetTrainableVariables_c", trainable=False) 1289 1290 # sync `ON_READ` sets trainable=False 1291 _ = variable_scope.variable( 1292 1.0, 1293 name="testGetTrainableVariables_d", 1294 synchronization=variable_scope.VariableSynchronization.ON_READ) 1295 self.assertEqual( 1296 [v.name for v in scope.trainable_variables()], 1297 ["testGetTrainableVariables_foo/testGetTrainableVariables_b:0"]) 1298 1299 _ = variable_scope.variable( 1300 1.0, 1301 name="testGetTrainableVariables_e", 1302 synchronization=variable_scope.VariableSynchronization.ON_READ, 1303 trainable=True) 1304 self.assertEqual([v.name for v in scope.trainable_variables()], [ 1305 "testGetTrainableVariables_foo/testGetTrainableVariables_b:0", 1306 "testGetTrainableVariables_foo/testGetTrainableVariables_e:0", 1307 ]) 1308 1309 # All other sync values sets trainable=True 1310 _ = variable_scope.variable( 1311 1.0, 1312 name="testGetTrainableVariables_f", 1313 synchronization=variable_scope.VariableSynchronization.ON_WRITE) 1314 self.assertEqual([v.name for v in scope.trainable_variables()], [ 1315 "testGetTrainableVariables_foo/testGetTrainableVariables_b:0", 1316 "testGetTrainableVariables_foo/testGetTrainableVariables_e:0", 1317 "testGetTrainableVariables_foo/testGetTrainableVariables_f:0", 1318 ]) 1319 1320 # TODO(mihaimaruseac): Not converted to use wrap_function because of 1321 # obtaining different results in the eager case compared to the graph one 1322 @test_util.run_deprecated_v1 1323 def testGetGlobalVariables(self): 1324 with self.cached_session(): 1325 _ = variable_scope.get_variable("testGetGlobalVariables_a", []) 1326 with variable_scope.variable_scope("testGetGlobalVariables_foo") as scope: 1327 _ = variable_scope.get_variable("testGetGlobalVariables_b", []) 1328 self.assertEqual( 1329 [v.name for v in scope.global_variables()], 1330 ["testGetGlobalVariables_foo/" 1331 "testGetGlobalVariables_b:0"]) 1332 1333 # TODO(mihaimaruseac): Not converted to use wrap_function because of 1334 # obtaining different results in the eager case compared to the graph one 1335 @test_util.run_deprecated_v1 1336 def testGetLocalVariables(self): 1337 with self.cached_session(): 1338 _ = variable_scope.get_variable( 1339 "a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES]) 1340 with variable_scope.variable_scope("foo") as scope: 1341 _ = variable_scope.get_variable( 1342 "b", [], collections=[ops.GraphKeys.LOCAL_VARIABLES]) 1343 _ = variable_scope.get_variable("c", []) 1344 self.assertEqual([v.name for v in scope.local_variables()], ["foo/b:0"]) 1345 1346 @test_util.run_in_graph_and_eager_modes 1347 @run_inside_wrap_function_in_eager_mode 1348 def testGetVariableWithRefDtype(self): 1349 v = variable_scope.get_variable("v", shape=[3, 4], dtype=dtypes.float32) 1350 # Ensure it is possible to do get_variable with a _ref dtype passed in. 1351 _ = variable_scope.get_variable("w", shape=[5, 6], dtype=v.dtype) 1352 1353 @test_util.run_in_graph_and_eager_modes 1354 @run_inside_wrap_function_in_eager_mode 1355 def testGetVariableWithInitializerWhichTakesNoArgs(self): 1356 v = variable_scope.get_variable("foo", initializer=lambda: [2]) 1357 self.assertEqual(v.name, "foo:0") 1358 1359 @test_util.run_in_graph_and_eager_modes 1360 @run_inside_wrap_function_in_eager_mode 1361 def testGetVariableWithInitializerWhichTakesOptionalArgs(self): 1362 v = variable_scope.get_variable("foo", initializer=lambda x=True: [2]) 1363 self.assertEqual(v.name, "foo:0") 1364 1365 @test_util.run_in_graph_and_eager_modes 1366 @run_inside_wrap_function_in_eager_mode 1367 def testGetVariableWithInitializerWhichTakesUnprovidedArgsAndNoShape(self): 1368 with self.assertRaisesRegex( 1369 ValueError, 1370 "The initializer passed is not valid. It should be a callable with no " 1371 "arguments and the shape should not be provided or an instance of " 1372 "`tf.keras.initializers.*' and `shape` should be fully defined."): 1373 variable_scope.get_variable("foo", initializer=lambda x: [2]) 1374 1375 @test_util.run_in_graph_and_eager_modes 1376 @run_inside_wrap_function_in_eager_mode 1377 def testTwoGraphs(self): 1378 1379 def f(): 1380 g1 = ops.Graph() 1381 g2 = ops.Graph() 1382 with g1.as_default(): 1383 with g2.as_default(): 1384 with variable_scope.variable_scope("_"): 1385 pass 1386 1387 self.assertRaisesRegex(ValueError, "'_' is not a valid scope name", f) 1388 1389 1390def axis0_into1_partitioner(shape=None, **unused_kwargs): 1391 part = [1] * len(shape) 1392 return part 1393 1394 1395def axis0_into2_partitioner(shape=None, **unused_kwargs): 1396 part = [1] * len(shape) 1397 part[0] = 2 1398 return part 1399 1400 1401def axis0_into3_partitioner(shape=None, **unused_kwargs): 1402 part = [1] * len(shape) 1403 part[0] = 3 1404 return part 1405 1406 1407class VariableScopeWithPartitioningTest(test.TestCase): 1408 1409 # TODO(mihaimaruseac): Not converted to use wrap_function because of 1410 # obtaining different results in the eager case compared to the graph one 1411 @test_util.run_deprecated_v1 1412 def testResultNameMatchesRequested(self): 1413 with variable_scope.variable_scope( 1414 "scope0", partitioner=axis0_into2_partitioner): 1415 v = variable_scope.get_variable("name0", shape=(3, 1, 1)) 1416 self.assertEqual(v.name, "scope0/name0") 1417 v_concat = v.as_tensor() 1418 self.assertEqual(v_concat.name, "scope0/name0:0") 1419 variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 1420 self.assertIn("scope0/name0/part_0:0", [x.name for x in variables]) 1421 self.assertIn("scope0/name0/part_1:0", [x.name for x in variables]) 1422 self.assertNotIn("scope0/name0/part_2:0", [x.name for x in variables]) 1423 1424 @test_util.run_in_graph_and_eager_modes 1425 @run_inside_wrap_function_in_eager_mode 1426 def testBreaksIfPartitioningChanges(self): 1427 with variable_scope.variable_scope( 1428 "scope0", partitioner=axis0_into2_partitioner): 1429 variable_scope.get_variable("name0", shape=(3, 1, 1)) 1430 1431 with variable_scope.variable_scope( 1432 "scope0", partitioner=axis0_into3_partitioner, reuse=True): 1433 with self.assertRaisesRegex( 1434 ValueError, 1435 "Trying to reuse partitioned variable .* but specified partitions " 1436 ".* and found partitions .*"): 1437 variable_scope.get_variable("name0", shape=(3, 1, 1)) 1438 1439 with variable_scope.variable_scope( 1440 "scope0", partitioner=axis0_into1_partitioner, reuse=True): 1441 with self.assertRaisesRegex( 1442 ValueError, 1443 "Trying to reuse partitioned variable .* but specified partitions " 1444 ".* and found partitions .*"): 1445 variable_scope.get_variable("name0", shape=(3, 1, 1)) 1446 1447 @test_util.run_in_graph_and_eager_modes 1448 @run_inside_wrap_function_in_eager_mode 1449 def testReturnsExistingConcatenatedValueIfReuse(self): 1450 with variable_scope.variable_scope( 1451 "scope0", partitioner=axis0_into2_partitioner): 1452 v_concat = variable_scope.get_variable("name0", shape=(3, 1, 1)) 1453 variable_scope.get_variable_scope().reuse_variables() 1454 v_concat_2 = variable_scope.get_variable("name0", shape=(3, 1, 1)) 1455 self.assertEqual(v_concat, v_concat_2) 1456 1457 @test_util.run_in_graph_and_eager_modes 1458 @run_inside_wrap_function_in_eager_mode 1459 def testAllowsReuseWithoutPartitioner(self): 1460 with variable_scope.variable_scope( 1461 "scope0", partitioner=axis0_into2_partitioner): 1462 v = variable_scope.get_variable("name0", shape=(3, 1, 1)) 1463 with variable_scope.variable_scope("scope0", reuse=True): 1464 v_reused = variable_scope.get_variable("name0") 1465 self.assertIs(v, v_reused) 1466 1467 def testNoReuseInEagerByDefault(self): 1468 with context.eager_mode(): 1469 with variable_scope.variable_scope( 1470 "scope0", partitioner=axis0_into2_partitioner): 1471 v1 = variable_scope.get_variable("name0", shape=(3, 1, 1)) 1472 v2 = variable_scope.get_variable("name0", shape=(3, 1, 1)) 1473 self.assertIsNot(v1, v2) 1474 1475 @test_util.run_in_graph_and_eager_modes 1476 @run_inside_wrap_function_in_eager_mode 1477 def testPropagatePartitionerOnReopening(self): 1478 with variable_scope.variable_scope( 1479 "scope0", partitioner=axis0_into2_partitioner) as vs: 1480 self.assertEqual(axis0_into2_partitioner, vs.partitioner) 1481 with variable_scope.variable_scope(vs) as vs1: 1482 self.assertEqual(axis0_into2_partitioner, vs1.partitioner) 1483 1484 # TODO(mihaimaruseac): Not converted to use wrap_function because of 1485 # obtaining different results in the eager case compared to the graph one 1486 @test_util.run_deprecated_v1 1487 def testScalarIgnoresPartitioner(self): 1488 with variable_scope.variable_scope( 1489 "scope0", partitioner=axis0_into2_partitioner): 1490 v = variable_scope.get_variable("name0", shape=()) 1491 self.assertEqual(v.name, "scope0/name0:0") 1492 variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 1493 self.assertIn("scope0/name0:0", [x.name for x in variables]) 1494 1495 def _testPartitionConcatenatesAlongCorrectAxis(self, use_resource): 1496 def _part_axis_0(**unused_kwargs): 1497 return (2, 1, 1) 1498 1499 def _part_axis_1(**unused_kwargs): 1500 return (1, 2, 1) 1501 1502 with variable_scope.variable_scope("root", use_resource=use_resource): 1503 v0 = variable_scope.get_variable( 1504 "n0", shape=(2, 2, 2), partitioner=_part_axis_0) 1505 v1 = variable_scope.get_variable( 1506 "n1", shape=(2, 2, 2), partitioner=_part_axis_1) 1507 1508 self.assertEqual(v0.get_shape(), (2, 2, 2)) 1509 self.assertEqual(v1.get_shape(), (2, 2, 2)) 1510 1511 n0_0 = list(v0)[0] 1512 n0_1 = list(v0)[1] 1513 self.assertEqual(n0_0.get_shape(), (1, 2, 2)) 1514 self.assertEqual(n0_1.get_shape(), (1, 2, 2)) 1515 1516 n1_0 = list(v1)[0] 1517 n1_1 = list(v1)[1] 1518 self.assertEqual(n1_0.get_shape(), (2, 1, 2)) 1519 self.assertEqual(n1_1.get_shape(), (2, 1, 2)) 1520 1521 @test_util.run_in_graph_and_eager_modes 1522 @run_inside_wrap_function_in_eager_mode 1523 def testPartitionConcatenatesAlongCorrectAxis(self): 1524 self._testPartitionConcatenatesAlongCorrectAxis(use_resource=False) 1525 1526 @test_util.run_in_graph_and_eager_modes 1527 @run_inside_wrap_function_in_eager_mode 1528 def testPartitionConcatenatesAlongCorrectAxisResource(self): 1529 self._testPartitionConcatenatesAlongCorrectAxis(use_resource=True) 1530 1531 def testPartitionConcatenatesAlongCorrectAxisResourceInEager(self): 1532 with context.eager_mode(): 1533 self._testPartitionConcatenatesAlongCorrectAxis(use_resource=True) 1534 1535 1536class VariableScopeWithCustomGetterTest(test.TestCase): 1537 1538 @test_util.run_in_graph_and_eager_modes 1539 @run_inside_wrap_function_in_eager_mode 1540 def testNonCallableGetterFails(self): 1541 with self.assertRaisesRegex(ValueError, r"custom_getter .* not callable:"): 1542 with variable_scope.variable_scope("scope0", custom_getter=3): 1543 variable_scope.get_variable("name0") 1544 with self.assertRaisesRegex(ValueError, r"custom_getter .* not callable:"): 1545 variable_scope.get_variable("name0", custom_getter=3) 1546 1547 @test_util.run_in_graph_and_eager_modes 1548 @run_inside_wrap_function_in_eager_mode 1549 def testNoSideEffectsWithIdentityCustomGetter(self): 1550 called = [0] 1551 1552 def custom_getter(getter, *args, **kwargs): 1553 called[0] += 1 1554 return getter(*args, **kwargs) 1555 1556 with variable_scope.variable_scope( 1557 "scope", custom_getter=custom_getter) as scope: 1558 v = variable_scope.get_variable("v", [1]) 1559 with variable_scope.variable_scope(scope, reuse=True): 1560 v2 = variable_scope.get_variable("v", [1]) 1561 with variable_scope.variable_scope("new_scope") as new_scope: 1562 v3 = variable_scope.get_variable("v3", [1]) 1563 with variable_scope.variable_scope( 1564 new_scope, reuse=True, custom_getter=custom_getter): 1565 v4 = variable_scope.get_variable("v3", [1]) 1566 1567 self.assertIs(v, v2) 1568 self.assertIs(v3, v4) 1569 self.assertEqual(3, called[0]) # skipped one in the first new_scope 1570 1571 @test_util.run_in_graph_and_eager_modes 1572 @run_inside_wrap_function_in_eager_mode 1573 def testSynchronizationAndAggregationWithCustomGetter(self): 1574 called = [0] 1575 synchronization = variable_scope.VariableSynchronization.AUTO 1576 aggregation = variable_scope.VariableAggregation.NONE 1577 1578 def custom_getter(getter, *args, **kwargs): 1579 called[0] += 1 1580 1581 # Verify synchronization and aggregation kwargs are as expected. 1582 self.assertEqual(kwargs["synchronization"], synchronization) 1583 self.assertEqual(kwargs["aggregation"], aggregation) 1584 return getter(*args, **kwargs) 1585 1586 with variable_scope.variable_scope("scope", custom_getter=custom_getter): 1587 variable_scope.get_variable("v", [1]) 1588 self.assertEqual(1, called[0]) 1589 1590 with variable_scope.variable_scope("scope", custom_getter=custom_getter): 1591 synchronization = variable_scope.VariableSynchronization.ON_READ 1592 aggregation = variable_scope.VariableAggregation.MEAN 1593 variable_scope.get_variable( 1594 "v1", [1], synchronization=synchronization, aggregation=aggregation) 1595 1596 self.assertEqual(2, called[0]) 1597 1598 @test_util.run_in_graph_and_eager_modes 1599 @run_inside_wrap_function_in_eager_mode 1600 def testCustomGetterWithReuse(self): 1601 # Custom getter can choose to behave differently on reused variables. 1602 def custom_getter(getter, *args, **kwargs): 1603 var = getter(*args, **kwargs) 1604 if kwargs["reuse"]: 1605 # This can be used, e.g., for changing the caching device if needed. 1606 return array_ops.identity(var, name="reused") 1607 else: 1608 return array_ops.identity(var, name="not_reused") 1609 1610 with variable_scope.variable_scope( 1611 "scope", custom_getter=custom_getter) as scope: 1612 v = variable_scope.get_variable("v", [1]) 1613 with variable_scope.variable_scope(scope, reuse=True): 1614 v2 = variable_scope.get_variable("v", [1]) 1615 1616 self.assertEqual(v.name, "not_reused:0") 1617 self.assertEqual(v2.name, "reused:0") 1618 1619 # TODO(mihaimaruseac): Not converted to use wrap_function because of 1620 # ValueError: Fetch argument <tf.Tensor 'custom_getter/add:0' shape=(1, 2, 3) 1621 # dtype=float32> cannot be interpreted as a Tensor. (Tensor 1622 # Tensor("custom_getter/add:0", shape=(1, 2, 3), dtype=float32) is not an 1623 # element of this graph.) 1624 @test_util.run_deprecated_v1 1625 def testGetterThatCreatesTwoVariablesAndSumsThem(self): 1626 1627 def custom_getter(getter, name, *args, **kwargs): 1628 g_0 = getter("%s/0" % name, *args, **kwargs) 1629 g_1 = getter("%s/1" % name, *args, **kwargs) 1630 with ops.name_scope("custom_getter"): 1631 return g_0 + g_1 1632 1633 with variable_scope.variable_scope("scope", custom_getter=custom_getter): 1634 v = variable_scope.get_variable("v", [1, 2, 3]) 1635 1636 self.assertEqual([1, 2, 3], v.get_shape()) 1637 true_vars = variables_lib.trainable_variables() 1638 self.assertEqual(2, len(true_vars)) 1639 self.assertEqual("scope/v/0:0", true_vars[0].name) 1640 self.assertEqual("scope/v/1:0", true_vars[1].name) 1641 self.assertEqual("custom_getter/add:0", v.name) 1642 with self.cached_session() as sess: 1643 variables_lib.global_variables_initializer().run() 1644 np_vars, np_v = self.evaluate([true_vars, v]) 1645 self.assertAllClose(np_v, sum(np_vars)) 1646 1647 # TODO(mihaimaruseac): Not converted to use wrap_function because of 1648 # ValueError: Fetch argument <tf.Tensor 'sum_getter_2/add:0' shape=(1, 2, 3) 1649 # dtype=float32> cannot be interpreted as a Tensor. (Tensor 1650 # Tensor("sum_getter_2/add:0", shape=(1, 2, 3), dtype=float32) is not an 1651 # element of this graph.) 1652 @test_util.run_deprecated_v1 1653 def testNestedCustomGetters(self): 1654 1655 def sum_getter(getter, name, *args, **kwargs): 1656 g_0 = getter("%s/sum_0" % name, *args, **kwargs) 1657 g_1 = getter("%s/sum_1" % name, *args, **kwargs) 1658 with ops.name_scope("sum_getter"): 1659 return g_0 + g_1 1660 1661 def prod_getter(getter, name, *args, **kwargs): 1662 g_0 = getter("%s/prod_0" % name, *args, **kwargs) 1663 g_1 = getter("%s/prod_1" % name, *args, **kwargs) 1664 with ops.name_scope("prod_getter"): 1665 return g_0 * g_1 1666 1667 with variable_scope.variable_scope("prod_scope", custom_getter=prod_getter): 1668 with variable_scope.variable_scope("sum_scope", custom_getter=sum_getter): 1669 with variable_scope.variable_scope( 1670 "inner_sum_scope", custom_getter=sum_getter): 1671 # take sums of sums of products 1672 v = variable_scope.get_variable("v", [1, 2, 3]) 1673 1674 self.assertEqual([1, 2, 3], v.get_shape()) 1675 true_vars = variables_lib.trainable_variables() 1676 self.assertEqual(8, len(true_vars)) 1677 template = ( 1678 "prod_scope/sum_scope/inner_sum_scope/v/sum_%d/sum_%d/prod_%d:0") 1679 self.assertEqual(template % (0, 0, 0), true_vars[0].name) 1680 self.assertEqual(template % (0, 0, 1), true_vars[1].name) 1681 self.assertEqual(template % (0, 1, 0), true_vars[2].name) 1682 self.assertEqual(template % (0, 1, 1), true_vars[3].name) 1683 self.assertEqual(template % (1, 0, 0), true_vars[4].name) 1684 self.assertEqual(template % (1, 0, 1), true_vars[5].name) 1685 self.assertEqual(template % (1, 1, 0), true_vars[6].name) 1686 self.assertEqual(template % (1, 1, 1), true_vars[7].name) 1687 1688 with self.cached_session() as sess: 1689 variables_lib.global_variables_initializer().run() 1690 np_vars, np_v = self.evaluate([true_vars, v]) 1691 # take products of sums of products 1692 self.assertAllClose( 1693 np_v, (((np_vars[0] * np_vars[1]) + (np_vars[2] * np_vars[3])) + ( 1694 (np_vars[4] * np_vars[5]) + (np_vars[6] * np_vars[7])))) 1695 1696 @test_util.run_in_graph_and_eager_modes 1697 @run_inside_wrap_function_in_eager_mode 1698 def testVariableCreator(self): 1699 variable_names = [] 1700 1701 def creator_a(next_creator, **kwargs): 1702 variable_names.append(kwargs.get("name", "")) 1703 return next_creator(**kwargs) 1704 1705 def creator_b(next_creator, **kwargs): 1706 kwargs["name"] = "forced_name" 1707 return next_creator(**kwargs) 1708 1709 with variable_scope.variable_creator_scope(creator_a): 1710 with variable_scope.variable_creator_scope(creator_b): 1711 variable_scope.variable(1.0, name="one_name") 1712 1713 self.assertEqual(variable_names[0], "forced_name") 1714 1715 called = [False] 1716 1717 def creater_c(next_creator, **kwargs): 1718 called[0] = True 1719 self.assertEqual(kwargs["synchronization"], 1720 variable_scope.VariableSynchronization.ON_WRITE) 1721 self.assertEqual(kwargs["aggregation"], 1722 variable_scope.VariableAggregation.MEAN) 1723 return next_creator(**kwargs) 1724 1725 with variable_scope.variable_creator_scope(creater_c): 1726 variable_scope.get_variable( 1727 "v", [], 1728 synchronization=variable_scope.VariableSynchronization.ON_WRITE, 1729 aggregation=variable_scope.VariableAggregation.MEAN) 1730 self.assertTrue(called[0]) 1731 1732 @test_util.run_in_graph_and_eager_modes 1733 @run_inside_wrap_function_in_eager_mode 1734 def testVariableCreatorNestingError(self): 1735 1736 def creator(next_creator, **kwargs): 1737 return next_creator(**kwargs) 1738 1739 # Save the state so we can clean up at the end. 1740 graph = ops.get_default_graph() 1741 old_creator_stack = graph._variable_creator_stack 1742 1743 try: 1744 scope = variable_scope.variable_creator_scope(creator) 1745 scope.__enter__() 1746 with variable_scope.variable_creator_scope(creator): 1747 with self.assertRaises(RuntimeError): 1748 scope.__exit__(None, None, None) 1749 finally: 1750 graph._variable_creator_stack = old_creator_stack 1751 1752 1753class PartitionInfoTest(test.TestCase): 1754 1755 @test_util.run_in_graph_and_eager_modes 1756 @run_inside_wrap_function_in_eager_mode 1757 def testConstructorChecks(self): 1758 # Invalid arg types. 1759 with self.assertRaises(TypeError): 1760 variable_scope._PartitionInfo(full_shape=None, var_offset=[0, 1]) 1761 with self.assertRaises(TypeError): 1762 variable_scope._PartitionInfo(full_shape=[0, 1], var_offset=None) 1763 with self.assertRaises(TypeError): 1764 variable_scope._PartitionInfo(full_shape="foo", var_offset=[0, 1]) 1765 with self.assertRaises(TypeError): 1766 variable_scope._PartitionInfo(full_shape=[0, 1], var_offset="foo") 1767 1768 # full_shape and var_offset must have same length. 1769 with self.assertRaises(ValueError): 1770 variable_scope._PartitionInfo(full_shape=[0, 1], var_offset=[0]) 1771 # Offset must always be less than shape. 1772 with self.assertRaises(ValueError): 1773 variable_scope._PartitionInfo(full_shape=[1, 1], var_offset=[0, 1]) 1774 1775 @test_util.run_in_graph_and_eager_modes 1776 @run_inside_wrap_function_in_eager_mode 1777 def testSingleOffset(self): 1778 partition_info = variable_scope._PartitionInfo( 1779 full_shape=[9, 3], var_offset=[4, 0]) 1780 self.assertEqual(4, partition_info.single_offset([1, 3])) 1781 1782 # Tests when the variable isn't partitioned at all. 1783 partition_info = variable_scope._PartitionInfo( 1784 full_shape=[9, 3], var_offset=[0, 0]) 1785 self.assertEqual(0, partition_info.single_offset([9, 3])) 1786 1787 @test_util.run_in_graph_and_eager_modes 1788 @run_inside_wrap_function_in_eager_mode 1789 def testSingleSliceDim(self): 1790 partition_info = variable_scope._PartitionInfo( 1791 full_shape=[9, 3], var_offset=[4, 0]) 1792 # Invalid shape. 1793 with self.assertRaises(TypeError): 1794 partition_info.single_slice_dim(None) 1795 1796 # Rank of shape differs from full_shape. 1797 with self.assertRaises(ValueError): 1798 partition_info.single_slice_dim([1, 2, 3]) 1799 1800 # Shape is too large given var_offset (4+6 > 9). 1801 with self.assertRaises(ValueError): 1802 partition_info.single_slice_dim([6, 3]) 1803 1804 # Multiple possible slice dim from shape. 1805 with self.assertRaises(ValueError): 1806 partition_info.single_slice_dim([1, 1]) 1807 1808 partition_info = variable_scope._PartitionInfo( 1809 full_shape=[9, 3], var_offset=[0, 0]) 1810 self.assertEqual(1, partition_info.single_slice_dim([9, 2])) 1811 partition_info = variable_scope._PartitionInfo( 1812 full_shape=[9, 3], var_offset=[4, 0]) 1813 self.assertEqual(0, partition_info.single_slice_dim([2, 3])) 1814 1815 1816class VariableScopeMultithreadedTest(test.TestCase): 1817 1818 @test_util.run_in_graph_and_eager_modes 1819 @run_inside_wrap_function_in_eager_mode 1820 def testTwoThreadsDisjointScopeEntry(self): 1821 1822 def thread_fn(i, graph): 1823 with graph.as_default(): 1824 with variable_scope.variable_scope("foo"): 1825 if i == 0: 1826 v = variable_scope.get_variable("v", []) 1827 self.assertEqual("foo/v:0", v.name) 1828 else: 1829 # Any thread after the first one should fail to create variable 1830 # with the same name. 1831 with self.assertRaises(ValueError): 1832 variable_scope.get_variable("v", []) 1833 1834 graph = ops.get_default_graph() 1835 threads = [ 1836 threading.Thread(target=thread_fn, args=( 1837 i, 1838 graph, 1839 )) for i in range(2) 1840 ] 1841 1842 threads[0].start() 1843 # Allow thread 0 to finish before starting thread 1. 1844 threads[0].join() 1845 threads[1].start() 1846 threads[1].join() 1847 1848 @test_util.run_in_graph_and_eager_modes 1849 @run_inside_wrap_function_in_eager_mode 1850 def testTwoThreadsNestedScopeEntry(self): 1851 1852 def thread_fn(i, graph, run_event, pause_event): 1853 with graph.as_default(): 1854 with variable_scope.variable_scope("foo"): 1855 if i == 0: 1856 v = variable_scope.get_variable("v", []) 1857 self.assertEqual("foo/v:0", v.name) 1858 else: 1859 # Any thread after the first one should fail to create variable 1860 # with the same name. 1861 with self.assertRaises(ValueError): 1862 variable_scope.get_variable("v", []) 1863 pause_event.set() 1864 run_event.wait() 1865 1866 graph = ops.get_default_graph() 1867 run_events = [threading.Event() for _ in range(2)] 1868 pause_events = [threading.Event() for _ in range(2)] 1869 threads = [ 1870 threading.Thread( 1871 target=thread_fn, args=(i, graph, run_events[i], pause_events[i])) 1872 for i in range(2) 1873 ] 1874 1875 # Start first thread. 1876 threads[0].start() 1877 pause_events[0].wait() 1878 # Start next thread once the first thread has paused. 1879 threads[1].start() 1880 pause_events[1].wait() 1881 # Resume both threads. 1882 run_events[0].set() 1883 run_events[1].set() 1884 threads[0].join() 1885 threads[1].join() 1886 1887 @test_util.run_in_graph_and_eager_modes 1888 @run_inside_wrap_function_in_eager_mode 1889 def testReenterMainScope(self): 1890 1891 def thread_fn(graph, main_thread_scope): 1892 with graph.as_default(): 1893 # Variable created with main scope will have prefix "main". 1894 with variable_scope.variable_scope(main_thread_scope): 1895 with variable_scope.variable_scope("foo"): 1896 v = variable_scope.get_variable("v", []) 1897 self.assertEqual("main/foo/v:0", v.name) 1898 1899 # Variable created outside main scope will not have prefix "main". 1900 with variable_scope.variable_scope("bar"): 1901 v = variable_scope.get_variable("v", []) 1902 self.assertEqual("bar/v:0", v.name) 1903 1904 graph = ops.get_default_graph() 1905 with variable_scope.variable_scope("main") as main_thread_scope: 1906 thread = threading.Thread( 1907 target=thread_fn, args=(graph, main_thread_scope)) 1908 thread.start() 1909 thread.join() 1910 1911 1912if __name__ == "__main__": 1913 test.main() 1914