1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for variable store.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import gc 22import threading 23 24import numpy 25 26from tensorflow.python.eager import context 27from tensorflow.python.eager import function 28from tensorflow.python.eager import wrap_function 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import errors 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import test_util 34from tensorflow.python.layers import core as core_layers 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import init_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import resource_variable_ops 40from tensorflow.python.ops import state_ops 41from tensorflow.python.ops import variable_scope 42from tensorflow.python.ops import variables as variables_lib 43from tensorflow.python.platform import test 44from tensorflow.python.util import compat 45from tensorflow.python.util import tf_inspect 46 47 48def run_inside_wrap_function_in_eager_mode(graph_function): 49 """Decorator to execute the same graph code in eager and graph modes. 50 51 In graph mode, we just execute the graph_function passed as argument. In eager 52 mode, we wrap the function using wrap_function and then execute the wrapped 53 result. 54 55 Args: 56 graph_function: python function containing graph code to be wrapped 57 58 Returns: 59 decorated function 60 """ 61 def wrap_and_execute(self): 62 if context.executing_eagerly(): 63 wrapped = wrap_function.wrap_function(graph_function, [self]) 64 # use the wrapped graph function 65 wrapped() 66 else: 67 # use the original function 68 graph_function(self) 69 return wrap_and_execute 70 71 72class VariableScopeTest(test.TestCase): 73 74 def tearDown(self): 75 gc.collect() 76 # This will only contain uncollectable garbage, i.e. reference cycles 77 # involving objects with __del__ defined. 78 self.assertEqual(0, len(gc.garbage)) 79 80 @test_util.run_in_graph_and_eager_modes 81 @run_inside_wrap_function_in_eager_mode 82 def testGetVar(self): 83 vs = variable_scope._get_default_variable_store() 84 v = vs.get_variable("v", [1]) 85 v1 = vs.get_variable("v", [1]) 86 self.assertEqual(v, v1) 87 88 @test_util.run_in_graph_and_eager_modes 89 @run_inside_wrap_function_in_eager_mode 90 def testResource(self): 91 vs = variable_scope._get_default_variable_store() 92 v1 = vs.get_variable("v", [1], use_resource=True) 93 self.assertTrue(isinstance(v1, resource_variable_ops.ResourceVariable)) 94 95 @test_util.run_in_graph_and_eager_modes 96 @run_inside_wrap_function_in_eager_mode 97 def testNameExists(self): 98 vs = variable_scope._get_default_variable_store() 99 # No check by default, so we can both create and get existing names. 100 v = vs.get_variable("v", [1]) 101 v1 = vs.get_variable("v", [1]) 102 self.assertEqual(v, v1) 103 104 # When reuse is False, we fail when variables are already there. 105 vs.get_variable("w", [1], reuse=False) # That's ok. 106 with self.assertRaises(ValueError): 107 vs.get_variable("v", [1], reuse=False) # That fails. 108 # When reuse is True, we fail when variables are new. 109 vs.get_variable("v", [1], reuse=True) # That's ok. 110 with self.assertRaises(ValueError): 111 vs.get_variable("u", [1], reuse=True) # That fails. 112 113 @test_util.run_in_graph_and_eager_modes 114 @run_inside_wrap_function_in_eager_mode 115 def testNamelessStore(self): 116 vs = variable_scope._get_default_variable_store() 117 vs.get_variable("v1", [2]) 118 vs.get_variable("v2", [2]) 119 expected_names = ["%s:0" % name for name in ["v1", "v2"]] 120 self.assertEqual( 121 set(expected_names), set([v.name for v in vs._vars.values()])) 122 123 # TODO(mihaimaruseac): Not converted to use wrap_function because of 124 # TypeError: Expected tf.group() expected Tensor arguments not 'None' with 125 # type '<type 'NoneType'>' 126 @test_util.run_in_graph_and_eager_modes 127 def testVarScopeInitializer(self): 128 init = init_ops.constant_initializer(0.3) 129 with variable_scope.variable_scope("tower0") as tower: 130 with variable_scope.variable_scope("foo", initializer=init): 131 v = variable_scope.get_variable("v", []) 132 self.evaluate(variables_lib.variables_initializer([v])) 133 self.assertAllClose(self.evaluate(v.value()), 0.3) 134 with variable_scope.variable_scope(tower, initializer=init): 135 w = variable_scope.get_variable("w", []) 136 self.evaluate(variables_lib.variables_initializer([w])) 137 self.assertAllClose(self.evaluate(w.value()), 0.3) 138 139 @test_util.run_in_graph_and_eager_modes 140 @run_inside_wrap_function_in_eager_mode 141 def testVarScopeConstraint(self): 142 constraint = lambda x: 0. * x 143 with variable_scope.variable_scope("tower1") as tower: 144 with variable_scope.variable_scope("foo", constraint=constraint): 145 v = variable_scope.get_variable("v", []) 146 self.assertEqual(v.constraint, constraint) 147 with variable_scope.variable_scope(tower, constraint=constraint): 148 w = variable_scope.get_variable("w", []) 149 self.assertEqual(w.constraint, constraint) 150 151 # TODO(mihaimaruseac): Not converted to use wrap_function because of 152 # TypeError: Fetch argument <tf.Variable 'string:0' shape=() dtype=string> 153 # has invalid type <class '...ResourceVariable'>, must be a string or Tensor. 154 # (Can not convert a ResourceVariable into a Tensor or Operation.) 155 @test_util.run_deprecated_v1 156 def testStringDefaultInitializer(self): 157 with self.cached_session(): 158 v = variable_scope.get_variable("string", shape=[], dtype=dtypes.string) 159 variables_lib.global_variables_initializer().run() 160 self.assertAllEqual(compat.as_bytes(self.evaluate(v)), b"") 161 162 @test_util.run_in_graph_and_eager_modes 163 @run_inside_wrap_function_in_eager_mode 164 def testVarScopeDType(self): 165 with variable_scope.variable_scope("tower2") as tower: 166 with variable_scope.variable_scope("foo", dtype=dtypes.float16): 167 v = variable_scope.get_variable("v", []) 168 self.assertEqual(v.dtype.base_dtype, dtypes.float16) 169 with variable_scope.variable_scope(tower, dtype=dtypes.float16): 170 w = variable_scope.get_variable("w", []) 171 self.assertEqual(w.dtype.base_dtype, dtypes.float16) 172 173 def testGetVariableInGraphNestedUnderEagerContext(self): 174 with context.eager_mode(): 175 176 @function.defun 177 def f(): 178 v = variable_scope.get_variable("should_be_resource", []) 179 self.assertEqual(type(v), resource_variable_ops.ResourceVariable) 180 181 f() 182 183 def testEagerVariableStore(self): 184 with context.eager_mode(): 185 store = variable_scope.EagerVariableStore() 186 with store.as_default(): 187 v = variable_scope.get_variable("v", shape=(), trainable=True) 188 w = variable_scope.get_variable("w", shape=(), trainable=False) 189 190 self.assertTrue(v in store.variables()) 191 self.assertTrue(w in store.variables()) 192 self.assertTrue(v in store.trainable_variables()) 193 self.assertFalse(w in store.trainable_variables()) 194 self.assertFalse(v in store.non_trainable_variables()) 195 self.assertTrue(w in store.non_trainable_variables()) 196 197 # Test copying. 198 new_store = store.copy() 199 with new_store.as_default(): 200 new_v = variable_scope.get_variable("v") 201 new_w = variable_scope.get_variable("w") 202 self.assertEqual(new_v.numpy(), v.numpy()) 203 self.assertEqual(new_w.numpy(), w.numpy()) 204 self.assertTrue(new_v in new_store.variables()) 205 self.assertTrue(new_w in new_store.variables()) 206 self.assertTrue(new_v in new_store.trainable_variables()) 207 self.assertFalse(new_w in new_store.trainable_variables()) 208 self.assertFalse(new_v in new_store.non_trainable_variables()) 209 self.assertTrue(new_w in new_store.non_trainable_variables()) 210 211 # Check that variables are separate instances. 212 for v in store.variables(): 213 v.assign(-1) 214 for v in new_store.variables(): 215 v.assign(1) 216 for v in store.variables(): 217 self.assertEqual(v.numpy(), -1) 218 for v in new_store.variables(): 219 self.assertEqual(v.numpy(), 1) 220 221 def testEagerVariableStoreWithEagerDefun(self): 222 with context.eager_mode(): 223 224 @function.defun 225 def f(): 226 x = constant_op.constant([[2.0]]) 227 d1 = core_layers.Dense( 228 1, name="my_dense", kernel_initializer=init_ops.ones_initializer()) 229 _ = d1(x) # create variables 230 self.assertEqual(len(d1.variables), 2) 231 v1, v2 = d1.variables 232 d2 = core_layers.Dense( 233 1, 234 name="my_dense", 235 kernel_initializer=init_ops.ones_initializer(), 236 _reuse=True) 237 _ = d2(x) 238 self.assertEqual(len(d2.variables), 2) 239 v3, v4 = d2.variables 240 self.assertEqual(v1, v3) 241 self.assertEqual(v2, v4) 242 f() 243 244 # TODO(mihaimaruseac): Not converted to use wrap_function because of 245 # obtaining different results in the eager case compared to the graph one 246 @test_util.run_in_graph_and_eager_modes 247 def testEagerVariablesStoreAddsToCollections(self): 248 store = variable_scope.EagerVariableStore() 249 with store.as_default(): 250 trainable = variable_scope.get_variable("v1", [], trainable=True) 251 not_trainable = variable_scope.get_variable("v2", [], trainable=False) 252 concat = variable_scope.get_variable( 253 "v3", [], collections=[ops.GraphKeys.CONCATENATED_VARIABLES]) 254 self.assertEqual( 255 ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES), 256 [trainable, not_trainable]) 257 self.assertEqual( 258 ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES), 259 [trainable, concat]) 260 self.assertEqual( 261 ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES), [concat]) 262 263 def testEagerVariablesOutsideStoreNotAddedToCollections(self): 264 with context.eager_mode(): 265 variable_scope.get_variable("v1", [], trainable=True) 266 variable_scope.get_variable("v2", [], trainable=False) 267 self.assertFalse(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) 268 self.assertFalse(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) 269 270 # TODO(mihaimaruseac): Not converted to use wrap_function because of 271 # TypeError: Expected tf.group() expected Tensor arguments not 'None' with 272 # type '<type 'NoneType'>'. 273 @test_util.run_in_graph_and_eager_modes 274 def testInitFromNonTensorValue(self): 275 v = variable_scope.get_variable("v4", initializer=4, dtype=dtypes.int32) 276 self.evaluate(variables_lib.variables_initializer([v])) 277 self.assertAllClose(self.evaluate(v.value()), 4) 278 279 w = variable_scope.get_variable( 280 "w4", initializer=numpy.array([1, 2, 3]), dtype=dtypes.int64) 281 self.evaluate(variables_lib.variables_initializer([w])) 282 self.assertAllClose(self.evaluate(w.value()), [1, 2, 3]) 283 284 # A quirk to be revisited? 285 error = ValueError if context.executing_eagerly() else TypeError 286 with self.assertRaises(error): 287 variable_scope.get_variable("x4", initializer={}) 288 289 # TODO(mihaimaruseac): Not converted to use wrap_function because of 290 # InvalidArgumentError=: You must feed a value for placeholder tensor 291 # 'ReadVariableOp/resource' with dtype resource 292 @test_util.run_in_graph_and_eager_modes 293 def testInitFromNonInitializer(self): 294 # Test various dtypes with zeros initializer as following: 295 types = [ 296 dtypes.int8, dtypes.uint8, dtypes.int16, dtypes.uint16, dtypes.int32, 297 dtypes.int64, dtypes.bool 298 ] 299 300 # Use different variable_name to distinguish various dtypes 301 for (i, dtype) in enumerate(types): 302 x = variable_scope.get_variable( 303 name="xx%d" % i, shape=(3, 4), dtype=dtype) 304 y = variable_scope.get_variable( 305 name="yy%d" % i, 306 shape=(3, 4), 307 dtype=dtype, 308 initializer=init_ops.zeros_initializer(dtype=dtype)) 309 310 self.evaluate(variables_lib.global_variables_initializer()) 311 self.assertAllEqual(self.evaluate(x.value()), self.evaluate(y.value())) 312 313 # TODO(mihaimaruseac): Not converted to use wrap_function because of 314 # InvalidArgumentError: /job:moo/replica:0/task:0/device:CPU:0 unknown device. 315 @test_util.run_deprecated_v1 316 def testVarScopeCachingDevice(self): 317 with self.cached_session(): 318 caching_device = "/job:moo" 319 with variable_scope.variable_scope("tower"): 320 with variable_scope.variable_scope( 321 "caching", caching_device=caching_device): 322 v = variable_scope.get_variable("v", []) 323 self.assertTrue(v.value().device.startswith(caching_device)) 324 325 with variable_scope.variable_scope("child"): 326 v2 = variable_scope.get_variable("v", []) 327 self.assertTrue(v2.value().device.startswith(caching_device)) 328 329 with variable_scope.variable_scope("not_cached", caching_device=""): 330 v2_not_cached = variable_scope.get_variable("v", []) 331 self.assertFalse( 332 v2_not_cached.value().device.startswith(caching_device)) 333 334 with variable_scope.variable_scope( 335 "not_cached_identity_device", 336 caching_device=lambda op: op.device): 337 v2_identity_device = variable_scope.get_variable("v", []) 338 self.assertFalse( 339 v2_identity_device.value().device.startswith(caching_device)) 340 341 with variable_scope.variable_scope("we_will_do_it_live") as vs_live: 342 vs_live.set_caching_device("/job:live") 343 v_live = variable_scope.get_variable("v", []) 344 self.assertTrue(v_live.value().device.startswith("/job:live")) 345 346 v_tower = variable_scope.get_variable("v", []) 347 self.assertFalse(v_tower.value().device.startswith(caching_device)) 348 349 # TODO(mihaimaruseac): Not converted to use wrap_function because of 350 # AttributeError: Tensor.name is meaningless when eager execution is enabled. 351 @test_util.run_in_graph_and_eager_modes 352 def testVarScopeRegularizer(self): 353 init = init_ops.constant_initializer(0.3) 354 355 def regularizer1(v): 356 return math_ops.reduce_mean(v) + 0.1 357 358 def regularizer2(v): 359 return math_ops.reduce_mean(v) + 0.2 360 361 with variable_scope.variable_scope( 362 "tower3", regularizer=regularizer1) as tower: 363 with variable_scope.variable_scope("foo", initializer=init): 364 v = variable_scope.get_variable("v", []) 365 self.evaluate(variables_lib.variables_initializer([v])) 366 losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) 367 self.assertEqual(1, len(losses)) 368 self.assertAllClose(self.evaluate(losses[0]), 0.4) 369 with variable_scope.variable_scope(tower, initializer=init) as vs: 370 u = variable_scope.get_variable("u", []) 371 vs.set_regularizer(regularizer2) 372 w = variable_scope.get_variable("w", []) 373 # Next 3 variable not regularized to test disabling regularization. 374 x = variable_scope.get_variable( 375 "x", [], regularizer=variable_scope.no_regularizer) 376 with variable_scope.variable_scope( 377 "baz", regularizer=variable_scope.no_regularizer): 378 y = variable_scope.get_variable("y", []) 379 vs.set_regularizer(variable_scope.no_regularizer) 380 z = variable_scope.get_variable("z", []) 381 # Check results. 382 losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) 383 self.assertEqual(3, len(losses)) 384 self.evaluate(variables_lib.variables_initializer([u, w, x, y, z])) 385 self.assertAllClose(self.evaluate(losses[0]), 0.4) 386 self.assertAllClose(self.evaluate(losses[1]), 0.4) 387 self.assertAllClose(self.evaluate(losses[2]), 0.5) 388 with variable_scope.variable_scope("foo", reuse=True): 389 # reuse=True is for now only supported when eager execution is disabled. 390 if not context.executing_eagerly(): 391 v = variable_scope.get_variable("v", 392 []) # "v" is already there, reused 393 losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) 394 self.assertEqual(3, len(losses)) # No new loss added. 395 396 # TODO(mihaimaruseac): Not converted to use wrap_function because of 397 # ValueError: Tensor-typed variable initializers must either be wrapped in an 398 # init_scope or callable... 399 @test_util.run_in_graph_and_eager_modes 400 def testInitializeFromValue(self): 401 init = constant_op.constant(0.1) 402 w = variable_scope.get_variable("v", initializer=init) 403 self.evaluate(variables_lib.variables_initializer([w])) 404 self.assertAllClose(self.evaluate(w.value()), 0.1) 405 406 with self.assertRaisesRegexp(ValueError, "shape"): 407 # We disallow explicit shape specification when initializer is constant. 408 variable_scope.get_variable("u", [1], initializer=init) 409 410 with variable_scope.variable_scope("foo", initializer=init): 411 # Constant initializer can be passed through scopes if needed. 412 v = variable_scope.get_variable("v") 413 self.evaluate(variables_lib.variables_initializer([v])) 414 self.assertAllClose(self.evaluate(v.value()), 0.1) 415 416 # Check that non-float32 initializer creates a non-float32 variable. 417 init = constant_op.constant(1, dtype=dtypes.int32) 418 t = variable_scope.get_variable("t", initializer=init) 419 self.assertEqual(t.dtype.base_dtype, dtypes.int32) 420 421 # Raise error if `initializer` dtype and `dtype` are not identical. 422 with self.assertRaisesRegexp(ValueError, "don't match"): 423 variable_scope.get_variable("s", initializer=init, dtype=dtypes.float64) 424 425 # TODO(mihaimaruseac): Not converted to use wrap_function because of 426 # TypeError: Fetch argument <tf.Variable 'v0:0' shape=(1,) dtype=float32> has 427 # invalid type <class '...ops.resource_variable_ops.ResourceVariable'>, must 428 # be a string or Tensor. (Can not convert a ResourceVariable into a Tensor or 429 # Operation.) 430 @test_util.run_deprecated_v1 431 def testControlDeps(self): 432 with self.cached_session() as sess: 433 v0 = variable_scope.get_variable( 434 "v0", [1], initializer=init_ops.constant_initializer(0)) 435 with ops.control_dependencies([v0.value()]): 436 v1 = variable_scope.get_variable( 437 "v1", [1], initializer=init_ops.constant_initializer(1)) 438 add = v1 + v0 439 # v0 should be uninitialized. 440 with self.assertRaisesRegexp(errors.OpError, "uninitialized"): 441 self.evaluate(v0) 442 # We should be able to initialize and run v1 without initializing 443 # v0, even if the variable was created with a control dep on v0. 444 self.evaluate(v1.initializer) 445 self.assertEqual(1, self.evaluate(v1)) 446 # v0 should still be uninitialized. 447 with self.assertRaisesRegexp(errors.OpError, "uninitialized"): 448 self.evaluate(v0) 449 with self.assertRaisesRegexp(errors.OpError, "uninitialized"): 450 self.evaluate(add) 451 # If we initialize v0 we should be able to run 'add'. 452 self.evaluate(v0.initializer) 453 self.evaluate(add) 454 455 # TODO(mihaimaruseac): Not converted to use wrap_function because of 456 # AssertionError: True is not false (last assertFalse) 457 @test_util.run_deprecated_v1 458 def testEnableResourceVariables(self): 459 old = variable_scope._DEFAULT_USE_RESOURCE 460 try: 461 variable_scope.enable_resource_variables() 462 self.assertTrue(isinstance(variables_lib.VariableV1(1.0), 463 resource_variable_ops.ResourceVariable)) 464 variable_scope.disable_resource_variables() 465 self.assertFalse(isinstance(variables_lib.VariableV1(1.0), 466 resource_variable_ops.ResourceVariable)) 467 finally: 468 variable_scope._DEFAULT_USE_RESOURCE = old 469 470 # TODO(mihaimaruseac): Not converted to use wrap_function because of 471 # TypeError: Fetch argument None has invalid type <type 'NoneType'> 472 @test_util.run_deprecated_v1 473 def testControlFlow(self): 474 with self.cached_session() as sess: 475 v0 = variable_scope.get_variable( 476 "v0", [], initializer=init_ops.constant_initializer(0)) 477 var_dict = {} 478 479 # Call get_variable in each of the cond clauses. 480 def var_in_then_clause(): 481 v1 = variable_scope.get_variable( 482 "v1", [1], initializer=init_ops.constant_initializer(1)) 483 var_dict["v1"] = v1 484 return v1 + v0 485 486 def var_in_else_clause(): 487 v2 = variable_scope.get_variable( 488 "v2", [1], initializer=init_ops.constant_initializer(2)) 489 var_dict["v2"] = v2 490 return v2 + v0 491 492 add = control_flow_ops.cond( 493 math_ops.less(v0, 10), var_in_then_clause, var_in_else_clause) 494 v1 = var_dict["v1"] 495 v2 = var_dict["v2"] 496 # We should be able to initialize and run v1 and v2 without initializing 497 # v0, even if the variable was created with a control dep on v0. 498 self.evaluate(v1.initializer) 499 self.assertEqual([1], self.evaluate(v1)) 500 self.evaluate(v2.initializer) 501 self.assertEqual([2], self.evaluate(v2)) 502 # v0 should still be uninitialized. 503 with self.assertRaisesRegexp(errors.OpError, "uninitialized"): 504 self.evaluate(v0) 505 # We should not be able to run 'add' yet. 506 with self.assertRaisesRegexp(errors.OpError, "uninitialized"): 507 self.evaluate(add) 508 # If we initialize v0 we should be able to run 'add'. 509 self.evaluate(v0.initializer) 510 self.evaluate(add) 511 512 # TODO(mihaimaruseac): Not converted to use wrap_function because of 513 # TypeError: Expected tf.group() expected Tensor arguments not 'None' with 514 # type '<type 'NoneType'>'. 515 @test_util.run_in_graph_and_eager_modes 516 def testGetVariableScope(self): 517 # Test the get_variable_scope() function and setting properties of result. 518 init = init_ops.constant_initializer(0.3) 519 with variable_scope.variable_scope("bar"): 520 new_init1 = variable_scope.get_variable_scope().initializer 521 self.assertEqual(new_init1, None) 522 # Check that we can set initializer like this. 523 variable_scope.get_variable_scope().set_initializer(init) 524 v = variable_scope.get_variable("v", []) 525 self.evaluate(variables_lib.variables_initializer([v])) 526 self.assertAllClose(self.evaluate(v.value()), 0.3) 527 if not context.executing_eagerly(): 528 # Check that we can set reuse. 529 variable_scope.get_variable_scope().reuse_variables() 530 with self.assertRaises(ValueError): # Fail, w does not exist yet. 531 variable_scope.get_variable("w", [1]) 532 # Check that the set initializer goes away. 533 new_init = variable_scope.get_variable_scope().initializer 534 self.assertEqual(new_init, None) 535 536 @test_util.run_in_graph_and_eager_modes 537 @run_inside_wrap_function_in_eager_mode 538 def testVarScope(self): 539 with variable_scope.variable_scope("tower4") as tower: 540 self.assertEqual(tower.name, "tower4") 541 with ops.name_scope("scope") as sc: 542 self.assertEqual(sc, "tower4/scope/") 543 544 with variable_scope.variable_scope("tower5"): 545 with variable_scope.variable_scope("bar") as bar: 546 self.assertEqual(bar.name, "tower5/bar") 547 with ops.name_scope("scope") as sc: 548 self.assertEqual(sc, "tower5/bar/scope/") 549 550 with variable_scope.variable_scope("tower6"): 551 with variable_scope.variable_scope(tower, reuse=True) as tower_shared: 552 self.assertEqual(tower_shared.name, "tower4") 553 with ops.name_scope("scope") as sc: 554 self.assertEqual(sc, "tower6/tower4/scope/") 555 556 @test_util.run_in_graph_and_eager_modes 557 @run_inside_wrap_function_in_eager_mode 558 def testVarScopeNameScope(self): 559 with ops.name_scope("testVarScopeNameScope1"): 560 with variable_scope.variable_scope("tower") as tower: 561 with ops.name_scope("scope2") as sc2: 562 self.assertEqual(sc2, "testVarScopeNameScope1/tower/scope2/") 563 if not context.executing_eagerly(): 564 with variable_scope.variable_scope( 565 tower): # Re-entering acts like another "tower". 566 with ops.name_scope("scope2") as sc2: 567 self.assertEqual(sc2, "testVarScopeNameScope1/tower_1/scope2/") 568 with variable_scope.variable_scope( 569 "tower"): # Re-entering by string acts the same. 570 with ops.name_scope("scope2") as sc2: 571 self.assertEqual(sc2, "testVarScopeNameScope1/tower_2/scope2/") 572 573 with ops.name_scope("testVarScopeNameScope2"): 574 with variable_scope.variable_scope("tower"): 575 with ops.name_scope("scope2") as sc2: 576 self.assertEqual(sc2, "testVarScopeNameScope2/tower/scope2/") 577 if not context.executing_eagerly(): 578 with variable_scope.variable_scope(tower): 579 with ops.name_scope("scope2") as sc2: 580 self.assertEqual(sc2, "testVarScopeNameScope2/tower_1/scope2/") 581 582 root_var_scope = variable_scope.get_variable_scope() 583 with ops.name_scope("testVarScopeNameScope3"): 584 with variable_scope.variable_scope(root_var_scope): 585 with ops.name_scope("scope2") as sc2: 586 self.assertEqual(sc2, "testVarScopeNameScope3/scope2/") 587 588 @test_util.run_in_graph_and_eager_modes 589 @run_inside_wrap_function_in_eager_mode 590 def testVarScopeOriginalNameScope(self): 591 with self.cached_session(): 592 with ops.name_scope("scope1"): 593 with variable_scope.variable_scope("tower") as tower: 594 self.assertEqual(tower.original_name_scope, "scope1/tower/") 595 with ops.name_scope("scope2") as sc2: 596 self.assertEqual(sc2, "scope1/tower/scope2/") 597 with ops.name_scope("scope2"): 598 with variable_scope.variable_scope(tower) as tower1: 599 # Re-entering preserves original name scope. 600 self.assertEqual(tower1.original_name_scope, "scope1/tower/") 601 with ops.name_scope("foo") as sc2: 602 self.assertEqual(sc2, "scope2/tower/foo/") 603 # Test re-entering original name scope. 604 with ops.name_scope(tower.original_name_scope): 605 with ops.name_scope("bar") as sc3: 606 self.assertEqual(sc3, "scope1/tower/bar/") 607 with ops.name_scope("scope2"): 608 with variable_scope.variable_scope(tower): 609 with ops.name_scope(tower.original_name_scope): 610 with ops.name_scope("bar") as sc3: 611 self.assertEqual(sc3, "scope1/tower/bar_1/") 612 613 @test_util.run_in_graph_and_eager_modes 614 @run_inside_wrap_function_in_eager_mode 615 def testVarScopeObjectReuse(self): 616 with self.cached_session(): 617 vs = None 618 with variable_scope.variable_scope("jump", reuse=True) as scope: 619 vs = scope 620 621 with variable_scope.variable_scope(vs) as jump: 622 self.assertTrue(jump.reuse) 623 624 with variable_scope.variable_scope(vs, reuse=True) as jump_reuse: 625 self.assertTrue(jump_reuse.reuse) 626 627 with variable_scope.variable_scope(vs, reuse=False) as jump_no_reuse: 628 self.assertTrue(jump_no_reuse.reuse) # Inherited, cannot be undone. 629 630 with variable_scope.variable_scope("jump", reuse=False) as scope: 631 vs = scope 632 633 with variable_scope.variable_scope(vs) as jump: 634 self.assertFalse(jump.reuse) 635 636 with variable_scope.variable_scope(vs, reuse=True) as jump_reuse: 637 self.assertTrue(jump_reuse.reuse) 638 639 with variable_scope.variable_scope(vs, reuse=False) as jump_no_reuse: 640 self.assertFalse(jump_no_reuse.reuse) 641 642 @test_util.run_in_graph_and_eager_modes 643 @run_inside_wrap_function_in_eager_mode 644 def testVarScopeGetOrCreateReuse(self): 645 with self.cached_session(): 646 647 def test_value(value): 648 x = constant_op.constant(value) 649 with variable_scope.variable_scope( 650 "testVarScopeGetOrCreateReuse_bar", 651 reuse=variable_scope.AUTO_REUSE): 652 _ = state_ops.assign(variable_scope.get_variable("var", []), x) 653 with variable_scope.variable_scope( 654 "testVarScopeGetOrCreateReuse_bar", 655 reuse=variable_scope.AUTO_REUSE): 656 _ = variable_scope.get_variable("var", []) 657 self.assertEqual(value, self.evaluate(x)) 658 659 test_value(42.) # Variable is created. 660 test_value(13.) # Variable is reused hereafter. 661 test_value(17.) 662 663 @test_util.run_in_graph_and_eager_modes 664 @run_inside_wrap_function_in_eager_mode 665 def testVarOpScope(self): 666 with self.cached_session(): 667 with ops.name_scope("testVarOpScope1"): 668 with variable_scope.variable_scope("tower", "default", []): 669 self.assertEqual( 670 variable_scope.get_variable("w", []).name, "tower/w:0") 671 with ops.name_scope("testVarOpScope2") as sc2: 672 self.assertEqual(sc2, "testVarOpScope1/tower/testVarOpScope2/") 673 with variable_scope.variable_scope("tower", "default", []): 674 with self.assertRaises(ValueError): 675 variable_scope.get_variable("w", []) 676 with ops.name_scope("testVarOpScope2") as sc2: 677 self.assertEqual(sc2, "testVarOpScope1/tower_1/testVarOpScope2/") 678 679 with ops.name_scope("testVarOpScope2"): 680 with variable_scope.variable_scope(None, "default", []): 681 self.assertEqual( 682 variable_scope.get_variable("w", []).name, "default/w:0") 683 with ops.name_scope("testVarOpScope2") as sc2: 684 self.assertEqual(sc2, "testVarOpScope2/default/testVarOpScope2/") 685 with variable_scope.variable_scope(None, "default", []): 686 self.assertEqual( 687 variable_scope.get_variable("w", []).name, "default_1/w:0") 688 with ops.name_scope("testVarOpScope2") as sc2: 689 self.assertEqual(sc2, "testVarOpScope2/default_1/testVarOpScope2/") 690 691 @test_util.run_in_graph_and_eager_modes 692 @run_inside_wrap_function_in_eager_mode 693 def testVarOpScopeUniqueNamesInterleavedSubstringScopes(self): 694 with self.cached_session(): 695 with variable_scope.variable_scope(None, "defaultScope1"): 696 with variable_scope.variable_scope(None, "layer"): 697 self.assertEqual( 698 variable_scope.get_variable("w", []).name, 699 "defaultScope1/layer/w:0") 700 with variable_scope.variable_scope(None, "defaultScope1"): 701 with variable_scope.variable_scope(None, "layer"): 702 self.assertEqual( 703 variable_scope.get_variable("w", []).name, 704 "defaultScope1_1/layer/w:0") 705 with variable_scope.variable_scope(None, "defaultScope"): 706 with variable_scope.variable_scope(None, "layer"): 707 self.assertEqual( 708 variable_scope.get_variable("w", []).name, 709 "defaultScope/layer/w:0") 710 with variable_scope.variable_scope(None, "defaultScope1"): 711 with variable_scope.variable_scope(None, "layer"): 712 self.assertEqual( 713 variable_scope.get_variable("w", []).name, 714 "defaultScope1_2/layer/w:0") 715 716 @test_util.run_in_graph_and_eager_modes 717 @run_inside_wrap_function_in_eager_mode 718 def testVarOpScopeUniqueNamesWithJump(self): 719 with self.cached_session(): 720 with variable_scope.variable_scope("default") as default: 721 with variable_scope.variable_scope(None, "layer"): 722 self.assertEqual( 723 variable_scope.get_variable("w", []).name, "default/layer/w:0") 724 with variable_scope.variable_scope(None, "layer"): 725 self.assertEqual( 726 variable_scope.get_variable("w", []).name, 727 "default/layer_1/w:0") 728 with variable_scope.variable_scope(default): 729 pass 730 # No matter the jump in the middle, unique numbering continues. 731 with variable_scope.variable_scope(None, "layer"): 732 self.assertEqual( 733 variable_scope.get_variable("w", []).name, 734 "default/layer_2/w:0") 735 736 @test_util.run_in_graph_and_eager_modes 737 @run_inside_wrap_function_in_eager_mode 738 def testVarOpScopeReuse(self): 739 with self.cached_session(): 740 with variable_scope.variable_scope("outer") as outer: 741 with variable_scope.variable_scope("tower", "default", []): 742 self.assertEqual( 743 variable_scope.get_variable("w", []).name, "outer/tower/w:0") 744 with ops.name_scope("scope2") as sc2: 745 self.assertEqual(sc2, "outer/tower/scope2/") 746 with variable_scope.variable_scope(None, "default", []): 747 self.assertEqual( 748 variable_scope.get_variable("w", []).name, "outer/default/w:0") 749 with ops.name_scope("scope2") as sc2: 750 self.assertEqual(sc2, "outer/default/scope2/") 751 752 with variable_scope.variable_scope(outer, reuse=True) as outer: 753 with variable_scope.variable_scope("tower", "default", []): 754 self.assertEqual( 755 variable_scope.get_variable("w", []).name, "outer/tower/w:0") 756 with ops.name_scope("scope2") as sc2: 757 self.assertEqual(sc2, "outer_1/tower/scope2/") 758 with variable_scope.variable_scope(None, "default", []): 759 self.assertEqual( 760 variable_scope.get_variable("w", []).name, "outer/default/w:0") 761 with ops.name_scope("scope2") as sc2: 762 self.assertEqual(sc2, "outer_1/default/scope2/") 763 764 @test_util.run_in_graph_and_eager_modes 765 @run_inside_wrap_function_in_eager_mode 766 def testVarScopeGetVar(self): 767 with self.cached_session(): 768 with variable_scope.variable_scope("root"): 769 with variable_scope.variable_scope("towerA") as tower_a: 770 va = variable_scope.get_variable("v", [1]) 771 self.assertEqual(va.name, "root/towerA/v:0") 772 773 with variable_scope.variable_scope(tower_a, reuse=True): 774 va2 = variable_scope.get_variable("v", [1]) 775 self.assertEqual(va2, va) 776 777 with variable_scope.variable_scope("towerB"): 778 vb = variable_scope.get_variable("v", [1]) 779 self.assertEqual(vb.name, "root/towerB/v:0") 780 781 with self.assertRaises(ValueError): 782 with variable_scope.variable_scope("towerA"): 783 va2 = variable_scope.get_variable("v", [1]) 784 785 with variable_scope.variable_scope("towerA", reuse=True): 786 va2 = variable_scope.get_variable("v", [1]) 787 self.assertEqual(va2, va) 788 789 with variable_scope.variable_scope("foo"): 790 with variable_scope.variable_scope("bar"): 791 v = variable_scope.get_variable("v", [1]) 792 self.assertEqual(v.name, "root/foo/bar/v:0") 793 with variable_scope.variable_scope(tower_a, reuse=True): 794 va3 = variable_scope.get_variable("v", [1]) 795 self.assertEqual(va, va3) 796 797 with self.assertRaises(ValueError): 798 with variable_scope.variable_scope(tower_a, reuse=True): 799 with variable_scope.variable_scope("baz"): 800 variable_scope.get_variable("v", [1]) 801 802 with self.assertRaises(ValueError) as exc: 803 with variable_scope.variable_scope(tower_a, reuse=True): 804 variable_scope.get_variable("v", [2]) # Different shape. 805 self.assertEqual("shape" in str(exc.exception), True) 806 807 with self.assertRaises(ValueError) as exc: 808 with variable_scope.variable_scope(tower_a, reuse=True): 809 variable_scope.get_variable("v", [1], dtype=dtypes.int32) 810 self.assertEqual("dtype" in str(exc.exception), True) 811 812 @test_util.run_in_graph_and_eager_modes 813 @run_inside_wrap_function_in_eager_mode 814 def testVarScopeOuterScope(self): 815 with self.cached_session(): 816 with variable_scope.variable_scope("outer") as outer: 817 pass 818 with variable_scope.variable_scope(outer): 819 self.assertEqual( 820 variable_scope.get_variable("w", []).name, "outer/w:0") 821 with ops.name_scope("scope2") as sc2: 822 self.assertEqual(sc2, "outer_1/scope2/") 823 with variable_scope.variable_scope("default"): 824 self.assertEqual( 825 variable_scope.get_variable("w", []).name, "outer/default/w:0") 826 with ops.name_scope("scope2") as sc2: 827 self.assertEqual(sc2, "outer_1/default/scope2/") 828 829 with variable_scope.variable_scope(outer, reuse=True): 830 self.assertEqual( 831 variable_scope.get_variable("w", []).name, "outer/w:0") 832 with ops.name_scope("scope2") as sc2: 833 self.assertEqual(sc2, "outer_2/scope2/") 834 with variable_scope.variable_scope("default", reuse=True): 835 self.assertEqual( 836 variable_scope.get_variable("w", []).name, "outer/default/w:0") 837 with ops.name_scope("scope2") as sc2: 838 self.assertEqual(sc2, "outer_2/default/scope2/") 839 840 @test_util.run_in_graph_and_eager_modes 841 @run_inside_wrap_function_in_eager_mode 842 def testVarScopeNestedOuterScope(self): 843 with self.cached_session(): 844 with variable_scope.variable_scope("outer") as outer: 845 with variable_scope.variable_scope(outer): 846 self.assertEqual( 847 variable_scope.get_variable("w", []).name, "outer/w:0") 848 with ops.name_scope("scope2") as sc2: 849 self.assertEqual(sc2, "outer/outer/scope2/") 850 with variable_scope.variable_scope("default"): 851 self.assertEqual( 852 variable_scope.get_variable("w", []).name, "outer/default/w:0") 853 with ops.name_scope("scope2") as sc2: 854 self.assertEqual(sc2, "outer/default/scope2/") 855 856 with variable_scope.variable_scope(outer, reuse=True): 857 self.assertEqual( 858 variable_scope.get_variable("w", []).name, "outer/w:0") 859 with ops.name_scope("scope2") as sc2: 860 self.assertEqual(sc2, "outer/outer_1/scope2/") 861 with variable_scope.variable_scope("default", reuse=True): 862 self.assertEqual( 863 variable_scope.get_variable("w", []).name, "outer/default/w:0") 864 with ops.name_scope("scope2") as sc2: 865 self.assertEqual(sc2, "outer/default_1/scope2/") 866 867 @test_util.run_in_graph_and_eager_modes 868 @run_inside_wrap_function_in_eager_mode 869 def testVarOpScopeReuseParam(self): 870 with self.cached_session(): 871 with variable_scope.variable_scope("outer") as outer: 872 with variable_scope.variable_scope("tower", "default", []): 873 self.assertEqual( 874 variable_scope.get_variable("w", []).name, "outer/tower/w:0") 875 with ops.name_scope("scope2") as sc2: 876 self.assertEqual(sc2, "outer/tower/scope2/") 877 with variable_scope.variable_scope(None, "default", []): 878 self.assertEqual( 879 variable_scope.get_variable("w", []).name, "outer/default/w:0") 880 with ops.name_scope("scope2") as sc2: 881 self.assertEqual(sc2, "outer/default/scope2/") 882 883 with variable_scope.variable_scope(outer) as outer: 884 with variable_scope.variable_scope("tower", "default", reuse=True): 885 self.assertEqual( 886 variable_scope.get_variable("w", []).name, "outer/tower/w:0") 887 with ops.name_scope("scope2") as sc2: 888 self.assertEqual(sc2, "outer_1/tower/scope2/") 889 outer.reuse_variables() 890 with variable_scope.variable_scope(None, "default", []): 891 self.assertEqual( 892 variable_scope.get_variable("w", []).name, "outer/default/w:0") 893 with ops.name_scope("scope2") as sc2: 894 self.assertEqual(sc2, "outer_1/default/scope2/") 895 896 @test_util.run_in_graph_and_eager_modes 897 @run_inside_wrap_function_in_eager_mode 898 def testVarOpScopeReuseError(self): 899 with self.cached_session(): 900 with self.assertRaises(ValueError): 901 with variable_scope.variable_scope(None, "default", reuse=True): 902 self.assertEqual( 903 variable_scope.get_variable("w", []).name, "outer/tower/w:0") 904 905 @test_util.run_in_graph_and_eager_modes 906 @run_inside_wrap_function_in_eager_mode 907 def testVarOpScopeOuterScope(self): 908 with self.cached_session(): 909 with variable_scope.variable_scope("outer") as outer: 910 pass 911 with variable_scope.variable_scope(outer, "default", []): 912 self.assertEqual( 913 variable_scope.get_variable("w", []).name, "outer/w:0") 914 with ops.name_scope("scope2") as sc2: 915 self.assertEqual(sc2, "outer_1/scope2/") 916 with variable_scope.variable_scope(None, "default", []): 917 self.assertEqual( 918 variable_scope.get_variable("w", []).name, "outer/default/w:0") 919 with ops.name_scope("scope2") as sc2: 920 self.assertEqual(sc2, "outer_1/default/scope2/") 921 922 with variable_scope.variable_scope(outer, "default", reuse=True): 923 self.assertEqual( 924 variable_scope.get_variable("w", []).name, "outer/w:0") 925 with ops.name_scope("scope2") as sc2: 926 self.assertEqual(sc2, "outer_2/scope2/") 927 outer.reuse_variables() 928 with variable_scope.variable_scope(None, "default", []): 929 self.assertEqual( 930 variable_scope.get_variable("w", []).name, "outer/default/w:0") 931 with ops.name_scope("scope2") as sc2: 932 self.assertEqual(sc2, "outer_2/default/scope2/") 933 934 @test_util.run_in_graph_and_eager_modes 935 @run_inside_wrap_function_in_eager_mode 936 def testVarOpScopeNestedOuterScope(self): 937 with self.cached_session(): 938 with variable_scope.variable_scope("outer") as outer: 939 with variable_scope.variable_scope(outer, "default", []): 940 self.assertEqual( 941 variable_scope.get_variable("w", []).name, "outer/w:0") 942 with ops.name_scope("scope2") as sc2: 943 self.assertEqual(sc2, "outer/outer/scope2/") 944 with variable_scope.variable_scope(None, "default", []): 945 self.assertEqual( 946 variable_scope.get_variable("w", []).name, "outer/default/w:0") 947 with ops.name_scope("scope2") as sc2: 948 self.assertEqual(sc2, "outer/default/scope2/") 949 950 with variable_scope.variable_scope(outer, "default", reuse=True): 951 self.assertEqual( 952 variable_scope.get_variable("w", []).name, "outer/w:0") 953 with ops.name_scope("scope2") as sc2: 954 self.assertEqual(sc2, "outer_1/scope2/") 955 with variable_scope.variable_scope(None, "default", []): 956 self.assertEqual( 957 variable_scope.get_variable("w", []).name, "outer/default/w:0") 958 with ops.name_scope("scope2") as sc2: 959 self.assertEqual(sc2, "outer_1/default/scope2/") 960 961 @test_util.run_in_graph_and_eager_modes 962 @run_inside_wrap_function_in_eager_mode 963 def testBasicWhenAuxiliaryNameScopeIsFalse(self): 964 with self.cached_session(): 965 with variable_scope.variable_scope( 966 "scope", auxiliary_name_scope=False) as scope: 967 self.assertEqual(scope.original_name_scope, "") 968 self.assertEqual( 969 variable_scope.get_variable("w", []).name, "scope/w:0") 970 self.assertEqual(constant_op.constant([], name="c").name, "c:0") 971 with variable_scope.variable_scope(scope, auxiliary_name_scope=False): 972 self.assertEqual(scope.original_name_scope, "") 973 self.assertEqual( 974 variable_scope.get_variable("w1", []).name, "scope/w1:0") 975 self.assertEqual(constant_op.constant([], name="c1").name, "c1:0") 976 # Recheck: new name scope is NOT created before 977 with ops.name_scope("scope"): 978 self.assertEqual(constant_op.constant([], name="c").name, "scope/c:0") 979 980 with variable_scope.variable_scope("outer"): 981 with variable_scope.variable_scope( 982 "inner", auxiliary_name_scope=False) as inner: 983 self.assertEqual(inner.original_name_scope, "outer/") 984 self.assertEqual( 985 variable_scope.get_variable("w", []).name, "outer/inner/w:0") 986 self.assertEqual( 987 constant_op.constant([], name="c").name, "outer/c:0") 988 with variable_scope.variable_scope( 989 inner, auxiliary_name_scope=False) as inner1: 990 self.assertEqual(inner1.original_name_scope, "outer/") 991 self.assertEqual( 992 variable_scope.get_variable("w1", []).name, "outer/inner/w1:0") 993 self.assertEqual( 994 constant_op.constant([], name="c1").name, "outer/c1:0") 995 # Recheck: new name scope is NOT created before 996 with ops.name_scope("inner"): 997 self.assertEqual( 998 constant_op.constant([], name="c").name, "outer/inner/c:0") 999 1000 @test_util.run_in_graph_and_eager_modes 1001 @run_inside_wrap_function_in_eager_mode 1002 def testCreatedByDefaultNameWhenAuxiliaryNameScopeIsFalse(self): 1003 with self.cached_session(): 1004 with variable_scope.variable_scope( 1005 None, default_name="default", auxiliary_name_scope=False) as scope: 1006 self.assertEqual(scope.original_name_scope, "") 1007 self.assertEqual( 1008 variable_scope.get_variable("w", []).name, "default/w:0") 1009 self.assertEqual(constant_op.constant([], name="c").name, "c:0") 1010 # Recheck: new name scope is NOT created before 1011 with ops.name_scope("default"): 1012 self.assertEqual( 1013 constant_op.constant([], name="c").name, "default/c:0") 1014 1015 with variable_scope.variable_scope("outer"): 1016 with variable_scope.variable_scope( 1017 None, default_name="default", 1018 auxiliary_name_scope=False) as inner: 1019 self.assertEqual(inner.original_name_scope, "outer/") 1020 self.assertEqual( 1021 variable_scope.get_variable("w", []).name, "outer/default/w:0") 1022 self.assertEqual( 1023 constant_op.constant([], name="c").name, "outer/c:0") 1024 # Recheck: new name scope is NOT created before 1025 with ops.name_scope("default"): 1026 self.assertEqual( 1027 constant_op.constant([], name="c").name, "outer/default/c:0") 1028 1029 @test_util.run_in_graph_and_eager_modes 1030 @run_inside_wrap_function_in_eager_mode 1031 def testReenterRootScopeWhenAuxiliaryNameScopeIsFalse(self): 1032 with self.cached_session(): 1033 root_scope = variable_scope.get_variable_scope() 1034 with variable_scope.variable_scope( 1035 root_scope, auxiliary_name_scope=False) as scope: 1036 self.assertEqual(scope.original_name_scope, "") 1037 self.assertEqual(variable_scope.get_variable("w", []).name, "w:0") 1038 self.assertEqual(constant_op.constant([], name="c").name, "c:0") 1039 1040 with variable_scope.variable_scope("outer"): 1041 with variable_scope.variable_scope( 1042 root_scope, auxiliary_name_scope=False) as inner: 1043 self.assertEqual(inner.original_name_scope, "") 1044 self.assertEqual(variable_scope.get_variable("w1", []).name, "w1:0") 1045 self.assertEqual( 1046 constant_op.constant([], name="c1").name, "outer/c1:0") 1047 1048 @test_util.run_in_graph_and_eager_modes 1049 @run_inside_wrap_function_in_eager_mode 1050 def testAuxiliaryNameScopeIsInvalid(self): 1051 with self.cached_session(): 1052 with self.assertRaisesRegexp(TypeError, "auxiliary_name_scope"): 1053 with variable_scope.variable_scope( 1054 None, default_name="scope", auxiliary_name_scope="invalid"): 1055 pass 1056 1057 with self.assertRaisesRegexp(TypeError, "auxiliary_name_scope"): 1058 with variable_scope.variable_scope( 1059 "scope", auxiliary_name_scope="invalid"): 1060 pass 1061 1062 with variable_scope.variable_scope("scope") as scope: 1063 pass 1064 with self.assertRaisesRegexp(TypeError, "auxiliary_name_scope"): 1065 with variable_scope.variable_scope( 1066 scope, auxiliary_name_scope="invalid"): 1067 pass 1068 1069 @test_util.run_in_graph_and_eager_modes 1070 @run_inside_wrap_function_in_eager_mode 1071 def testReuseScopeWithoutNameScopeCollision(self): 1072 # Github issue: #13429 1073 with self.cached_session(): 1074 with variable_scope.variable_scope("outer"): 1075 with variable_scope.variable_scope("inner") as inner: 1076 pass 1077 1078 with variable_scope.variable_scope( 1079 inner, auxiliary_name_scope=False) as scope: 1080 with ops.name_scope(scope.original_name_scope): 1081 self.assertEqual( 1082 variable_scope.get_variable("w", []).name, "outer/inner/w:0") 1083 self.assertEqual( 1084 constant_op.constant([], name="c").name, "outer/inner/c:0") 1085 with ops.name_scope("inner"): 1086 self.assertEqual( 1087 constant_op.constant([], name="c").name, "inner/c:0") 1088 1089 with variable_scope.variable_scope("another"): 1090 with variable_scope.variable_scope( 1091 inner, auxiliary_name_scope=False) as scope1: 1092 with ops.name_scope(scope1.original_name_scope): 1093 self.assertEqual( 1094 variable_scope.get_variable("w1", []).name, 1095 "outer/inner/w1:0") 1096 self.assertEqual( 1097 constant_op.constant([], name="c1").name, "outer/inner/c1:0") 1098 with ops.name_scope("inner"): 1099 self.assertEqual( 1100 constant_op.constant([], name="c").name, "another/inner/c:0") 1101 1102 # TODO(mihaimaruseac): Not converted to use wrap_function because of 1103 # obtaining different results in the eager case compared to the graph one 1104 # (different assertions failing after wrapping, in both execution modes) 1105 @test_util.run_in_graph_and_eager_modes 1106 def testGetLocalVar(self): 1107 # Check that local variable respects naming. 1108 with variable_scope.variable_scope("outer") as outer: 1109 with variable_scope.variable_scope(outer, "default", []): 1110 local_var = variable_scope.get_local_variable( 1111 "w", [], collections=["foo"]) 1112 self.assertEqual(local_var.name, "outer/w:0") 1113 1114 if not context.executing_eagerly(): 1115 # Since variable is local, it should be in the local variable collection 1116 # but not the trainable collection. 1117 self.assertIn(local_var, 1118 ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) 1119 self.assertIn(local_var, ops.get_collection("foo")) 1120 self.assertNotIn(local_var, 1121 ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) 1122 # Check that local variable respects `reuse`. 1123 with variable_scope.variable_scope(outer, "default", reuse=True): 1124 self.assertEqual( 1125 variable_scope.get_local_variable("w", []).name, "outer/w:0") 1126 1127 @test_util.run_in_graph_and_eager_modes 1128 @run_inside_wrap_function_in_eager_mode 1129 def testSignatureGetVarVsGetLocalVar(self): 1130 """get_{local,}variable() must take the same list of args.""" 1131 arg_names = tf_inspect.getargspec(variable_scope.get_variable)[0] 1132 local_arg_names = tf_inspect.getargspec( 1133 variable_scope.get_local_variable)[0] 1134 self.assertEqual(arg_names, local_arg_names) 1135 1136 @test_util.run_in_graph_and_eager_modes 1137 @run_inside_wrap_function_in_eager_mode 1138 def testGetVarWithDevice(self): 1139 g = ops.Graph() 1140 varname_type = [] 1141 1142 def device_func(op): 1143 if op.type in ["Variable", "VariableV2", "VarHandleOp"]: 1144 varname_type.append((op.name, op.get_attr("dtype"))) 1145 return "/device:GPU:0" 1146 1147 with g.as_default(): 1148 with ops.device(device_func): 1149 _ = variable_scope.get_variable("x", (100, 200)) 1150 _ = variable_scope.get_variable( 1151 "y", dtype=dtypes.int64, initializer=numpy.arange(73)) 1152 self.assertEqual(varname_type[0], ("x", dtypes.float32)) 1153 self.assertEqual(varname_type[1], ("y", dtypes.int64)) 1154 1155 # TODO(mihaimaruseac): Not converted to use wrap_function because of 1156 # obtaining different results in the eager case compared to the graph one 1157 @test_util.run_deprecated_v1 1158 def testGetCollection(self): 1159 with self.cached_session(): 1160 _ = variable_scope.get_variable("testGetCollection_a", []) 1161 _ = variable_scope.get_variable( 1162 "testGetCollection_b", [], trainable=False) 1163 with variable_scope.variable_scope("testGetCollection_foo_") as scope1: 1164 _ = variable_scope.get_variable("testGetCollection_a", []) 1165 _ = variable_scope.get_variable( 1166 "testGetCollection_b", [], trainable=False) 1167 self.assertEqual([ 1168 v.name 1169 for v in scope1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) 1170 ], ["testGetCollection_foo_/testGetCollection_a:0"]) 1171 self.assertEqual([ 1172 v.name 1173 for v in scope1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 1174 ], [ 1175 "testGetCollection_foo_/testGetCollection_a:0", 1176 "testGetCollection_foo_/testGetCollection_b:0" 1177 ]) 1178 with variable_scope.variable_scope("testGetCollection_foo") as scope2: 1179 _ = variable_scope.get_variable("testGetCollection_a", []) 1180 _ = variable_scope.get_variable( 1181 "testGetCollection_b", [], trainable=False) 1182 self.assertEqual([ 1183 v.name 1184 for v in scope2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) 1185 ], ["testGetCollection_foo/testGetCollection_a:0"]) 1186 self.assertEqual([ 1187 v.name 1188 for v in scope2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 1189 ], [ 1190 "testGetCollection_foo/testGetCollection_a:0", 1191 "testGetCollection_foo/testGetCollection_b:0" 1192 ]) 1193 scope = variable_scope.get_variable_scope() 1194 self.assertEqual([ 1195 v.name for v in scope.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 1196 ], [ 1197 "testGetCollection_a:0", "testGetCollection_b:0", 1198 "testGetCollection_foo_/testGetCollection_a:0", 1199 "testGetCollection_foo_/testGetCollection_b:0", 1200 "testGetCollection_foo/testGetCollection_a:0", 1201 "testGetCollection_foo/testGetCollection_b:0" 1202 ]) 1203 self.assertEqual([ 1204 v.name 1205 for v in scope.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) 1206 ], [ 1207 "testGetCollection_a:0", 1208 "testGetCollection_foo_/testGetCollection_a:0", 1209 "testGetCollection_foo/testGetCollection_a:0" 1210 ]) 1211 1212 # TODO(mihaimaruseac): Not converted to use wrap_function because of 1213 # obtaining different results in the eager case compared to the graph one 1214 @test_util.run_deprecated_v1 1215 def testGetTrainableVariablesWithGetVariable(self): 1216 with self.cached_session(): 1217 _ = variable_scope.get_variable("testGetTrainableVariables_a", []) 1218 with variable_scope.variable_scope( 1219 "testGetTrainableVariables_foo") as scope: 1220 _ = variable_scope.get_variable("testGetTrainableVariables_b", []) 1221 _ = variable_scope.get_variable( 1222 "testGetTrainableVariables_c", [], trainable=False) 1223 1224 # sync `ON_READ` sets trainable=False 1225 _ = variable_scope.get_variable( 1226 "testGetTrainableVariables_d", [], 1227 synchronization=variable_scope.VariableSynchronization.ON_READ) 1228 self.assertEqual( 1229 [v.name for v in scope.trainable_variables()], 1230 ["testGetTrainableVariables_foo/testGetTrainableVariables_b:0"]) 1231 1232 # All other sync values sets trainable=True 1233 _ = variable_scope.get_variable( 1234 "testGetTrainableVariables_e", [], 1235 synchronization=variable_scope.VariableSynchronization.ON_WRITE) 1236 self.assertEqual([v.name for v in scope.trainable_variables()], [ 1237 "testGetTrainableVariables_foo/testGetTrainableVariables_b:0", 1238 "testGetTrainableVariables_foo/testGetTrainableVariables_e:0" 1239 ]) 1240 1241 with self.assertRaisesRegexp( 1242 ValueError, "Synchronization value can be set to " 1243 "VariableSynchronization.ON_READ only for non-trainable variables. " 1244 "You have specified trainable=True and " 1245 "synchronization=VariableSynchronization.ON_READ."): 1246 _ = variable_scope.get_variable( 1247 "testGetTrainableVariables_e", [], 1248 synchronization=variable_scope.VariableSynchronization.ON_READ, 1249 trainable=True) 1250 1251 # TODO(mihaimaruseac): Not converted to use wrap_function because of 1252 # obtaining different results in the eager case compared to the graph one 1253 @test_util.run_deprecated_v1 1254 def testGetTrainableVariablesWithVariable(self): 1255 with self.cached_session(): 1256 _ = variable_scope.variable(1.0, name="testGetTrainableVariables_a") 1257 with variable_scope.variable_scope( 1258 "testGetTrainableVariables_foo") as scope: 1259 _ = variable_scope.variable(1.0, name="testGetTrainableVariables_b") 1260 _ = variable_scope.variable( 1261 1.0, name="testGetTrainableVariables_c", trainable=False) 1262 1263 # sync `ON_READ` sets trainable=False 1264 _ = variable_scope.variable( 1265 1.0, 1266 name="testGetTrainableVariables_d", 1267 synchronization=variable_scope.VariableSynchronization.ON_READ) 1268 self.assertEqual( 1269 [v.name for v in scope.trainable_variables()], 1270 ["testGetTrainableVariables_foo/testGetTrainableVariables_b:0"]) 1271 1272 # All other sync values sets trainable=True 1273 _ = variable_scope.variable( 1274 1.0, 1275 name="testGetTrainableVariables_e", 1276 synchronization=variable_scope.VariableSynchronization.ON_WRITE) 1277 self.assertEqual([v.name for v in scope.trainable_variables()], [ 1278 "testGetTrainableVariables_foo/testGetTrainableVariables_b:0", 1279 "testGetTrainableVariables_foo/testGetTrainableVariables_e:0" 1280 ]) 1281 1282 with self.assertRaisesRegexp( 1283 ValueError, "Synchronization value can be set to " 1284 "VariableSynchronization.ON_READ only for non-trainable variables. " 1285 "You have specified trainable=True and " 1286 "synchronization=VariableSynchronization.ON_READ."): 1287 _ = variable_scope.variable( 1288 1.0, 1289 name="testGetTrainableVariables_e", 1290 synchronization=variable_scope.VariableSynchronization.ON_READ, 1291 trainable=True) 1292 1293 # TODO(mihaimaruseac): Not converted to use wrap_function because of 1294 # obtaining different results in the eager case compared to the graph one 1295 @test_util.run_deprecated_v1 1296 def testGetGlobalVariables(self): 1297 with self.cached_session(): 1298 _ = variable_scope.get_variable("testGetGlobalVariables_a", []) 1299 with variable_scope.variable_scope("testGetGlobalVariables_foo") as scope: 1300 _ = variable_scope.get_variable("testGetGlobalVariables_b", []) 1301 self.assertEqual( 1302 [v.name for v in scope.global_variables()], 1303 ["testGetGlobalVariables_foo/" 1304 "testGetGlobalVariables_b:0"]) 1305 1306 # TODO(mihaimaruseac): Not converted to use wrap_function because of 1307 # obtaining different results in the eager case compared to the graph one 1308 @test_util.run_deprecated_v1 1309 def testGetLocalVariables(self): 1310 with self.cached_session(): 1311 _ = variable_scope.get_variable( 1312 "a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES]) 1313 with variable_scope.variable_scope("foo") as scope: 1314 _ = variable_scope.get_variable( 1315 "b", [], collections=[ops.GraphKeys.LOCAL_VARIABLES]) 1316 _ = variable_scope.get_variable("c", []) 1317 self.assertEqual([v.name for v in scope.local_variables()], ["foo/b:0"]) 1318 1319 @test_util.run_in_graph_and_eager_modes 1320 @run_inside_wrap_function_in_eager_mode 1321 def testGetVariableWithRefDtype(self): 1322 v = variable_scope.get_variable("v", shape=[3, 4], dtype=dtypes.float32) 1323 # Ensure it is possible to do get_variable with a _ref dtype passed in. 1324 _ = variable_scope.get_variable("w", shape=[5, 6], dtype=v.dtype) 1325 1326 @test_util.run_in_graph_and_eager_modes 1327 @run_inside_wrap_function_in_eager_mode 1328 def testGetVariableWithInitializerWhichTakesNoArgs(self): 1329 v = variable_scope.get_variable("foo", initializer=lambda: [2]) 1330 self.assertEqual(v.name, "foo:0") 1331 1332 @test_util.run_in_graph_and_eager_modes 1333 @run_inside_wrap_function_in_eager_mode 1334 def testGetVariableWithInitializerWhichTakesOptionalArgs(self): 1335 v = variable_scope.get_variable("foo", initializer=lambda x=True: [2]) 1336 self.assertEqual(v.name, "foo:0") 1337 1338 @test_util.run_in_graph_and_eager_modes 1339 @run_inside_wrap_function_in_eager_mode 1340 def testGetVariableWithInitializerWhichTakesUnprovidedArgsAndNoShape(self): 1341 with self.assertRaisesRegexp( 1342 ValueError, 1343 "The initializer passed is not valid. It should be a callable with no " 1344 "arguments and the shape should not be provided or an instance of " 1345 "`tf.keras.initializers.*' and `shape` should be fully defined."): 1346 variable_scope.get_variable("foo", initializer=lambda x: [2]) 1347 1348 @test_util.run_in_graph_and_eager_modes 1349 @run_inside_wrap_function_in_eager_mode 1350 def testTwoGraphs(self): 1351 1352 def f(): 1353 g1 = ops.Graph() 1354 g2 = ops.Graph() 1355 with g1.as_default(): 1356 with g2.as_default(): 1357 with variable_scope.variable_scope("_"): 1358 pass 1359 1360 self.assertRaisesRegexp(ValueError, "'_' is not a valid scope name", f) 1361 1362 1363def axis0_into1_partitioner(shape=None, **unused_kwargs): 1364 part = [1] * len(shape) 1365 return part 1366 1367 1368def axis0_into2_partitioner(shape=None, **unused_kwargs): 1369 part = [1] * len(shape) 1370 part[0] = 2 1371 return part 1372 1373 1374def axis0_into3_partitioner(shape=None, **unused_kwargs): 1375 part = [1] * len(shape) 1376 part[0] = 3 1377 return part 1378 1379 1380class VariableScopeWithPartitioningTest(test.TestCase): 1381 1382 # TODO(mihaimaruseac): Not converted to use wrap_function because of 1383 # obtaining different results in the eager case compared to the graph one 1384 @test_util.run_deprecated_v1 1385 def testResultNameMatchesRequested(self): 1386 with variable_scope.variable_scope( 1387 "scope0", partitioner=axis0_into2_partitioner): 1388 v = variable_scope.get_variable("name0", shape=(3, 1, 1)) 1389 self.assertEqual(v.name, "scope0/name0") 1390 v_concat = v.as_tensor() 1391 self.assertEqual(v_concat.name, "scope0/name0:0") 1392 variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 1393 self.assertIn("scope0/name0/part_0:0", [x.name for x in variables]) 1394 self.assertIn("scope0/name0/part_1:0", [x.name for x in variables]) 1395 self.assertNotIn("scope0/name0/part_2:0", [x.name for x in variables]) 1396 1397 @test_util.run_in_graph_and_eager_modes 1398 @run_inside_wrap_function_in_eager_mode 1399 def testBreaksIfPartitioningChanges(self): 1400 with variable_scope.variable_scope( 1401 "scope0", partitioner=axis0_into2_partitioner): 1402 variable_scope.get_variable("name0", shape=(3, 1, 1)) 1403 1404 with variable_scope.variable_scope( 1405 "scope0", partitioner=axis0_into3_partitioner, reuse=True): 1406 with self.assertRaisesRegexp( 1407 ValueError, 1408 "Trying to reuse partitioned variable .* but specified partitions " 1409 ".* and found partitions .*"): 1410 variable_scope.get_variable("name0", shape=(3, 1, 1)) 1411 1412 with variable_scope.variable_scope( 1413 "scope0", partitioner=axis0_into1_partitioner, reuse=True): 1414 with self.assertRaisesRegexp( 1415 ValueError, 1416 "Trying to reuse partitioned variable .* but specified partitions " 1417 ".* and found partitions .*"): 1418 variable_scope.get_variable("name0", shape=(3, 1, 1)) 1419 1420 @test_util.run_in_graph_and_eager_modes 1421 @run_inside_wrap_function_in_eager_mode 1422 def testReturnsExistingConcatenatedValueIfReuse(self): 1423 with variable_scope.variable_scope( 1424 "scope0", partitioner=axis0_into2_partitioner): 1425 v_concat = variable_scope.get_variable("name0", shape=(3, 1, 1)) 1426 variable_scope.get_variable_scope().reuse_variables() 1427 v_concat_2 = variable_scope.get_variable("name0", shape=(3, 1, 1)) 1428 self.assertEqual(v_concat, v_concat_2) 1429 1430 @test_util.run_in_graph_and_eager_modes 1431 @run_inside_wrap_function_in_eager_mode 1432 def testAllowsReuseWithoutPartitioner(self): 1433 with variable_scope.variable_scope( 1434 "scope0", partitioner=axis0_into2_partitioner): 1435 v = variable_scope.get_variable("name0", shape=(3, 1, 1)) 1436 with variable_scope.variable_scope("scope0", reuse=True): 1437 v_reused = variable_scope.get_variable("name0") 1438 self.assertEqual(v, v_reused) 1439 1440 def testNoReuseInEagerByDefault(self): 1441 with context.eager_mode(): 1442 with variable_scope.variable_scope( 1443 "scope0", partitioner=axis0_into2_partitioner): 1444 v1 = variable_scope.get_variable("name0", shape=(3, 1, 1)) 1445 v2 = variable_scope.get_variable("name0", shape=(3, 1, 1)) 1446 self.assertIsNot(v1, v2) 1447 1448 @test_util.run_in_graph_and_eager_modes 1449 @run_inside_wrap_function_in_eager_mode 1450 def testPropagatePartitionerOnReopening(self): 1451 with variable_scope.variable_scope( 1452 "scope0", partitioner=axis0_into2_partitioner) as vs: 1453 self.assertEqual(axis0_into2_partitioner, vs.partitioner) 1454 with variable_scope.variable_scope(vs) as vs1: 1455 self.assertEqual(axis0_into2_partitioner, vs1.partitioner) 1456 1457 # TODO(mihaimaruseac): Not converted to use wrap_function because of 1458 # obtaining different results in the eager case compared to the graph one 1459 @test_util.run_deprecated_v1 1460 def testScalarIgnoresPartitioner(self): 1461 with variable_scope.variable_scope( 1462 "scope0", partitioner=axis0_into2_partitioner): 1463 v = variable_scope.get_variable("name0", shape=()) 1464 self.assertEqual(v.name, "scope0/name0:0") 1465 variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 1466 self.assertIn("scope0/name0:0", [x.name for x in variables]) 1467 1468 def _testPartitionConcatenatesAlongCorrectAxis(self, use_resource): 1469 def _part_axis_0(**unused_kwargs): 1470 return (2, 1, 1) 1471 1472 def _part_axis_1(**unused_kwargs): 1473 return (1, 2, 1) 1474 1475 with variable_scope.variable_scope("root", use_resource=use_resource): 1476 v0 = variable_scope.get_variable( 1477 "n0", shape=(2, 2, 2), partitioner=_part_axis_0) 1478 v1 = variable_scope.get_variable( 1479 "n1", shape=(2, 2, 2), partitioner=_part_axis_1) 1480 1481 self.assertEqual(v0.get_shape(), (2, 2, 2)) 1482 self.assertEqual(v1.get_shape(), (2, 2, 2)) 1483 1484 n0_0 = list(v0)[0] 1485 n0_1 = list(v0)[1] 1486 self.assertEqual(n0_0.get_shape(), (1, 2, 2)) 1487 self.assertEqual(n0_1.get_shape(), (1, 2, 2)) 1488 1489 n1_0 = list(v1)[0] 1490 n1_1 = list(v1)[1] 1491 self.assertEqual(n1_0.get_shape(), (2, 1, 2)) 1492 self.assertEqual(n1_1.get_shape(), (2, 1, 2)) 1493 1494 @test_util.run_in_graph_and_eager_modes 1495 @run_inside_wrap_function_in_eager_mode 1496 def testPartitionConcatenatesAlongCorrectAxis(self): 1497 self._testPartitionConcatenatesAlongCorrectAxis(use_resource=False) 1498 1499 @test_util.run_in_graph_and_eager_modes 1500 @run_inside_wrap_function_in_eager_mode 1501 def testPartitionConcatenatesAlongCorrectAxisResource(self): 1502 self._testPartitionConcatenatesAlongCorrectAxis(use_resource=True) 1503 1504 def testPartitionConcatenatesAlongCorrectAxisResourceInEager(self): 1505 with context.eager_mode(): 1506 self._testPartitionConcatenatesAlongCorrectAxis(use_resource=True) 1507 1508 1509class VariableScopeWithCustomGetterTest(test.TestCase): 1510 1511 @test_util.run_in_graph_and_eager_modes 1512 @run_inside_wrap_function_in_eager_mode 1513 def testNonCallableGetterFails(self): 1514 with self.assertRaisesRegexp(ValueError, 1515 r"custom_getter .* not callable:"): 1516 with variable_scope.variable_scope("scope0", custom_getter=3): 1517 variable_scope.get_variable("name0") 1518 with self.assertRaisesRegexp(ValueError, 1519 r"custom_getter .* not callable:"): 1520 variable_scope.get_variable("name0", custom_getter=3) 1521 1522 @test_util.run_in_graph_and_eager_modes 1523 @run_inside_wrap_function_in_eager_mode 1524 def testNoSideEffectsWithIdentityCustomGetter(self): 1525 called = [0] 1526 1527 def custom_getter(getter, *args, **kwargs): 1528 called[0] += 1 1529 return getter(*args, **kwargs) 1530 1531 with variable_scope.variable_scope( 1532 "scope", custom_getter=custom_getter) as scope: 1533 v = variable_scope.get_variable("v", [1]) 1534 with variable_scope.variable_scope(scope, reuse=True): 1535 v2 = variable_scope.get_variable("v", [1]) 1536 with variable_scope.variable_scope("new_scope") as new_scope: 1537 v3 = variable_scope.get_variable("v3", [1]) 1538 with variable_scope.variable_scope( 1539 new_scope, reuse=True, custom_getter=custom_getter): 1540 v4 = variable_scope.get_variable("v3", [1]) 1541 1542 self.assertEqual(v, v2) 1543 self.assertEqual(v3, v4) 1544 self.assertEqual(3, called[0]) # skipped one in the first new_scope 1545 1546 @test_util.run_in_graph_and_eager_modes 1547 @run_inside_wrap_function_in_eager_mode 1548 def testSynchronizationAndAggregationWithCustomGetter(self): 1549 called = [0] 1550 synchronization = variable_scope.VariableSynchronization.AUTO 1551 aggregation = variable_scope.VariableAggregation.NONE 1552 1553 def custom_getter(getter, *args, **kwargs): 1554 called[0] += 1 1555 1556 # Verify synchronization and aggregation kwargs are as expected. 1557 self.assertEqual(kwargs["synchronization"], synchronization) 1558 self.assertEqual(kwargs["aggregation"], aggregation) 1559 return getter(*args, **kwargs) 1560 1561 with variable_scope.variable_scope("scope", custom_getter=custom_getter): 1562 variable_scope.get_variable("v", [1]) 1563 self.assertEqual(1, called[0]) 1564 1565 with variable_scope.variable_scope("scope", custom_getter=custom_getter): 1566 synchronization = variable_scope.VariableSynchronization.ON_READ 1567 aggregation = variable_scope.VariableAggregation.MEAN 1568 variable_scope.get_variable( 1569 "v1", [1], synchronization=synchronization, aggregation=aggregation) 1570 1571 self.assertEqual(2, called[0]) 1572 1573 @test_util.run_in_graph_and_eager_modes 1574 @run_inside_wrap_function_in_eager_mode 1575 def testCustomGetterWithReuse(self): 1576 # Custom getter can choose to behave differently on reused variables. 1577 def custom_getter(getter, *args, **kwargs): 1578 var = getter(*args, **kwargs) 1579 if kwargs["reuse"]: 1580 # This can be used, e.g., for changing the caching device if needed. 1581 return array_ops.identity(var, name="reused") 1582 else: 1583 return array_ops.identity(var, name="not_reused") 1584 1585 with variable_scope.variable_scope( 1586 "scope", custom_getter=custom_getter) as scope: 1587 v = variable_scope.get_variable("v", [1]) 1588 with variable_scope.variable_scope(scope, reuse=True): 1589 v2 = variable_scope.get_variable("v", [1]) 1590 1591 self.assertEqual(v.name, "not_reused:0") 1592 self.assertEqual(v2.name, "reused:0") 1593 1594 # TODO(mihaimaruseac): Not converted to use wrap_function because of 1595 # ValueError: Fetch argument <tf.Tensor 'custom_getter/add:0' shape=(1, 2, 3) 1596 # dtype=float32> cannot be interpreted as a Tensor. (Tensor 1597 # Tensor("custom_getter/add:0", shape=(1, 2, 3), dtype=float32) is not an 1598 # element of this graph.) 1599 @test_util.run_deprecated_v1 1600 def testGetterThatCreatesTwoVariablesAndSumsThem(self): 1601 1602 def custom_getter(getter, name, *args, **kwargs): 1603 g_0 = getter("%s/0" % name, *args, **kwargs) 1604 g_1 = getter("%s/1" % name, *args, **kwargs) 1605 with ops.name_scope("custom_getter"): 1606 return g_0 + g_1 1607 1608 with variable_scope.variable_scope("scope", custom_getter=custom_getter): 1609 v = variable_scope.get_variable("v", [1, 2, 3]) 1610 1611 self.assertEqual([1, 2, 3], v.get_shape()) 1612 true_vars = variables_lib.trainable_variables() 1613 self.assertEqual(2, len(true_vars)) 1614 self.assertEqual("scope/v/0:0", true_vars[0].name) 1615 self.assertEqual("scope/v/1:0", true_vars[1].name) 1616 self.assertEqual("custom_getter/add:0", v.name) 1617 with self.cached_session() as sess: 1618 variables_lib.global_variables_initializer().run() 1619 np_vars, np_v = self.evaluate([true_vars, v]) 1620 self.assertAllClose(np_v, sum(np_vars)) 1621 1622 # TODO(mihaimaruseac): Not converted to use wrap_function because of 1623 # ValueError: Fetch argument <tf.Tensor 'sum_getter_2/add:0' shape=(1, 2, 3) 1624 # dtype=float32> cannot be interpreted as a Tensor. (Tensor 1625 # Tensor("sum_getter_2/add:0", shape=(1, 2, 3), dtype=float32) is not an 1626 # element of this graph.) 1627 @test_util.run_deprecated_v1 1628 def testNestedCustomGetters(self): 1629 1630 def sum_getter(getter, name, *args, **kwargs): 1631 g_0 = getter("%s/sum_0" % name, *args, **kwargs) 1632 g_1 = getter("%s/sum_1" % name, *args, **kwargs) 1633 with ops.name_scope("sum_getter"): 1634 return g_0 + g_1 1635 1636 def prod_getter(getter, name, *args, **kwargs): 1637 g_0 = getter("%s/prod_0" % name, *args, **kwargs) 1638 g_1 = getter("%s/prod_1" % name, *args, **kwargs) 1639 with ops.name_scope("prod_getter"): 1640 return g_0 * g_1 1641 1642 with variable_scope.variable_scope("prod_scope", custom_getter=prod_getter): 1643 with variable_scope.variable_scope("sum_scope", custom_getter=sum_getter): 1644 with variable_scope.variable_scope( 1645 "inner_sum_scope", custom_getter=sum_getter): 1646 # take sums of sums of products 1647 v = variable_scope.get_variable("v", [1, 2, 3]) 1648 1649 self.assertEqual([1, 2, 3], v.get_shape()) 1650 true_vars = variables_lib.trainable_variables() 1651 self.assertEqual(8, len(true_vars)) 1652 template = ( 1653 "prod_scope/sum_scope/inner_sum_scope/v/sum_%d/sum_%d/prod_%d:0") 1654 self.assertEqual(template % (0, 0, 0), true_vars[0].name) 1655 self.assertEqual(template % (0, 0, 1), true_vars[1].name) 1656 self.assertEqual(template % (0, 1, 0), true_vars[2].name) 1657 self.assertEqual(template % (0, 1, 1), true_vars[3].name) 1658 self.assertEqual(template % (1, 0, 0), true_vars[4].name) 1659 self.assertEqual(template % (1, 0, 1), true_vars[5].name) 1660 self.assertEqual(template % (1, 1, 0), true_vars[6].name) 1661 self.assertEqual(template % (1, 1, 1), true_vars[7].name) 1662 1663 with self.cached_session() as sess: 1664 variables_lib.global_variables_initializer().run() 1665 np_vars, np_v = self.evaluate([true_vars, v]) 1666 # take products of sums of products 1667 self.assertAllClose( 1668 np_v, (((np_vars[0] * np_vars[1]) + (np_vars[2] * np_vars[3])) + ( 1669 (np_vars[4] * np_vars[5]) + (np_vars[6] * np_vars[7])))) 1670 1671 @test_util.run_in_graph_and_eager_modes 1672 @run_inside_wrap_function_in_eager_mode 1673 def testVariableCreator(self): 1674 variable_names = [] 1675 1676 def creator_a(next_creator, **kwargs): 1677 variable_names.append(kwargs.get("name", "")) 1678 return next_creator(**kwargs) 1679 1680 def creator_b(next_creator, **kwargs): 1681 kwargs["name"] = "forced_name" 1682 return next_creator(**kwargs) 1683 1684 with variable_scope.variable_creator_scope(creator_a): 1685 with variable_scope.variable_creator_scope(creator_b): 1686 variable_scope.variable(1.0, name="one_name") 1687 1688 self.assertEqual(variable_names[0], "forced_name") 1689 1690 called = [False] 1691 1692 def creater_c(next_creator, **kwargs): 1693 called[0] = True 1694 self.assertEqual(kwargs["synchronization"], 1695 variable_scope.VariableSynchronization.ON_WRITE) 1696 self.assertEqual(kwargs["aggregation"], 1697 variable_scope.VariableAggregation.MEAN) 1698 return next_creator(**kwargs) 1699 1700 with variable_scope.variable_creator_scope(creater_c): 1701 variable_scope.get_variable( 1702 "v", [], 1703 synchronization=variable_scope.VariableSynchronization.ON_WRITE, 1704 aggregation=variable_scope.VariableAggregation.MEAN) 1705 self.assertTrue(called[0]) 1706 1707 1708class PartitionInfoTest(test.TestCase): 1709 1710 @test_util.run_in_graph_and_eager_modes 1711 @run_inside_wrap_function_in_eager_mode 1712 def testConstructorChecks(self): 1713 # Invalid arg types. 1714 with self.assertRaises(TypeError): 1715 variable_scope._PartitionInfo(full_shape=None, var_offset=[0, 1]) 1716 with self.assertRaises(TypeError): 1717 variable_scope._PartitionInfo(full_shape=[0, 1], var_offset=None) 1718 with self.assertRaises(TypeError): 1719 variable_scope._PartitionInfo(full_shape="foo", var_offset=[0, 1]) 1720 with self.assertRaises(TypeError): 1721 variable_scope._PartitionInfo(full_shape=[0, 1], var_offset="foo") 1722 1723 # full_shape and var_offset must have same length. 1724 with self.assertRaises(ValueError): 1725 variable_scope._PartitionInfo(full_shape=[0, 1], var_offset=[0]) 1726 # Offset must always be less than shape. 1727 with self.assertRaises(ValueError): 1728 variable_scope._PartitionInfo(full_shape=[1, 1], var_offset=[0, 1]) 1729 1730 @test_util.run_in_graph_and_eager_modes 1731 @run_inside_wrap_function_in_eager_mode 1732 def testSingleOffset(self): 1733 partition_info = variable_scope._PartitionInfo( 1734 full_shape=[9, 3], var_offset=[4, 0]) 1735 self.assertEqual(4, partition_info.single_offset([1, 3])) 1736 1737 # Tests when the variable isn't partitioned at all. 1738 partition_info = variable_scope._PartitionInfo( 1739 full_shape=[9, 3], var_offset=[0, 0]) 1740 self.assertEqual(0, partition_info.single_offset([9, 3])) 1741 1742 @test_util.run_in_graph_and_eager_modes 1743 @run_inside_wrap_function_in_eager_mode 1744 def testSingleSliceDim(self): 1745 partition_info = variable_scope._PartitionInfo( 1746 full_shape=[9, 3], var_offset=[4, 0]) 1747 # Invalid shape. 1748 with self.assertRaises(TypeError): 1749 partition_info.single_slice_dim(None) 1750 1751 # Rank of shape differs from full_shape. 1752 with self.assertRaises(ValueError): 1753 partition_info.single_slice_dim([1, 2, 3]) 1754 1755 # Shape is too large given var_offset (4+6 > 9). 1756 with self.assertRaises(ValueError): 1757 partition_info.single_slice_dim([6, 3]) 1758 1759 # Multiple possible slice dim from shape. 1760 with self.assertRaises(ValueError): 1761 partition_info.single_slice_dim([1, 1]) 1762 1763 partition_info = variable_scope._PartitionInfo( 1764 full_shape=[9, 3], var_offset=[0, 0]) 1765 self.assertEqual(1, partition_info.single_slice_dim([9, 2])) 1766 partition_info = variable_scope._PartitionInfo( 1767 full_shape=[9, 3], var_offset=[4, 0]) 1768 self.assertEqual(0, partition_info.single_slice_dim([2, 3])) 1769 1770 1771class VariableScopeMultithreadedTest(test.TestCase): 1772 1773 @test_util.run_in_graph_and_eager_modes 1774 @run_inside_wrap_function_in_eager_mode 1775 def testTwoThreadsDisjointScopeEntry(self): 1776 1777 def thread_fn(i, graph): 1778 with graph.as_default(): 1779 with variable_scope.variable_scope("foo"): 1780 if i == 0: 1781 v = variable_scope.get_variable("v", []) 1782 self.assertEquals("foo/v:0", v.name) 1783 else: 1784 # Any thread after the first one should fail to create variable 1785 # with the same name. 1786 with self.assertRaises(ValueError): 1787 variable_scope.get_variable("v", []) 1788 1789 graph = ops.get_default_graph() 1790 threads = [ 1791 threading.Thread(target=thread_fn, args=( 1792 i, 1793 graph, 1794 )) for i in range(2) 1795 ] 1796 1797 threads[0].start() 1798 # Allow thread 0 to finish before starting thread 1. 1799 threads[0].join() 1800 threads[1].start() 1801 threads[1].join() 1802 1803 @test_util.run_in_graph_and_eager_modes 1804 @run_inside_wrap_function_in_eager_mode 1805 def testTwoThreadsNestedScopeEntry(self): 1806 1807 def thread_fn(i, graph, run_event, pause_event): 1808 with graph.as_default(): 1809 with variable_scope.variable_scope("foo"): 1810 if i == 0: 1811 v = variable_scope.get_variable("v", []) 1812 self.assertEquals("foo/v:0", v.name) 1813 else: 1814 # Any thread after the first one should fail to create variable 1815 # with the same name. 1816 with self.assertRaises(ValueError): 1817 variable_scope.get_variable("v", []) 1818 pause_event.set() 1819 run_event.wait() 1820 1821 graph = ops.get_default_graph() 1822 run_events = [threading.Event() for _ in range(2)] 1823 pause_events = [threading.Event() for _ in range(2)] 1824 threads = [ 1825 threading.Thread( 1826 target=thread_fn, args=(i, graph, run_events[i], pause_events[i])) 1827 for i in range(2) 1828 ] 1829 1830 # Start first thread. 1831 threads[0].start() 1832 pause_events[0].wait() 1833 # Start next thread once the first thread has paused. 1834 threads[1].start() 1835 pause_events[1].wait() 1836 # Resume both threads. 1837 run_events[0].set() 1838 run_events[1].set() 1839 threads[0].join() 1840 threads[1].join() 1841 1842 @test_util.run_in_graph_and_eager_modes 1843 @run_inside_wrap_function_in_eager_mode 1844 def testReenterMainScope(self): 1845 1846 def thread_fn(graph, main_thread_scope): 1847 with graph.as_default(): 1848 # Variable created with main scope will have prefix "main". 1849 with variable_scope.variable_scope(main_thread_scope): 1850 with variable_scope.variable_scope("foo"): 1851 v = variable_scope.get_variable("v", []) 1852 self.assertEquals("main/foo/v:0", v.name) 1853 1854 # Variable created outside main scope will not have prefix "main". 1855 with variable_scope.variable_scope("bar"): 1856 v = variable_scope.get_variable("v", []) 1857 self.assertEquals("bar/v:0", v.name) 1858 1859 graph = ops.get_default_graph() 1860 with variable_scope.variable_scope("main") as main_thread_scope: 1861 thread = threading.Thread( 1862 target=thread_fn, args=(graph, main_thread_scope)) 1863 thread.start() 1864 thread.join() 1865 1866 1867if __name__ == "__main__": 1868 test.main() 1869