• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for variable store."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import gc
22import threading
23
24import numpy
25
26from tensorflow.python.eager import context
27from tensorflow.python.eager import function
28from tensorflow.python.eager import wrap_function
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import errors
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import test_util
34from tensorflow.python.layers import core as core_layers
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import init_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import resource_variable_ops
40from tensorflow.python.ops import state_ops
41from tensorflow.python.ops import variable_scope
42from tensorflow.python.ops import variables as variables_lib
43from tensorflow.python.platform import test
44from tensorflow.python.util import compat
45from tensorflow.python.util import tf_inspect
46
47
48def run_inside_wrap_function_in_eager_mode(graph_function):
49  """Decorator to execute the same graph code in eager and graph modes.
50
51  In graph mode, we just execute the graph_function passed as argument. In eager
52  mode, we wrap the function using wrap_function and then execute the wrapped
53  result.
54
55  Args:
56    graph_function: python function containing graph code to be wrapped
57
58  Returns:
59    decorated function
60  """
61  def wrap_and_execute(self):
62    if context.executing_eagerly():
63      wrapped = wrap_function.wrap_function(graph_function, [self])
64      # use the wrapped graph function
65      wrapped()
66    else:
67      # use the original function
68      graph_function(self)
69  return wrap_and_execute
70
71
72class VariableScopeTest(test.TestCase):
73
74  def tearDown(self):
75    gc.collect()
76    # This will only contain uncollectable garbage, i.e. reference cycles
77    # involving objects with __del__ defined.
78    self.assertEqual(0, len(gc.garbage))
79
80  @test_util.run_in_graph_and_eager_modes
81  @run_inside_wrap_function_in_eager_mode
82  def testGetVar(self):
83    vs = variable_scope._get_default_variable_store()
84    v = vs.get_variable("v", [1])
85    v1 = vs.get_variable("v", [1])
86    self.assertIs(v, v1)
87
88  @test_util.run_in_graph_and_eager_modes
89  @run_inside_wrap_function_in_eager_mode
90  def testResource(self):
91    vs = variable_scope._get_default_variable_store()
92    v1 = vs.get_variable("v", [1], use_resource=True)
93    self.assertTrue(isinstance(v1, resource_variable_ops.ResourceVariable))
94
95  @test_util.run_in_graph_and_eager_modes
96  @run_inside_wrap_function_in_eager_mode
97  def testNameExists(self):
98    vs = variable_scope._get_default_variable_store()
99    # No check by default, so we can both create and get existing names.
100    v = vs.get_variable("v", [1])
101    v1 = vs.get_variable("v", [1])
102    self.assertIs(v, v1)
103
104    # When reuse is False, we fail when variables are already there.
105    vs.get_variable("w", [1], reuse=False)  # That's ok.
106    with self.assertRaises(ValueError):
107      vs.get_variable("v", [1], reuse=False)  # That fails.
108    # When reuse is True, we fail when variables are new.
109    vs.get_variable("v", [1], reuse=True)  # That's ok.
110    with self.assertRaises(ValueError):
111      vs.get_variable("u", [1], reuse=True)  # That fails.
112
113  @test_util.run_in_graph_and_eager_modes
114  @run_inside_wrap_function_in_eager_mode
115  def testNamelessStore(self):
116    vs = variable_scope._get_default_variable_store()
117    vs.get_variable("v1", [2])
118    vs.get_variable("v2", [2])
119    expected_names = ["%s:0" % name for name in ["v1", "v2"]]
120    self.assertEqual(
121        set(expected_names), set(v.name for v in vs._vars.values()))
122
123  # TODO(mihaimaruseac): Not converted to use wrap_function because of
124  # TypeError: Expected tf.group() expected Tensor arguments not 'None' with
125  # type '<type 'NoneType'>'
126  @test_util.run_in_graph_and_eager_modes
127  def testVarScopeInitializer(self):
128    init = init_ops.constant_initializer(0.3)
129    with variable_scope.variable_scope("tower0") as tower:
130      with variable_scope.variable_scope("foo", initializer=init):
131        v = variable_scope.get_variable("v", [])
132        self.evaluate(variables_lib.variables_initializer([v]))
133        self.assertAllClose(self.evaluate(v.value()), 0.3)
134      with variable_scope.variable_scope(tower, initializer=init):
135        w = variable_scope.get_variable("w", [])
136        self.evaluate(variables_lib.variables_initializer([w]))
137        self.assertAllClose(self.evaluate(w.value()), 0.3)
138
139  @test_util.run_in_graph_and_eager_modes
140  @run_inside_wrap_function_in_eager_mode
141  def testVarScopeConstraint(self):
142    constraint = lambda x: 0. * x
143    with variable_scope.variable_scope("tower1") as tower:
144      with variable_scope.variable_scope("foo", constraint=constraint):
145        v = variable_scope.get_variable("v", [])
146        self.assertEqual(v.constraint, constraint)
147      with variable_scope.variable_scope(tower, constraint=constraint):
148        w = variable_scope.get_variable("w", [])
149        self.assertEqual(w.constraint, constraint)
150
151  @test_util.run_in_graph_and_eager_modes
152  @run_inside_wrap_function_in_eager_mode
153  def testVarScopeNestingError(self):
154    with variable_scope.variable_scope("aa"):
155      scope = variable_scope.variable_scope("bb")
156      scope.__enter__()
157      with variable_scope.variable_scope("cc"):
158        with self.assertRaises(RuntimeError):
159          scope.__exit__(None, None, None)
160      scope.__exit__(None, None, None)
161
162  # TODO(mihaimaruseac): Not converted to use wrap_function because of
163  # TypeError: Fetch argument <tf.Variable 'string:0' shape=() dtype=string>
164  # has invalid type <class '...ResourceVariable'>, must be a string or Tensor.
165  # (Can not convert a ResourceVariable into a Tensor or Operation.)
166  @test_util.run_deprecated_v1
167  def testStringDefaultInitializer(self):
168    with self.cached_session():
169      v = variable_scope.get_variable("string", shape=[], dtype=dtypes.string)
170      variables_lib.global_variables_initializer().run()
171      self.assertAllEqual(compat.as_bytes(self.evaluate(v)), b"")
172
173  @test_util.run_in_graph_and_eager_modes
174  @run_inside_wrap_function_in_eager_mode
175  def testVarScopeDType(self):
176    with variable_scope.variable_scope("tower2") as tower:
177      with variable_scope.variable_scope("foo", dtype=dtypes.float16):
178        v = variable_scope.get_variable("v", [])
179        self.assertEqual(v.dtype.base_dtype, dtypes.float16)
180      with variable_scope.variable_scope(tower, dtype=dtypes.float16):
181        w = variable_scope.get_variable("w", [])
182        self.assertEqual(w.dtype.base_dtype, dtypes.float16)
183
184  def testGetVariableInGraphNestedUnderEagerContext(self):
185    with context.eager_mode():
186
187      @function.defun
188      def f():
189        v = variable_scope.get_variable("should_be_resource", [])
190        self.assertEqual(type(v), resource_variable_ops.ResourceVariable)
191
192      f()
193
194  def testEagerVariableStore(self):
195    with context.eager_mode():
196      store = variable_scope.EagerVariableStore()
197      with store.as_default():
198        v = variable_scope.get_variable("v", shape=(), trainable=True)
199        w = variable_scope.get_variable("w", shape=(), trainable=False)
200
201      self.assertTrue(v in store.variables())
202      self.assertTrue(w in store.variables())
203      self.assertTrue(v in store.trainable_variables())
204      self.assertFalse(w in store.trainable_variables())
205      self.assertFalse(v in store.non_trainable_variables())
206      self.assertTrue(w in store.non_trainable_variables())
207
208      # Test copying.
209      new_store = store.copy()
210      with new_store.as_default():
211        new_v = variable_scope.get_variable("v")
212        new_w = variable_scope.get_variable("w")
213      self.assertEqual(new_v.numpy(), v.numpy())
214      self.assertEqual(new_w.numpy(), w.numpy())
215      self.assertTrue(new_v in new_store.variables())
216      self.assertTrue(new_w in new_store.variables())
217      self.assertTrue(new_v in new_store.trainable_variables())
218      self.assertFalse(new_w in new_store.trainable_variables())
219      self.assertFalse(new_v in new_store.non_trainable_variables())
220      self.assertTrue(new_w in new_store.non_trainable_variables())
221
222      # Check that variables are separate instances.
223      for v in store.variables():
224        v.assign(-1)
225      for v in new_store.variables():
226        v.assign(1)
227      for v in store.variables():
228        self.assertEqual(v.numpy(), -1)
229      for v in new_store.variables():
230        self.assertEqual(v.numpy(), 1)
231
232  def testEagerVariableStoreWithEagerDefun(self):
233    with context.eager_mode():
234
235      @function.defun
236      def f():
237        x = constant_op.constant([[2.0]])
238        d1 = core_layers.Dense(
239            1, name="my_dense", kernel_initializer=init_ops.ones_initializer())
240        _ = d1(x)  # create variables
241        self.assertEqual(len(d1.variables), 2)
242        v1, v2 = d1.variables
243        d2 = core_layers.Dense(
244            1,
245            name="my_dense",
246            kernel_initializer=init_ops.ones_initializer(),
247            _reuse=True)
248        _ = d2(x)
249        self.assertEqual(len(d2.variables), 2)
250        v3, v4 = d2.variables
251        self.assertIs(v1, v3)
252        self.assertIs(v2, v4)
253
254      f()
255
256  # TODO(mihaimaruseac): Not converted to use wrap_function because of
257  # obtaining different results in the eager case compared to the graph one
258  @test_util.run_in_graph_and_eager_modes
259  def testEagerVariablesStoreAddsToCollections(self):
260    store = variable_scope.EagerVariableStore()
261    with store.as_default():
262      trainable = variable_scope.get_variable("v1", [], trainable=True)
263      not_trainable = variable_scope.get_variable("v2", [], trainable=False)
264      concat = variable_scope.get_variable(
265          "v3", [], collections=[ops.GraphKeys.CONCATENATED_VARIABLES])
266      self.assertEqual(
267          ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES),
268          [trainable, not_trainable])
269      self.assertEqual(
270          ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES),
271          [trainable, concat])
272      self.assertEqual(
273          ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES), [concat])
274
275  def testEagerVariablesOutsideStoreNotAddedToCollections(self):
276    with context.eager_mode():
277      variable_scope.get_variable("v1", [], trainable=True)
278      variable_scope.get_variable("v2", [], trainable=False)
279      self.assertFalse(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
280      self.assertFalse(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
281
282  def testEagerVariableStoreWithFunctionalLayer(self):
283    with context.eager_mode():
284      container = variable_scope.EagerVariableStore()
285      x = constant_op.constant([[2.0]])
286      with container.as_default():
287        y = core_layers.dense(x, 1, name="my_dense",
288                              kernel_initializer=init_ops.ones_initializer())
289      self.assertAllEqual(y, [[2.0]])
290      self.assertEqual(len(container.variables()), 2)
291      # Recreate the layer to test reuse.
292      with container.as_default():
293        core_layers.dense(x, 1, name="my_dense",
294                          kernel_initializer=init_ops.ones_initializer())
295      self.assertEqual(len(container.variables()), 2)
296
297  # TODO(mihaimaruseac): Not converted to use wrap_function because of
298  # TypeError: Expected tf.group() expected Tensor arguments not 'None' with
299  # type '<type 'NoneType'>'.
300  @test_util.run_in_graph_and_eager_modes
301  def testInitFromNonTensorValue(self):
302    v = variable_scope.get_variable("v4", initializer=4, dtype=dtypes.int32)
303    self.evaluate(variables_lib.variables_initializer([v]))
304    self.assertAllClose(self.evaluate(v.value()), 4)
305
306    w = variable_scope.get_variable(
307        "w4", initializer=numpy.array([1, 2, 3]), dtype=dtypes.int64)
308    self.evaluate(variables_lib.variables_initializer([w]))
309    self.assertAllClose(self.evaluate(w.value()), [1, 2, 3])
310
311    # A quirk to be revisited?
312    error = ValueError if context.executing_eagerly() else TypeError
313    with self.assertRaises(error):
314      variable_scope.get_variable("x4", initializer={})
315
316  # TODO(mihaimaruseac): Not converted to use wrap_function because of
317  # InvalidArgumentError=: You must feed a value for placeholder tensor
318  # 'ReadVariableOp/resource' with dtype resource
319  @test_util.run_in_graph_and_eager_modes
320  def testInitFromNonInitializer(self):
321    # Test various dtypes with zeros initializer as following:
322    types = [
323        dtypes.int8, dtypes.uint8, dtypes.int16, dtypes.uint16, dtypes.int32,
324        dtypes.int64, dtypes.bool
325    ]
326
327    # Use different variable_name to distinguish various dtypes
328    for (i, dtype) in enumerate(types):
329      x = variable_scope.get_variable(
330          name="xx%d" % i, shape=(3, 4), dtype=dtype)
331      y = variable_scope.get_variable(
332          name="yy%d" % i,
333          shape=(3, 4),
334          dtype=dtype,
335          initializer=init_ops.zeros_initializer(dtype=dtype))
336
337      self.evaluate(variables_lib.global_variables_initializer())
338      self.assertAllEqual(self.evaluate(x.value()), self.evaluate(y.value()))
339
340  # TODO(mihaimaruseac): Not converted to use wrap_function because of
341  # InvalidArgumentError: /job:moo/replica:0/task:0/device:CPU:0 unknown device.
342  @test_util.run_deprecated_v1
343  def testVarScopeCachingDevice(self):
344    with self.cached_session():
345      caching_device = "/job:moo"
346      with variable_scope.variable_scope("tower"):
347        with variable_scope.variable_scope(
348            "caching", caching_device=caching_device):
349          v = variable_scope.get_variable("v", [])
350          self.assertTrue(v.value().device.startswith(caching_device))
351
352          with variable_scope.variable_scope("child"):
353            v2 = variable_scope.get_variable("v", [])
354            self.assertTrue(v2.value().device.startswith(caching_device))
355
356          with variable_scope.variable_scope("not_cached", caching_device=""):
357            v2_not_cached = variable_scope.get_variable("v", [])
358            self.assertFalse(
359                v2_not_cached.value().device.startswith(caching_device))
360
361          with variable_scope.variable_scope(
362              "not_cached_identity_device",
363              caching_device=lambda op: op.device):
364            v2_identity_device = variable_scope.get_variable("v", [])
365            self.assertFalse(
366                v2_identity_device.value().device.startswith(caching_device))
367
368          with variable_scope.variable_scope("we_will_do_it_live") as vs_live:
369            vs_live.set_caching_device("/job:live")
370            v_live = variable_scope.get_variable("v", [])
371            self.assertTrue(v_live.value().device.startswith("/job:live"))
372
373        v_tower = variable_scope.get_variable("v", [])
374        self.assertFalse(v_tower.value().device.startswith(caching_device))
375
376  # TODO(mihaimaruseac): Not converted to use wrap_function because of
377  # AttributeError: Tensor.name is meaningless when eager execution is enabled.
378  @test_util.run_in_graph_and_eager_modes
379  def testVarScopeRegularizer(self):
380    init = init_ops.constant_initializer(0.3)
381
382    def regularizer1(v):
383      return math_ops.reduce_mean(v) + 0.1
384
385    def regularizer2(v):
386      return math_ops.reduce_mean(v) + 0.2
387
388    with variable_scope.variable_scope(
389        "tower3", regularizer=regularizer1) as tower:
390      with variable_scope.variable_scope("foo", initializer=init):
391        v = variable_scope.get_variable("v", [])
392        self.evaluate(variables_lib.variables_initializer([v]))
393        losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
394        self.assertEqual(1, len(losses))
395        self.assertAllClose(self.evaluate(losses[0]), 0.4)
396      with variable_scope.variable_scope(tower, initializer=init) as vs:
397        u = variable_scope.get_variable("u", [])
398        vs.set_regularizer(regularizer2)
399        w = variable_scope.get_variable("w", [])
400        # Next 3 variable not regularized to test disabling regularization.
401        x = variable_scope.get_variable(
402            "x", [], regularizer=variable_scope.no_regularizer)
403        with variable_scope.variable_scope(
404            "baz", regularizer=variable_scope.no_regularizer):
405          y = variable_scope.get_variable("y", [])
406        vs.set_regularizer(variable_scope.no_regularizer)
407        z = variable_scope.get_variable("z", [])
408        # Check results.
409        losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
410        self.assertEqual(3, len(losses))
411        self.evaluate(variables_lib.variables_initializer([u, w, x, y, z]))
412        self.assertAllClose(self.evaluate(losses[0]), 0.4)
413        self.assertAllClose(self.evaluate(losses[1]), 0.4)
414        self.assertAllClose(self.evaluate(losses[2]), 0.5)
415      with variable_scope.variable_scope("foo", reuse=True):
416        # reuse=True is for now only supported when eager execution is disabled.
417        if not context.executing_eagerly():
418          v = variable_scope.get_variable("v",
419                                          [])  # "v" is already there, reused
420          losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
421          self.assertEqual(3, len(losses))  # No new loss added.
422
423  # TODO(mihaimaruseac): Not converted to use wrap_function because of
424  # ValueError: Tensor-typed variable initializers must either be wrapped in an
425  # init_scope or callable...
426  @test_util.run_in_graph_and_eager_modes
427  def testInitializeFromValue(self):
428    init = constant_op.constant(0.1)
429    w = variable_scope.get_variable("v", initializer=init)
430    self.evaluate(variables_lib.variables_initializer([w]))
431    self.assertAllClose(self.evaluate(w.value()), 0.1)
432
433    with self.assertRaisesRegex(ValueError, "shape"):
434      # We disallow explicit shape specification when initializer is constant.
435      variable_scope.get_variable("u", [1], initializer=init)
436
437    with variable_scope.variable_scope("foo", initializer=init):
438      # Constant initializer can be passed through scopes if needed.
439      v = variable_scope.get_variable("v")
440      self.evaluate(variables_lib.variables_initializer([v]))
441      self.assertAllClose(self.evaluate(v.value()), 0.1)
442
443    # Check that non-float32 initializer creates a non-float32 variable.
444    init = constant_op.constant(1, dtype=dtypes.int32)
445    t = variable_scope.get_variable("t", initializer=init)
446    self.assertEqual(t.dtype.base_dtype, dtypes.int32)
447
448    # Raise error if `initializer` dtype and `dtype` are not identical.
449    with self.assertRaisesRegex(ValueError, "don't match"):
450      variable_scope.get_variable("s", initializer=init, dtype=dtypes.float64)
451
452  # TODO(mihaimaruseac): Not converted to use wrap_function because of
453  # TypeError: Fetch argument <tf.Variable 'v0:0' shape=(1,) dtype=float32> has
454  # invalid type <class '...ops.resource_variable_ops.ResourceVariable'>, must
455  # be a string or Tensor. (Can not convert a ResourceVariable into a Tensor or
456  # Operation.)
457  @test_util.run_deprecated_v1
458  def testControlDeps(self):
459    with self.cached_session() as sess:
460      v0 = variable_scope.get_variable(
461          "v0", [1], initializer=init_ops.constant_initializer(0))
462      with ops.control_dependencies([v0.value()]):
463        v1 = variable_scope.get_variable(
464            "v1", [1], initializer=init_ops.constant_initializer(1))
465        add = v1 + v0
466      # v0 should be uninitialized.
467      with self.assertRaisesRegex(errors.OpError, "uninitialized"):
468        self.evaluate(v0)
469      # We should be able to initialize and run v1 without initializing
470      # v0, even if the variable was created with a control dep on v0.
471      self.evaluate(v1.initializer)
472      self.assertEqual(1, self.evaluate(v1))
473      # v0 should still be uninitialized.
474      with self.assertRaisesRegex(errors.OpError, "uninitialized"):
475        self.evaluate(v0)
476      with self.assertRaisesRegex(errors.OpError, "uninitialized"):
477        self.evaluate(add)
478      # If we initialize v0 we should be able to run 'add'.
479      self.evaluate(v0.initializer)
480      self.evaluate(add)
481
482  # TODO(mihaimaruseac): Not converted to use wrap_function because of
483  # AssertionError: True is not false (last assertFalse)
484  @test_util.run_deprecated_v1
485  def testEnableResourceVariables(self):
486    old = variable_scope._DEFAULT_USE_RESOURCE
487    try:
488      variable_scope.enable_resource_variables()
489      self.assertTrue(isinstance(variables_lib.VariableV1(1.0),
490                                 resource_variable_ops.ResourceVariable))
491      variable_scope.disable_resource_variables()
492      self.assertFalse(isinstance(variables_lib.VariableV1(1.0),
493                                  resource_variable_ops.ResourceVariable))
494    finally:
495      variable_scope._DEFAULT_USE_RESOURCE = old
496
497  # TODO(mihaimaruseac): Not converted to use wrap_function because of
498  # TypeError: Fetch argument None has invalid type <type 'NoneType'>
499  @test_util.run_deprecated_v1
500  def testControlFlow(self):
501    with self.cached_session() as sess:
502      v0 = variable_scope.get_variable(
503          "v0", [], initializer=init_ops.constant_initializer(0))
504      var_dict = {}
505
506      # Call get_variable in each of the cond clauses.
507      def var_in_then_clause():
508        v1 = variable_scope.get_variable(
509            "v1", [1], initializer=init_ops.constant_initializer(1))
510        var_dict["v1"] = v1
511        return v1 + v0
512
513      def var_in_else_clause():
514        v2 = variable_scope.get_variable(
515            "v2", [1], initializer=init_ops.constant_initializer(2))
516        var_dict["v2"] = v2
517        return v2 + v0
518
519      add = control_flow_ops.cond(
520          math_ops.less(v0, 10), var_in_then_clause, var_in_else_clause)
521      v1 = var_dict["v1"]
522      v2 = var_dict["v2"]
523      # We should be able to initialize and run v1 and v2 without initializing
524      # v0, even if the variable was created with a control dep on v0.
525      self.evaluate(v1.initializer)
526      self.assertEqual([1], self.evaluate(v1))
527      self.evaluate(v2.initializer)
528      self.assertEqual([2], self.evaluate(v2))
529      # v0 should still be uninitialized.
530      with self.assertRaisesRegex(errors.OpError, "uninitialized"):
531        self.evaluate(v0)
532      # We should not be able to run 'add' yet.
533      with self.assertRaisesRegex(errors.OpError, "uninitialized"):
534        self.evaluate(add)
535      # If we initialize v0 we should be able to run 'add'.
536      self.evaluate(v0.initializer)
537      self.evaluate(add)
538
539  # TODO(mihaimaruseac): Not converted to use wrap_function because of
540  # TypeError: Expected tf.group() expected Tensor arguments not 'None' with
541  # type '<type 'NoneType'>'.
542  @test_util.run_in_graph_and_eager_modes
543  def testGetVariableScope(self):
544    # Test the get_variable_scope() function and setting properties of result.
545    init = init_ops.constant_initializer(0.3)
546    with variable_scope.variable_scope("bar"):
547      new_init1 = variable_scope.get_variable_scope().initializer
548      self.assertEqual(new_init1, None)
549      # Check that we can set initializer like this.
550      variable_scope.get_variable_scope().set_initializer(init)
551      v = variable_scope.get_variable("v", [])
552      self.evaluate(variables_lib.variables_initializer([v]))
553      self.assertAllClose(self.evaluate(v.value()), 0.3)
554      if not context.executing_eagerly():
555        # Check that we can set reuse.
556        variable_scope.get_variable_scope().reuse_variables()
557        with self.assertRaises(ValueError):  # Fail, w does not exist yet.
558          variable_scope.get_variable("w", [1])
559    # Check that the set initializer goes away.
560    new_init = variable_scope.get_variable_scope().initializer
561    self.assertEqual(new_init, None)
562
563  @test_util.run_in_graph_and_eager_modes
564  @run_inside_wrap_function_in_eager_mode
565  def testVarScope(self):
566    with variable_scope.variable_scope("tower4") as tower:
567      self.assertEqual(tower.name, "tower4")
568      with ops.name_scope("scope") as sc:
569        self.assertEqual(sc, "tower4/scope/")
570
571    with variable_scope.variable_scope("tower5"):
572      with variable_scope.variable_scope("bar") as bar:
573        self.assertEqual(bar.name, "tower5/bar")
574        with ops.name_scope("scope") as sc:
575          self.assertEqual(sc, "tower5/bar/scope/")
576
577    with variable_scope.variable_scope("tower6"):
578      with variable_scope.variable_scope(tower, reuse=True) as tower_shared:
579        self.assertEqual(tower_shared.name, "tower4")
580        with ops.name_scope("scope") as sc:
581          self.assertEqual(sc, "tower6/tower4/scope/")
582
583  @test_util.run_in_graph_and_eager_modes
584  @run_inside_wrap_function_in_eager_mode
585  def testVarScopeNameScope(self):
586    with ops.name_scope("testVarScopeNameScope1"):
587      with variable_scope.variable_scope("tower") as tower:
588        with ops.name_scope("scope2") as sc2:
589          self.assertEqual(sc2, "testVarScopeNameScope1/tower/scope2/")
590      if not context.executing_eagerly():
591        with variable_scope.variable_scope(
592            tower):  # Re-entering acts like another "tower".
593          with ops.name_scope("scope2") as sc2:
594            self.assertEqual(sc2, "testVarScopeNameScope1/tower_1/scope2/")
595        with variable_scope.variable_scope(
596            "tower"):  # Re-entering by string acts the same.
597          with ops.name_scope("scope2") as sc2:
598            self.assertEqual(sc2, "testVarScopeNameScope1/tower_2/scope2/")
599
600    with ops.name_scope("testVarScopeNameScope2"):
601      with variable_scope.variable_scope("tower"):
602        with ops.name_scope("scope2") as sc2:
603          self.assertEqual(sc2, "testVarScopeNameScope2/tower/scope2/")
604      if not context.executing_eagerly():
605        with variable_scope.variable_scope(tower):
606          with ops.name_scope("scope2") as sc2:
607            self.assertEqual(sc2, "testVarScopeNameScope2/tower_1/scope2/")
608
609    root_var_scope = variable_scope.get_variable_scope()
610    with ops.name_scope("testVarScopeNameScope3"):
611      with variable_scope.variable_scope(root_var_scope):
612        with ops.name_scope("scope2") as sc2:
613          self.assertEqual(sc2, "testVarScopeNameScope3/scope2/")
614
615  @test_util.run_in_graph_and_eager_modes
616  @run_inside_wrap_function_in_eager_mode
617  def testVarScopeOriginalNameScope(self):
618    with self.cached_session():
619      with ops.name_scope("scope1"):
620        with variable_scope.variable_scope("tower") as tower:
621          self.assertEqual(tower.original_name_scope, "scope1/tower/")
622          with ops.name_scope("scope2") as sc2:
623            self.assertEqual(sc2, "scope1/tower/scope2/")
624      with ops.name_scope("scope2"):
625        with variable_scope.variable_scope(tower) as tower1:
626          # Re-entering preserves original name scope.
627          self.assertEqual(tower1.original_name_scope, "scope1/tower/")
628          with ops.name_scope("foo") as sc2:
629            self.assertEqual(sc2, "scope2/tower/foo/")
630        # Test re-entering original name scope.
631        with ops.name_scope(tower.original_name_scope):
632          with ops.name_scope("bar") as sc3:
633            self.assertEqual(sc3, "scope1/tower/bar/")
634      with ops.name_scope("scope2"):
635        with variable_scope.variable_scope(tower):
636          with ops.name_scope(tower.original_name_scope):
637            with ops.name_scope("bar") as sc3:
638              self.assertEqual(sc3, "scope1/tower/bar_1/")
639
640  @test_util.run_in_graph_and_eager_modes
641  @run_inside_wrap_function_in_eager_mode
642  def testVarScopeObjectReuse(self):
643    with self.cached_session():
644      vs = None
645      with variable_scope.variable_scope("jump", reuse=True) as scope:
646        vs = scope
647
648      with variable_scope.variable_scope(vs) as jump:
649        self.assertTrue(jump.reuse)
650
651      with variable_scope.variable_scope(vs, reuse=True) as jump_reuse:
652        self.assertTrue(jump_reuse.reuse)
653
654      with variable_scope.variable_scope(vs, reuse=False) as jump_no_reuse:
655        self.assertTrue(jump_no_reuse.reuse)  # Inherited, cannot be undone.
656
657      with variable_scope.variable_scope("jump", reuse=False) as scope:
658        vs = scope
659
660      with variable_scope.variable_scope(vs) as jump:
661        self.assertFalse(jump.reuse)
662
663      with variable_scope.variable_scope(vs, reuse=True) as jump_reuse:
664        self.assertTrue(jump_reuse.reuse)
665
666      with variable_scope.variable_scope(vs, reuse=False) as jump_no_reuse:
667        self.assertFalse(jump_no_reuse.reuse)
668
669  @test_util.run_in_graph_and_eager_modes
670  @run_inside_wrap_function_in_eager_mode
671  def testVarScopeGetOrCreateReuse(self):
672    with self.cached_session():
673
674      def test_value(value):
675        x = constant_op.constant(value)
676        with variable_scope.variable_scope(
677            "testVarScopeGetOrCreateReuse_bar",
678            reuse=variable_scope.AUTO_REUSE):
679          _ = state_ops.assign(variable_scope.get_variable("var", []), x)
680        with variable_scope.variable_scope(
681            "testVarScopeGetOrCreateReuse_bar",
682            reuse=variable_scope.AUTO_REUSE):
683          _ = variable_scope.get_variable("var", [])
684        self.assertEqual(value, self.evaluate(x))
685
686      test_value(42.)  # Variable is created.
687      test_value(13.)  # Variable is reused hereafter.
688      test_value(17.)
689
690  @test_util.run_in_graph_and_eager_modes
691  @run_inside_wrap_function_in_eager_mode
692  def testVarOpScope(self):
693    with self.cached_session():
694      with ops.name_scope("testVarOpScope1"):
695        with variable_scope.variable_scope("tower", "default", []):
696          self.assertEqual(
697              variable_scope.get_variable("w", []).name, "tower/w:0")
698          with ops.name_scope("testVarOpScope2") as sc2:
699            self.assertEqual(sc2, "testVarOpScope1/tower/testVarOpScope2/")
700        with variable_scope.variable_scope("tower", "default", []):
701          with self.assertRaises(ValueError):
702            variable_scope.get_variable("w", [])
703          with ops.name_scope("testVarOpScope2") as sc2:
704            self.assertEqual(sc2, "testVarOpScope1/tower_1/testVarOpScope2/")
705
706      with ops.name_scope("testVarOpScope2"):
707        with variable_scope.variable_scope(None, "default", []):
708          self.assertEqual(
709              variable_scope.get_variable("w", []).name, "default/w:0")
710          with ops.name_scope("testVarOpScope2") as sc2:
711            self.assertEqual(sc2, "testVarOpScope2/default/testVarOpScope2/")
712        with variable_scope.variable_scope(None, "default", []):
713          self.assertEqual(
714              variable_scope.get_variable("w", []).name, "default_1/w:0")
715          with ops.name_scope("testVarOpScope2") as sc2:
716            self.assertEqual(sc2, "testVarOpScope2/default_1/testVarOpScope2/")
717
718  @test_util.run_in_graph_and_eager_modes
719  @run_inside_wrap_function_in_eager_mode
720  def testVarOpScopeUniqueNamesInterleavedSubstringScopes(self):
721    with self.cached_session():
722      with variable_scope.variable_scope(None, "defaultScope1"):
723        with variable_scope.variable_scope(None, "layer"):
724          self.assertEqual(
725              variable_scope.get_variable("w", []).name,
726              "defaultScope1/layer/w:0")
727      with variable_scope.variable_scope(None, "defaultScope1"):
728        with variable_scope.variable_scope(None, "layer"):
729          self.assertEqual(
730              variable_scope.get_variable("w", []).name,
731              "defaultScope1_1/layer/w:0")
732      with variable_scope.variable_scope(None, "defaultScope"):
733        with variable_scope.variable_scope(None, "layer"):
734          self.assertEqual(
735              variable_scope.get_variable("w", []).name,
736              "defaultScope/layer/w:0")
737      with variable_scope.variable_scope(None, "defaultScope1"):
738        with variable_scope.variable_scope(None, "layer"):
739          self.assertEqual(
740              variable_scope.get_variable("w", []).name,
741              "defaultScope1_2/layer/w:0")
742
743  @test_util.run_in_graph_and_eager_modes
744  @run_inside_wrap_function_in_eager_mode
745  def testVarOpScopeUniqueNamesWithJump(self):
746    with self.cached_session():
747      with variable_scope.variable_scope("default") as default:
748        with variable_scope.variable_scope(None, "layer"):
749          self.assertEqual(
750              variable_scope.get_variable("w", []).name, "default/layer/w:0")
751        with variable_scope.variable_scope(None, "layer"):
752          self.assertEqual(
753              variable_scope.get_variable("w", []).name,
754              "default/layer_1/w:0")
755        with variable_scope.variable_scope(default):
756          pass
757        # No matter the jump in the middle, unique numbering continues.
758        with variable_scope.variable_scope(None, "layer"):
759          self.assertEqual(
760              variable_scope.get_variable("w", []).name,
761              "default/layer_2/w:0")
762
763  @test_util.run_in_graph_and_eager_modes
764  @run_inside_wrap_function_in_eager_mode
765  def testVarOpScopeReuse(self):
766    with self.cached_session():
767      with variable_scope.variable_scope("outer") as outer:
768        with variable_scope.variable_scope("tower", "default", []):
769          self.assertEqual(
770              variable_scope.get_variable("w", []).name, "outer/tower/w:0")
771          with ops.name_scope("scope2") as sc2:
772            self.assertEqual(sc2, "outer/tower/scope2/")
773        with variable_scope.variable_scope(None, "default", []):
774          self.assertEqual(
775              variable_scope.get_variable("w", []).name, "outer/default/w:0")
776          with ops.name_scope("scope2") as sc2:
777            self.assertEqual(sc2, "outer/default/scope2/")
778
779      with variable_scope.variable_scope(outer, reuse=True) as outer:
780        with variable_scope.variable_scope("tower", "default", []):
781          self.assertEqual(
782              variable_scope.get_variable("w", []).name, "outer/tower/w:0")
783          with ops.name_scope("scope2") as sc2:
784            self.assertEqual(sc2, "outer_1/tower/scope2/")
785        with variable_scope.variable_scope(None, "default", []):
786          self.assertEqual(
787              variable_scope.get_variable("w", []).name, "outer/default/w:0")
788          with ops.name_scope("scope2") as sc2:
789            self.assertEqual(sc2, "outer_1/default/scope2/")
790
791  @test_util.run_in_graph_and_eager_modes
792  @run_inside_wrap_function_in_eager_mode
793  def testVarScopeGetVar(self):
794    with self.cached_session():
795      with variable_scope.variable_scope("root"):
796        with variable_scope.variable_scope("towerA") as tower_a:
797          va = variable_scope.get_variable("v", [1])
798          self.assertEqual(va.name, "root/towerA/v:0")
799
800        with variable_scope.variable_scope(tower_a, reuse=True):
801          va2 = variable_scope.get_variable("v", [1])
802          self.assertIs(va2, va)
803
804        with variable_scope.variable_scope("towerB"):
805          vb = variable_scope.get_variable("v", [1])
806          self.assertEqual(vb.name, "root/towerB/v:0")
807
808        with self.assertRaises(ValueError):
809          with variable_scope.variable_scope("towerA"):
810            va2 = variable_scope.get_variable("v", [1])
811
812        with variable_scope.variable_scope("towerA", reuse=True):
813          va2 = variable_scope.get_variable("v", [1])
814          self.assertIs(va2, va)
815
816        with variable_scope.variable_scope("foo"):
817          with variable_scope.variable_scope("bar"):
818            v = variable_scope.get_variable("v", [1])
819            self.assertEqual(v.name, "root/foo/bar/v:0")
820            with variable_scope.variable_scope(tower_a, reuse=True):
821              va3 = variable_scope.get_variable("v", [1])
822              self.assertIs(va, va3)
823
824        with self.assertRaises(ValueError):
825          with variable_scope.variable_scope(tower_a, reuse=True):
826            with variable_scope.variable_scope("baz"):
827              variable_scope.get_variable("v", [1])
828
829        with self.assertRaises(ValueError) as exc:
830          with variable_scope.variable_scope(tower_a, reuse=True):
831            variable_scope.get_variable("v", [2])  # Different shape.
832        self.assertEqual("shape" in str(exc.exception), True)
833
834        with self.assertRaises(ValueError) as exc:
835          with variable_scope.variable_scope(tower_a, reuse=True):
836            variable_scope.get_variable("v", [1], dtype=dtypes.int32)
837        self.assertEqual("dtype" in str(exc.exception), True)
838
839  @test_util.run_in_graph_and_eager_modes
840  @run_inside_wrap_function_in_eager_mode
841  def testVarScopeOuterScope(self):
842    with self.cached_session():
843      with variable_scope.variable_scope("outer") as outer:
844        pass
845      with variable_scope.variable_scope(outer):
846        self.assertEqual(
847            variable_scope.get_variable("w", []).name, "outer/w:0")
848        with ops.name_scope("scope2") as sc2:
849          self.assertEqual(sc2, "outer_1/scope2/")
850        with variable_scope.variable_scope("default"):
851          self.assertEqual(
852              variable_scope.get_variable("w", []).name, "outer/default/w:0")
853          with ops.name_scope("scope2") as sc2:
854            self.assertEqual(sc2, "outer_1/default/scope2/")
855
856      with variable_scope.variable_scope(outer, reuse=True):
857        self.assertEqual(
858            variable_scope.get_variable("w", []).name, "outer/w:0")
859        with ops.name_scope("scope2") as sc2:
860          self.assertEqual(sc2, "outer_2/scope2/")
861        with variable_scope.variable_scope("default", reuse=True):
862          self.assertEqual(
863              variable_scope.get_variable("w", []).name, "outer/default/w:0")
864          with ops.name_scope("scope2") as sc2:
865            self.assertEqual(sc2, "outer_2/default/scope2/")
866
867  @test_util.run_in_graph_and_eager_modes
868  @run_inside_wrap_function_in_eager_mode
869  def testVarScopeNestedOuterScope(self):
870    with self.cached_session():
871      with variable_scope.variable_scope("outer") as outer:
872        with variable_scope.variable_scope(outer):
873          self.assertEqual(
874              variable_scope.get_variable("w", []).name, "outer/w:0")
875          with ops.name_scope("scope2") as sc2:
876            self.assertEqual(sc2, "outer/outer/scope2/")
877        with variable_scope.variable_scope("default"):
878          self.assertEqual(
879              variable_scope.get_variable("w", []).name, "outer/default/w:0")
880          with ops.name_scope("scope2") as sc2:
881            self.assertEqual(sc2, "outer/default/scope2/")
882
883        with variable_scope.variable_scope(outer, reuse=True):
884          self.assertEqual(
885              variable_scope.get_variable("w", []).name, "outer/w:0")
886          with ops.name_scope("scope2") as sc2:
887            self.assertEqual(sc2, "outer/outer_1/scope2/")
888        with variable_scope.variable_scope("default", reuse=True):
889          self.assertEqual(
890              variable_scope.get_variable("w", []).name, "outer/default/w:0")
891          with ops.name_scope("scope2") as sc2:
892            self.assertEqual(sc2, "outer/default_1/scope2/")
893
894  @test_util.run_in_graph_and_eager_modes
895  @run_inside_wrap_function_in_eager_mode
896  def testVarOpScopeReuseParam(self):
897    with self.cached_session():
898      with variable_scope.variable_scope("outer") as outer:
899        with variable_scope.variable_scope("tower", "default", []):
900          self.assertEqual(
901              variable_scope.get_variable("w", []).name, "outer/tower/w:0")
902          with ops.name_scope("scope2") as sc2:
903            self.assertEqual(sc2, "outer/tower/scope2/")
904        with variable_scope.variable_scope(None, "default", []):
905          self.assertEqual(
906              variable_scope.get_variable("w", []).name, "outer/default/w:0")
907          with ops.name_scope("scope2") as sc2:
908            self.assertEqual(sc2, "outer/default/scope2/")
909
910      with variable_scope.variable_scope(outer) as outer:
911        with variable_scope.variable_scope("tower", "default", reuse=True):
912          self.assertEqual(
913              variable_scope.get_variable("w", []).name, "outer/tower/w:0")
914          with ops.name_scope("scope2") as sc2:
915            self.assertEqual(sc2, "outer_1/tower/scope2/")
916        outer.reuse_variables()
917        with variable_scope.variable_scope(None, "default", []):
918          self.assertEqual(
919              variable_scope.get_variable("w", []).name, "outer/default/w:0")
920          with ops.name_scope("scope2") as sc2:
921            self.assertEqual(sc2, "outer_1/default/scope2/")
922
923  @test_util.run_in_graph_and_eager_modes
924  @run_inside_wrap_function_in_eager_mode
925  def testVarOpScopeReuseError(self):
926    with self.cached_session():
927      with self.assertRaises(ValueError):
928        with variable_scope.variable_scope(None, "default", reuse=True):
929          self.assertEqual(
930              variable_scope.get_variable("w", []).name, "outer/tower/w:0")
931
932  @test_util.run_in_graph_and_eager_modes
933  @run_inside_wrap_function_in_eager_mode
934  def testVarOpScopeOuterScope(self):
935    with self.cached_session():
936      with variable_scope.variable_scope("outer") as outer:
937        pass
938      with variable_scope.variable_scope(outer, "default", []):
939        self.assertEqual(
940            variable_scope.get_variable("w", []).name, "outer/w:0")
941        with ops.name_scope("scope2") as sc2:
942          self.assertEqual(sc2, "outer_1/scope2/")
943        with variable_scope.variable_scope(None, "default", []):
944          self.assertEqual(
945              variable_scope.get_variable("w", []).name, "outer/default/w:0")
946          with ops.name_scope("scope2") as sc2:
947            self.assertEqual(sc2, "outer_1/default/scope2/")
948
949      with variable_scope.variable_scope(outer, "default", reuse=True):
950        self.assertEqual(
951            variable_scope.get_variable("w", []).name, "outer/w:0")
952        with ops.name_scope("scope2") as sc2:
953          self.assertEqual(sc2, "outer_2/scope2/")
954        outer.reuse_variables()
955        with variable_scope.variable_scope(None, "default", []):
956          self.assertEqual(
957              variable_scope.get_variable("w", []).name, "outer/default/w:0")
958          with ops.name_scope("scope2") as sc2:
959            self.assertEqual(sc2, "outer_2/default/scope2/")
960
961  @test_util.run_in_graph_and_eager_modes
962  @run_inside_wrap_function_in_eager_mode
963  def testVarOpScopeNestedOuterScope(self):
964    with self.cached_session():
965      with variable_scope.variable_scope("outer") as outer:
966        with variable_scope.variable_scope(outer, "default", []):
967          self.assertEqual(
968              variable_scope.get_variable("w", []).name, "outer/w:0")
969          with ops.name_scope("scope2") as sc2:
970            self.assertEqual(sc2, "outer/outer/scope2/")
971        with variable_scope.variable_scope(None, "default", []):
972          self.assertEqual(
973              variable_scope.get_variable("w", []).name, "outer/default/w:0")
974          with ops.name_scope("scope2") as sc2:
975            self.assertEqual(sc2, "outer/default/scope2/")
976
977      with variable_scope.variable_scope(outer, "default", reuse=True):
978        self.assertEqual(
979            variable_scope.get_variable("w", []).name, "outer/w:0")
980        with ops.name_scope("scope2") as sc2:
981          self.assertEqual(sc2, "outer_1/scope2/")
982        with variable_scope.variable_scope(None, "default", []):
983          self.assertEqual(
984              variable_scope.get_variable("w", []).name, "outer/default/w:0")
985          with ops.name_scope("scope2") as sc2:
986            self.assertEqual(sc2, "outer_1/default/scope2/")
987
988  @test_util.run_in_graph_and_eager_modes
989  @run_inside_wrap_function_in_eager_mode
990  def testBasicWhenAuxiliaryNameScopeIsFalse(self):
991    with self.cached_session():
992      with variable_scope.variable_scope(
993          "scope", auxiliary_name_scope=False) as scope:
994        self.assertEqual(scope.original_name_scope, "")
995        self.assertEqual(
996            variable_scope.get_variable("w", []).name, "scope/w:0")
997        self.assertEqual(constant_op.constant([], name="c").name, "c:0")
998      with variable_scope.variable_scope(scope, auxiliary_name_scope=False):
999        self.assertEqual(scope.original_name_scope, "")
1000        self.assertEqual(
1001            variable_scope.get_variable("w1", []).name, "scope/w1:0")
1002        self.assertEqual(constant_op.constant([], name="c1").name, "c1:0")
1003      # Recheck: new name scope is NOT created before
1004      with ops.name_scope("scope"):
1005        self.assertEqual(constant_op.constant([], name="c").name, "scope/c:0")
1006
1007      with variable_scope.variable_scope("outer"):
1008        with variable_scope.variable_scope(
1009            "inner", auxiliary_name_scope=False) as inner:
1010          self.assertEqual(inner.original_name_scope, "outer/")
1011          self.assertEqual(
1012              variable_scope.get_variable("w", []).name, "outer/inner/w:0")
1013          self.assertEqual(
1014              constant_op.constant([], name="c").name, "outer/c:0")
1015        with variable_scope.variable_scope(
1016            inner, auxiliary_name_scope=False) as inner1:
1017          self.assertEqual(inner1.original_name_scope, "outer/")
1018          self.assertEqual(
1019              variable_scope.get_variable("w1", []).name, "outer/inner/w1:0")
1020          self.assertEqual(
1021              constant_op.constant([], name="c1").name, "outer/c1:0")
1022        # Recheck: new name scope is NOT created before
1023        with ops.name_scope("inner"):
1024          self.assertEqual(
1025              constant_op.constant([], name="c").name, "outer/inner/c:0")
1026
1027  @test_util.run_in_graph_and_eager_modes
1028  @run_inside_wrap_function_in_eager_mode
1029  def testCreatedByDefaultNameWhenAuxiliaryNameScopeIsFalse(self):
1030    with self.cached_session():
1031      with variable_scope.variable_scope(
1032          None, default_name="default", auxiliary_name_scope=False) as scope:
1033        self.assertEqual(scope.original_name_scope, "")
1034        self.assertEqual(
1035            variable_scope.get_variable("w", []).name, "default/w:0")
1036        self.assertEqual(constant_op.constant([], name="c").name, "c:0")
1037      # Recheck: new name scope is NOT created before
1038      with ops.name_scope("default"):
1039        self.assertEqual(
1040            constant_op.constant([], name="c").name, "default/c:0")
1041
1042      with variable_scope.variable_scope("outer"):
1043        with variable_scope.variable_scope(
1044            None, default_name="default",
1045            auxiliary_name_scope=False) as inner:
1046          self.assertEqual(inner.original_name_scope, "outer/")
1047          self.assertEqual(
1048              variable_scope.get_variable("w", []).name, "outer/default/w:0")
1049          self.assertEqual(
1050              constant_op.constant([], name="c").name, "outer/c:0")
1051        # Recheck: new name scope is NOT created before
1052        with ops.name_scope("default"):
1053          self.assertEqual(
1054              constant_op.constant([], name="c").name, "outer/default/c:0")
1055
1056  @test_util.run_in_graph_and_eager_modes
1057  @run_inside_wrap_function_in_eager_mode
1058  def testReenterRootScopeWhenAuxiliaryNameScopeIsFalse(self):
1059    with self.cached_session():
1060      root_scope = variable_scope.get_variable_scope()
1061      with variable_scope.variable_scope(
1062          root_scope, auxiliary_name_scope=False) as scope:
1063        self.assertEqual(scope.original_name_scope, "")
1064        self.assertEqual(variable_scope.get_variable("w", []).name, "w:0")
1065        self.assertEqual(constant_op.constant([], name="c").name, "c:0")
1066
1067      with variable_scope.variable_scope("outer"):
1068        with variable_scope.variable_scope(
1069            root_scope, auxiliary_name_scope=False) as inner:
1070          self.assertEqual(inner.original_name_scope, "")
1071          self.assertEqual(variable_scope.get_variable("w1", []).name, "w1:0")
1072          self.assertEqual(
1073              constant_op.constant([], name="c1").name, "outer/c1:0")
1074
1075  @test_util.run_in_graph_and_eager_modes
1076  @run_inside_wrap_function_in_eager_mode
1077  def testAuxiliaryNameScopeIsInvalid(self):
1078    with self.cached_session():
1079      with self.assertRaisesRegex(TypeError, "auxiliary_name_scope"):
1080        with variable_scope.variable_scope(
1081            None, default_name="scope", auxiliary_name_scope="invalid"):
1082          pass
1083
1084      with self.assertRaisesRegex(TypeError, "auxiliary_name_scope"):
1085        with variable_scope.variable_scope(
1086            "scope", auxiliary_name_scope="invalid"):
1087          pass
1088
1089      with variable_scope.variable_scope("scope") as scope:
1090        pass
1091      with self.assertRaisesRegex(TypeError, "auxiliary_name_scope"):
1092        with variable_scope.variable_scope(
1093            scope, auxiliary_name_scope="invalid"):
1094          pass
1095
1096  @test_util.run_in_graph_and_eager_modes
1097  @run_inside_wrap_function_in_eager_mode
1098  def testReuseScopeWithoutNameScopeCollision(self):
1099    # Github issue: #13429
1100    with self.cached_session():
1101      with variable_scope.variable_scope("outer"):
1102        with variable_scope.variable_scope("inner") as inner:
1103          pass
1104
1105      with variable_scope.variable_scope(
1106          inner, auxiliary_name_scope=False) as scope:
1107        with ops.name_scope(scope.original_name_scope):
1108          self.assertEqual(
1109              variable_scope.get_variable("w", []).name, "outer/inner/w:0")
1110          self.assertEqual(
1111              constant_op.constant([], name="c").name, "outer/inner/c:0")
1112        with ops.name_scope("inner"):
1113          self.assertEqual(
1114              constant_op.constant([], name="c").name, "inner/c:0")
1115
1116      with variable_scope.variable_scope("another"):
1117        with variable_scope.variable_scope(
1118            inner, auxiliary_name_scope=False) as scope1:
1119          with ops.name_scope(scope1.original_name_scope):
1120            self.assertEqual(
1121                variable_scope.get_variable("w1", []).name,
1122                "outer/inner/w1:0")
1123            self.assertEqual(
1124                constant_op.constant([], name="c1").name, "outer/inner/c1:0")
1125          with ops.name_scope("inner"):
1126            self.assertEqual(
1127                constant_op.constant([], name="c").name, "another/inner/c:0")
1128
1129  # TODO(mihaimaruseac): Not converted to use wrap_function because of
1130  # obtaining different results in the eager case compared to the graph one
1131  # (different assertions failing after wrapping, in both execution modes)
1132  @test_util.run_in_graph_and_eager_modes
1133  def testGetLocalVar(self):
1134    # Check that local variable respects naming.
1135    with variable_scope.variable_scope("outer") as outer:
1136      with variable_scope.variable_scope(outer, "default", []):
1137        local_var = variable_scope.get_local_variable(
1138            "w", [], collections=["foo"])
1139        self.assertEqual(local_var.name, "outer/w:0")
1140
1141    if not context.executing_eagerly():
1142      # Since variable is local, it should be in the local variable collection
1143      # but not the trainable collection.
1144      self.assertIn(local_var,
1145                    ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
1146      self.assertIn(local_var, ops.get_collection("foo"))
1147      self.assertNotIn(local_var,
1148                       ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
1149      # Check that local variable respects `reuse`.
1150      with variable_scope.variable_scope(outer, "default", reuse=True):
1151        self.assertEqual(
1152            variable_scope.get_local_variable("w", []).name, "outer/w:0")
1153
1154  @test_util.run_in_graph_and_eager_modes
1155  @run_inside_wrap_function_in_eager_mode
1156  def testSignatureGetVarVsGetLocalVar(self):
1157    """get_{local,}variable() must take the same list of args."""
1158    arg_names = tf_inspect.getargspec(variable_scope.get_variable)[0]
1159    local_arg_names = tf_inspect.getargspec(
1160        variable_scope.get_local_variable)[0]
1161    self.assertEqual(arg_names, local_arg_names)
1162
1163  @test_util.run_in_graph_and_eager_modes
1164  @run_inside_wrap_function_in_eager_mode
1165  def testGetVarWithDevice(self):
1166    g = ops.Graph()
1167    varname_type = []
1168
1169    def device_func(op):
1170      if op.type in ["Variable", "VariableV2", "VarHandleOp"]:
1171        varname_type.append((op.name, op.get_attr("dtype")))
1172      return "/device:GPU:0"
1173
1174    with g.as_default():
1175      with ops.device(device_func):
1176        _ = variable_scope.get_variable("x", (100, 200))
1177        _ = variable_scope.get_variable(
1178            "y", dtype=dtypes.int64, initializer=numpy.arange(73))
1179    self.assertEqual(varname_type[0], ("x", dtypes.float32))
1180    self.assertEqual(varname_type[1], ("y", dtypes.int64))
1181
1182  # TODO(mihaimaruseac): Not converted to use wrap_function because of
1183  # obtaining different results in the eager case compared to the graph one
1184  @test_util.run_deprecated_v1
1185  def testGetCollection(self):
1186    with self.cached_session():
1187      _ = variable_scope.get_variable("testGetCollection_a", [])
1188      _ = variable_scope.get_variable(
1189          "testGetCollection_b", [], trainable=False)
1190      with variable_scope.variable_scope("testGetCollection_foo_") as scope1:
1191        _ = variable_scope.get_variable("testGetCollection_a", [])
1192        _ = variable_scope.get_variable(
1193            "testGetCollection_b", [], trainable=False)
1194        self.assertEqual([
1195            v.name
1196            for v in scope1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
1197        ], ["testGetCollection_foo_/testGetCollection_a:0"])
1198        self.assertEqual([
1199            v.name
1200            for v in scope1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
1201        ], [
1202            "testGetCollection_foo_/testGetCollection_a:0",
1203            "testGetCollection_foo_/testGetCollection_b:0"
1204        ])
1205      with variable_scope.variable_scope("testGetCollection_foo") as scope2:
1206        _ = variable_scope.get_variable("testGetCollection_a", [])
1207        _ = variable_scope.get_variable(
1208            "testGetCollection_b", [], trainable=False)
1209        self.assertEqual([
1210            v.name
1211            for v in scope2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
1212        ], ["testGetCollection_foo/testGetCollection_a:0"])
1213        self.assertEqual([
1214            v.name
1215            for v in scope2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
1216        ], [
1217            "testGetCollection_foo/testGetCollection_a:0",
1218            "testGetCollection_foo/testGetCollection_b:0"
1219        ])
1220      scope = variable_scope.get_variable_scope()
1221      self.assertEqual([
1222          v.name for v in scope.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
1223      ], [
1224          "testGetCollection_a:0", "testGetCollection_b:0",
1225          "testGetCollection_foo_/testGetCollection_a:0",
1226          "testGetCollection_foo_/testGetCollection_b:0",
1227          "testGetCollection_foo/testGetCollection_a:0",
1228          "testGetCollection_foo/testGetCollection_b:0"
1229      ])
1230      self.assertEqual([
1231          v.name
1232          for v in scope.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
1233      ], [
1234          "testGetCollection_a:0",
1235          "testGetCollection_foo_/testGetCollection_a:0",
1236          "testGetCollection_foo/testGetCollection_a:0"
1237      ])
1238
1239  # TODO(mihaimaruseac): Not converted to use wrap_function because of
1240  # obtaining different results in the eager case compared to the graph one
1241  @test_util.run_deprecated_v1
1242  def testGetTrainableVariablesWithGetVariable(self):
1243    with self.cached_session():
1244      _ = variable_scope.get_variable("testGetTrainableVariables_a", [])
1245      with variable_scope.variable_scope(
1246          "testGetTrainableVariables_foo") as scope:
1247        _ = variable_scope.get_variable("testGetTrainableVariables_b", [])
1248        _ = variable_scope.get_variable(
1249            "testGetTrainableVariables_c", [], trainable=False)
1250
1251        # sync `ON_READ` sets trainable=False
1252        _ = variable_scope.get_variable(
1253            "testGetTrainableVariables_d", [],
1254            synchronization=variable_scope.VariableSynchronization.ON_READ)
1255        self.assertEqual(
1256            [v.name for v in scope.trainable_variables()],
1257            ["testGetTrainableVariables_foo/testGetTrainableVariables_b:0"])
1258
1259        _ = variable_scope.get_variable(
1260            "testGetTrainableVariables_e", [],
1261            synchronization=variable_scope.VariableSynchronization.ON_READ,
1262            trainable=True)
1263        self.assertEqual([v.name for v in scope.trainable_variables()], [
1264            "testGetTrainableVariables_foo/testGetTrainableVariables_b:0",
1265            "testGetTrainableVariables_foo/testGetTrainableVariables_e:0",
1266        ])
1267
1268        # All other sync values sets trainable=True
1269        _ = variable_scope.get_variable(
1270            "testGetTrainableVariables_f", [],
1271            synchronization=variable_scope.VariableSynchronization.ON_WRITE)
1272        self.assertEqual([v.name for v in scope.trainable_variables()], [
1273            "testGetTrainableVariables_foo/testGetTrainableVariables_b:0",
1274            "testGetTrainableVariables_foo/testGetTrainableVariables_e:0",
1275            "testGetTrainableVariables_foo/testGetTrainableVariables_f:0",
1276        ])
1277
1278  # TODO(mihaimaruseac): Not converted to use wrap_function because of
1279  # obtaining different results in the eager case compared to the graph one
1280  @test_util.run_deprecated_v1
1281  def testGetTrainableVariablesWithVariable(self):
1282    with self.cached_session():
1283      _ = variable_scope.variable(1.0, name="testGetTrainableVariables_a")
1284      with variable_scope.variable_scope(
1285          "testGetTrainableVariables_foo") as scope:
1286        _ = variable_scope.variable(1.0, name="testGetTrainableVariables_b")
1287        _ = variable_scope.variable(
1288            1.0, name="testGetTrainableVariables_c", trainable=False)
1289
1290        # sync `ON_READ` sets trainable=False
1291        _ = variable_scope.variable(
1292            1.0,
1293            name="testGetTrainableVariables_d",
1294            synchronization=variable_scope.VariableSynchronization.ON_READ)
1295        self.assertEqual(
1296            [v.name for v in scope.trainable_variables()],
1297            ["testGetTrainableVariables_foo/testGetTrainableVariables_b:0"])
1298
1299        _ = variable_scope.variable(
1300            1.0,
1301            name="testGetTrainableVariables_e",
1302            synchronization=variable_scope.VariableSynchronization.ON_READ,
1303            trainable=True)
1304        self.assertEqual([v.name for v in scope.trainable_variables()], [
1305            "testGetTrainableVariables_foo/testGetTrainableVariables_b:0",
1306            "testGetTrainableVariables_foo/testGetTrainableVariables_e:0",
1307        ])
1308
1309        # All other sync values sets trainable=True
1310        _ = variable_scope.variable(
1311            1.0,
1312            name="testGetTrainableVariables_f",
1313            synchronization=variable_scope.VariableSynchronization.ON_WRITE)
1314        self.assertEqual([v.name for v in scope.trainable_variables()], [
1315            "testGetTrainableVariables_foo/testGetTrainableVariables_b:0",
1316            "testGetTrainableVariables_foo/testGetTrainableVariables_e:0",
1317            "testGetTrainableVariables_foo/testGetTrainableVariables_f:0",
1318        ])
1319
1320  # TODO(mihaimaruseac): Not converted to use wrap_function because of
1321  # obtaining different results in the eager case compared to the graph one
1322  @test_util.run_deprecated_v1
1323  def testGetGlobalVariables(self):
1324    with self.cached_session():
1325      _ = variable_scope.get_variable("testGetGlobalVariables_a", [])
1326      with variable_scope.variable_scope("testGetGlobalVariables_foo") as scope:
1327        _ = variable_scope.get_variable("testGetGlobalVariables_b", [])
1328        self.assertEqual(
1329            [v.name for v in scope.global_variables()],
1330            ["testGetGlobalVariables_foo/"
1331             "testGetGlobalVariables_b:0"])
1332
1333  # TODO(mihaimaruseac): Not converted to use wrap_function because of
1334  # obtaining different results in the eager case compared to the graph one
1335  @test_util.run_deprecated_v1
1336  def testGetLocalVariables(self):
1337    with self.cached_session():
1338      _ = variable_scope.get_variable(
1339          "a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
1340      with variable_scope.variable_scope("foo") as scope:
1341        _ = variable_scope.get_variable(
1342            "b", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
1343        _ = variable_scope.get_variable("c", [])
1344        self.assertEqual([v.name for v in scope.local_variables()], ["foo/b:0"])
1345
1346  @test_util.run_in_graph_and_eager_modes
1347  @run_inside_wrap_function_in_eager_mode
1348  def testGetVariableWithRefDtype(self):
1349    v = variable_scope.get_variable("v", shape=[3, 4], dtype=dtypes.float32)
1350    # Ensure it is possible to do get_variable with a _ref dtype passed in.
1351    _ = variable_scope.get_variable("w", shape=[5, 6], dtype=v.dtype)
1352
1353  @test_util.run_in_graph_and_eager_modes
1354  @run_inside_wrap_function_in_eager_mode
1355  def testGetVariableWithInitializerWhichTakesNoArgs(self):
1356    v = variable_scope.get_variable("foo", initializer=lambda: [2])
1357    self.assertEqual(v.name, "foo:0")
1358
1359  @test_util.run_in_graph_and_eager_modes
1360  @run_inside_wrap_function_in_eager_mode
1361  def testGetVariableWithInitializerWhichTakesOptionalArgs(self):
1362    v = variable_scope.get_variable("foo", initializer=lambda x=True: [2])
1363    self.assertEqual(v.name, "foo:0")
1364
1365  @test_util.run_in_graph_and_eager_modes
1366  @run_inside_wrap_function_in_eager_mode
1367  def testGetVariableWithInitializerWhichTakesUnprovidedArgsAndNoShape(self):
1368    with self.assertRaisesRegex(
1369        ValueError,
1370        "The initializer passed is not valid. It should be a callable with no "
1371        "arguments and the shape should not be provided or an instance of "
1372        "`tf.keras.initializers.*' and `shape` should be fully defined."):
1373      variable_scope.get_variable("foo", initializer=lambda x: [2])
1374
1375  @test_util.run_in_graph_and_eager_modes
1376  @run_inside_wrap_function_in_eager_mode
1377  def testTwoGraphs(self):
1378
1379    def f():
1380      g1 = ops.Graph()
1381      g2 = ops.Graph()
1382      with g1.as_default():
1383        with g2.as_default():
1384          with variable_scope.variable_scope("_"):
1385            pass
1386
1387    self.assertRaisesRegex(ValueError, "'_' is not a valid scope name", f)
1388
1389
1390def axis0_into1_partitioner(shape=None, **unused_kwargs):
1391  part = [1] * len(shape)
1392  return part
1393
1394
1395def axis0_into2_partitioner(shape=None, **unused_kwargs):
1396  part = [1] * len(shape)
1397  part[0] = 2
1398  return part
1399
1400
1401def axis0_into3_partitioner(shape=None, **unused_kwargs):
1402  part = [1] * len(shape)
1403  part[0] = 3
1404  return part
1405
1406
1407class VariableScopeWithPartitioningTest(test.TestCase):
1408
1409  # TODO(mihaimaruseac): Not converted to use wrap_function because of
1410  # obtaining different results in the eager case compared to the graph one
1411  @test_util.run_deprecated_v1
1412  def testResultNameMatchesRequested(self):
1413    with variable_scope.variable_scope(
1414        "scope0", partitioner=axis0_into2_partitioner):
1415      v = variable_scope.get_variable("name0", shape=(3, 1, 1))
1416      self.assertEqual(v.name, "scope0/name0")
1417      v_concat = v.as_tensor()
1418      self.assertEqual(v_concat.name, "scope0/name0:0")
1419      variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
1420      self.assertIn("scope0/name0/part_0:0", [x.name for x in variables])
1421      self.assertIn("scope0/name0/part_1:0", [x.name for x in variables])
1422      self.assertNotIn("scope0/name0/part_2:0", [x.name for x in variables])
1423
1424  @test_util.run_in_graph_and_eager_modes
1425  @run_inside_wrap_function_in_eager_mode
1426  def testBreaksIfPartitioningChanges(self):
1427    with variable_scope.variable_scope(
1428        "scope0", partitioner=axis0_into2_partitioner):
1429      variable_scope.get_variable("name0", shape=(3, 1, 1))
1430
1431    with variable_scope.variable_scope(
1432        "scope0", partitioner=axis0_into3_partitioner, reuse=True):
1433      with self.assertRaisesRegex(
1434          ValueError,
1435          "Trying to reuse partitioned variable .* but specified partitions "
1436          ".* and found partitions .*"):
1437        variable_scope.get_variable("name0", shape=(3, 1, 1))
1438
1439    with variable_scope.variable_scope(
1440        "scope0", partitioner=axis0_into1_partitioner, reuse=True):
1441      with self.assertRaisesRegex(
1442          ValueError,
1443          "Trying to reuse partitioned variable .* but specified partitions "
1444          ".* and found partitions .*"):
1445        variable_scope.get_variable("name0", shape=(3, 1, 1))
1446
1447  @test_util.run_in_graph_and_eager_modes
1448  @run_inside_wrap_function_in_eager_mode
1449  def testReturnsExistingConcatenatedValueIfReuse(self):
1450    with variable_scope.variable_scope(
1451        "scope0", partitioner=axis0_into2_partitioner):
1452      v_concat = variable_scope.get_variable("name0", shape=(3, 1, 1))
1453      variable_scope.get_variable_scope().reuse_variables()
1454      v_concat_2 = variable_scope.get_variable("name0", shape=(3, 1, 1))
1455      self.assertEqual(v_concat, v_concat_2)
1456
1457  @test_util.run_in_graph_and_eager_modes
1458  @run_inside_wrap_function_in_eager_mode
1459  def testAllowsReuseWithoutPartitioner(self):
1460    with variable_scope.variable_scope(
1461        "scope0", partitioner=axis0_into2_partitioner):
1462      v = variable_scope.get_variable("name0", shape=(3, 1, 1))
1463    with variable_scope.variable_scope("scope0", reuse=True):
1464      v_reused = variable_scope.get_variable("name0")
1465    self.assertIs(v, v_reused)
1466
1467  def testNoReuseInEagerByDefault(self):
1468    with context.eager_mode():
1469      with variable_scope.variable_scope(
1470          "scope0", partitioner=axis0_into2_partitioner):
1471        v1 = variable_scope.get_variable("name0", shape=(3, 1, 1))
1472        v2 = variable_scope.get_variable("name0", shape=(3, 1, 1))
1473        self.assertIsNot(v1, v2)
1474
1475  @test_util.run_in_graph_and_eager_modes
1476  @run_inside_wrap_function_in_eager_mode
1477  def testPropagatePartitionerOnReopening(self):
1478    with variable_scope.variable_scope(
1479        "scope0", partitioner=axis0_into2_partitioner) as vs:
1480      self.assertEqual(axis0_into2_partitioner, vs.partitioner)
1481      with variable_scope.variable_scope(vs) as vs1:
1482        self.assertEqual(axis0_into2_partitioner, vs1.partitioner)
1483
1484  # TODO(mihaimaruseac): Not converted to use wrap_function because of
1485  # obtaining different results in the eager case compared to the graph one
1486  @test_util.run_deprecated_v1
1487  def testScalarIgnoresPartitioner(self):
1488    with variable_scope.variable_scope(
1489        "scope0", partitioner=axis0_into2_partitioner):
1490      v = variable_scope.get_variable("name0", shape=())
1491      self.assertEqual(v.name, "scope0/name0:0")
1492      variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
1493      self.assertIn("scope0/name0:0", [x.name for x in variables])
1494
1495  def _testPartitionConcatenatesAlongCorrectAxis(self, use_resource):
1496    def _part_axis_0(**unused_kwargs):
1497      return (2, 1, 1)
1498
1499    def _part_axis_1(**unused_kwargs):
1500      return (1, 2, 1)
1501
1502    with variable_scope.variable_scope("root", use_resource=use_resource):
1503      v0 = variable_scope.get_variable(
1504          "n0", shape=(2, 2, 2), partitioner=_part_axis_0)
1505      v1 = variable_scope.get_variable(
1506          "n1", shape=(2, 2, 2), partitioner=_part_axis_1)
1507
1508    self.assertEqual(v0.get_shape(), (2, 2, 2))
1509    self.assertEqual(v1.get_shape(), (2, 2, 2))
1510
1511    n0_0 = list(v0)[0]
1512    n0_1 = list(v0)[1]
1513    self.assertEqual(n0_0.get_shape(), (1, 2, 2))
1514    self.assertEqual(n0_1.get_shape(), (1, 2, 2))
1515
1516    n1_0 = list(v1)[0]
1517    n1_1 = list(v1)[1]
1518    self.assertEqual(n1_0.get_shape(), (2, 1, 2))
1519    self.assertEqual(n1_1.get_shape(), (2, 1, 2))
1520
1521  @test_util.run_in_graph_and_eager_modes
1522  @run_inside_wrap_function_in_eager_mode
1523  def testPartitionConcatenatesAlongCorrectAxis(self):
1524    self._testPartitionConcatenatesAlongCorrectAxis(use_resource=False)
1525
1526  @test_util.run_in_graph_and_eager_modes
1527  @run_inside_wrap_function_in_eager_mode
1528  def testPartitionConcatenatesAlongCorrectAxisResource(self):
1529    self._testPartitionConcatenatesAlongCorrectAxis(use_resource=True)
1530
1531  def testPartitionConcatenatesAlongCorrectAxisResourceInEager(self):
1532    with context.eager_mode():
1533      self._testPartitionConcatenatesAlongCorrectAxis(use_resource=True)
1534
1535
1536class VariableScopeWithCustomGetterTest(test.TestCase):
1537
1538  @test_util.run_in_graph_and_eager_modes
1539  @run_inside_wrap_function_in_eager_mode
1540  def testNonCallableGetterFails(self):
1541    with self.assertRaisesRegex(ValueError, r"custom_getter .* not callable:"):
1542      with variable_scope.variable_scope("scope0", custom_getter=3):
1543        variable_scope.get_variable("name0")
1544    with self.assertRaisesRegex(ValueError, r"custom_getter .* not callable:"):
1545      variable_scope.get_variable("name0", custom_getter=3)
1546
1547  @test_util.run_in_graph_and_eager_modes
1548  @run_inside_wrap_function_in_eager_mode
1549  def testNoSideEffectsWithIdentityCustomGetter(self):
1550    called = [0]
1551
1552    def custom_getter(getter, *args, **kwargs):
1553      called[0] += 1
1554      return getter(*args, **kwargs)
1555
1556    with variable_scope.variable_scope(
1557        "scope", custom_getter=custom_getter) as scope:
1558      v = variable_scope.get_variable("v", [1])
1559    with variable_scope.variable_scope(scope, reuse=True):
1560      v2 = variable_scope.get_variable("v", [1])
1561    with variable_scope.variable_scope("new_scope") as new_scope:
1562      v3 = variable_scope.get_variable("v3", [1])
1563    with variable_scope.variable_scope(
1564        new_scope, reuse=True, custom_getter=custom_getter):
1565      v4 = variable_scope.get_variable("v3", [1])
1566
1567    self.assertIs(v, v2)
1568    self.assertIs(v3, v4)
1569    self.assertEqual(3, called[0])  # skipped one in the first new_scope
1570
1571  @test_util.run_in_graph_and_eager_modes
1572  @run_inside_wrap_function_in_eager_mode
1573  def testSynchronizationAndAggregationWithCustomGetter(self):
1574    called = [0]
1575    synchronization = variable_scope.VariableSynchronization.AUTO
1576    aggregation = variable_scope.VariableAggregation.NONE
1577
1578    def custom_getter(getter, *args, **kwargs):
1579      called[0] += 1
1580
1581      # Verify synchronization and aggregation kwargs are as expected.
1582      self.assertEqual(kwargs["synchronization"], synchronization)
1583      self.assertEqual(kwargs["aggregation"], aggregation)
1584      return getter(*args, **kwargs)
1585
1586    with variable_scope.variable_scope("scope", custom_getter=custom_getter):
1587      variable_scope.get_variable("v", [1])
1588    self.assertEqual(1, called[0])
1589
1590    with variable_scope.variable_scope("scope", custom_getter=custom_getter):
1591      synchronization = variable_scope.VariableSynchronization.ON_READ
1592      aggregation = variable_scope.VariableAggregation.MEAN
1593      variable_scope.get_variable(
1594          "v1", [1], synchronization=synchronization, aggregation=aggregation)
1595
1596    self.assertEqual(2, called[0])
1597
1598  @test_util.run_in_graph_and_eager_modes
1599  @run_inside_wrap_function_in_eager_mode
1600  def testCustomGetterWithReuse(self):
1601    # Custom getter can choose to behave differently on reused variables.
1602    def custom_getter(getter, *args, **kwargs):
1603      var = getter(*args, **kwargs)
1604      if kwargs["reuse"]:
1605        # This can be used, e.g., for changing the caching device if needed.
1606        return array_ops.identity(var, name="reused")
1607      else:
1608        return array_ops.identity(var, name="not_reused")
1609
1610    with variable_scope.variable_scope(
1611        "scope", custom_getter=custom_getter) as scope:
1612      v = variable_scope.get_variable("v", [1])
1613    with variable_scope.variable_scope(scope, reuse=True):
1614      v2 = variable_scope.get_variable("v", [1])
1615
1616    self.assertEqual(v.name, "not_reused:0")
1617    self.assertEqual(v2.name, "reused:0")
1618
1619  # TODO(mihaimaruseac): Not converted to use wrap_function because of
1620  # ValueError: Fetch argument <tf.Tensor 'custom_getter/add:0' shape=(1, 2, 3)
1621  # dtype=float32> cannot be interpreted as a Tensor. (Tensor
1622  # Tensor("custom_getter/add:0", shape=(1, 2, 3), dtype=float32) is not an
1623  # element of this graph.)
1624  @test_util.run_deprecated_v1
1625  def testGetterThatCreatesTwoVariablesAndSumsThem(self):
1626
1627    def custom_getter(getter, name, *args, **kwargs):
1628      g_0 = getter("%s/0" % name, *args, **kwargs)
1629      g_1 = getter("%s/1" % name, *args, **kwargs)
1630      with ops.name_scope("custom_getter"):
1631        return g_0 + g_1
1632
1633    with variable_scope.variable_scope("scope", custom_getter=custom_getter):
1634      v = variable_scope.get_variable("v", [1, 2, 3])
1635
1636    self.assertEqual([1, 2, 3], v.get_shape())
1637    true_vars = variables_lib.trainable_variables()
1638    self.assertEqual(2, len(true_vars))
1639    self.assertEqual("scope/v/0:0", true_vars[0].name)
1640    self.assertEqual("scope/v/1:0", true_vars[1].name)
1641    self.assertEqual("custom_getter/add:0", v.name)
1642    with self.cached_session() as sess:
1643      variables_lib.global_variables_initializer().run()
1644      np_vars, np_v = self.evaluate([true_vars, v])
1645      self.assertAllClose(np_v, sum(np_vars))
1646
1647  # TODO(mihaimaruseac): Not converted to use wrap_function because of
1648  # ValueError: Fetch argument <tf.Tensor 'sum_getter_2/add:0' shape=(1, 2, 3)
1649  # dtype=float32> cannot be interpreted as a Tensor. (Tensor
1650  # Tensor("sum_getter_2/add:0", shape=(1, 2, 3), dtype=float32) is not an
1651  # element of this graph.)
1652  @test_util.run_deprecated_v1
1653  def testNestedCustomGetters(self):
1654
1655    def sum_getter(getter, name, *args, **kwargs):
1656      g_0 = getter("%s/sum_0" % name, *args, **kwargs)
1657      g_1 = getter("%s/sum_1" % name, *args, **kwargs)
1658      with ops.name_scope("sum_getter"):
1659        return g_0 + g_1
1660
1661    def prod_getter(getter, name, *args, **kwargs):
1662      g_0 = getter("%s/prod_0" % name, *args, **kwargs)
1663      g_1 = getter("%s/prod_1" % name, *args, **kwargs)
1664      with ops.name_scope("prod_getter"):
1665        return g_0 * g_1
1666
1667    with variable_scope.variable_scope("prod_scope", custom_getter=prod_getter):
1668      with variable_scope.variable_scope("sum_scope", custom_getter=sum_getter):
1669        with variable_scope.variable_scope(
1670            "inner_sum_scope", custom_getter=sum_getter):
1671          # take sums of sums of products
1672          v = variable_scope.get_variable("v", [1, 2, 3])
1673
1674    self.assertEqual([1, 2, 3], v.get_shape())
1675    true_vars = variables_lib.trainable_variables()
1676    self.assertEqual(8, len(true_vars))
1677    template = (
1678        "prod_scope/sum_scope/inner_sum_scope/v/sum_%d/sum_%d/prod_%d:0")
1679    self.assertEqual(template % (0, 0, 0), true_vars[0].name)
1680    self.assertEqual(template % (0, 0, 1), true_vars[1].name)
1681    self.assertEqual(template % (0, 1, 0), true_vars[2].name)
1682    self.assertEqual(template % (0, 1, 1), true_vars[3].name)
1683    self.assertEqual(template % (1, 0, 0), true_vars[4].name)
1684    self.assertEqual(template % (1, 0, 1), true_vars[5].name)
1685    self.assertEqual(template % (1, 1, 0), true_vars[6].name)
1686    self.assertEqual(template % (1, 1, 1), true_vars[7].name)
1687
1688    with self.cached_session() as sess:
1689      variables_lib.global_variables_initializer().run()
1690      np_vars, np_v = self.evaluate([true_vars, v])
1691      # take products of sums of products
1692      self.assertAllClose(
1693          np_v, (((np_vars[0] * np_vars[1]) + (np_vars[2] * np_vars[3])) + (
1694              (np_vars[4] * np_vars[5]) + (np_vars[6] * np_vars[7]))))
1695
1696  @test_util.run_in_graph_and_eager_modes
1697  @run_inside_wrap_function_in_eager_mode
1698  def testVariableCreator(self):
1699    variable_names = []
1700
1701    def creator_a(next_creator, **kwargs):
1702      variable_names.append(kwargs.get("name", ""))
1703      return next_creator(**kwargs)
1704
1705    def creator_b(next_creator, **kwargs):
1706      kwargs["name"] = "forced_name"
1707      return next_creator(**kwargs)
1708
1709    with variable_scope.variable_creator_scope(creator_a):
1710      with variable_scope.variable_creator_scope(creator_b):
1711        variable_scope.variable(1.0, name="one_name")
1712
1713    self.assertEqual(variable_names[0], "forced_name")
1714
1715    called = [False]
1716
1717    def creater_c(next_creator, **kwargs):
1718      called[0] = True
1719      self.assertEqual(kwargs["synchronization"],
1720                       variable_scope.VariableSynchronization.ON_WRITE)
1721      self.assertEqual(kwargs["aggregation"],
1722                       variable_scope.VariableAggregation.MEAN)
1723      return next_creator(**kwargs)
1724
1725    with variable_scope.variable_creator_scope(creater_c):
1726      variable_scope.get_variable(
1727          "v", [],
1728          synchronization=variable_scope.VariableSynchronization.ON_WRITE,
1729          aggregation=variable_scope.VariableAggregation.MEAN)
1730    self.assertTrue(called[0])
1731
1732  @test_util.run_in_graph_and_eager_modes
1733  @run_inside_wrap_function_in_eager_mode
1734  def testVariableCreatorNestingError(self):
1735
1736    def creator(next_creator, **kwargs):
1737      return next_creator(**kwargs)
1738
1739    # Save the state so we can clean up at the end.
1740    graph = ops.get_default_graph()
1741    old_creator_stack = graph._variable_creator_stack
1742
1743    try:
1744      scope = variable_scope.variable_creator_scope(creator)
1745      scope.__enter__()
1746      with variable_scope.variable_creator_scope(creator):
1747        with self.assertRaises(RuntimeError):
1748          scope.__exit__(None, None, None)
1749    finally:
1750      graph._variable_creator_stack = old_creator_stack
1751
1752
1753class PartitionInfoTest(test.TestCase):
1754
1755  @test_util.run_in_graph_and_eager_modes
1756  @run_inside_wrap_function_in_eager_mode
1757  def testConstructorChecks(self):
1758    # Invalid arg types.
1759    with self.assertRaises(TypeError):
1760      variable_scope._PartitionInfo(full_shape=None, var_offset=[0, 1])
1761    with self.assertRaises(TypeError):
1762      variable_scope._PartitionInfo(full_shape=[0, 1], var_offset=None)
1763    with self.assertRaises(TypeError):
1764      variable_scope._PartitionInfo(full_shape="foo", var_offset=[0, 1])
1765    with self.assertRaises(TypeError):
1766      variable_scope._PartitionInfo(full_shape=[0, 1], var_offset="foo")
1767
1768    # full_shape and var_offset must have same length.
1769    with self.assertRaises(ValueError):
1770      variable_scope._PartitionInfo(full_shape=[0, 1], var_offset=[0])
1771    # Offset must always be less than shape.
1772    with self.assertRaises(ValueError):
1773      variable_scope._PartitionInfo(full_shape=[1, 1], var_offset=[0, 1])
1774
1775  @test_util.run_in_graph_and_eager_modes
1776  @run_inside_wrap_function_in_eager_mode
1777  def testSingleOffset(self):
1778    partition_info = variable_scope._PartitionInfo(
1779        full_shape=[9, 3], var_offset=[4, 0])
1780    self.assertEqual(4, partition_info.single_offset([1, 3]))
1781
1782    # Tests when the variable isn't partitioned at all.
1783    partition_info = variable_scope._PartitionInfo(
1784        full_shape=[9, 3], var_offset=[0, 0])
1785    self.assertEqual(0, partition_info.single_offset([9, 3]))
1786
1787  @test_util.run_in_graph_and_eager_modes
1788  @run_inside_wrap_function_in_eager_mode
1789  def testSingleSliceDim(self):
1790    partition_info = variable_scope._PartitionInfo(
1791        full_shape=[9, 3], var_offset=[4, 0])
1792    # Invalid shape.
1793    with self.assertRaises(TypeError):
1794      partition_info.single_slice_dim(None)
1795
1796    # Rank of shape differs from full_shape.
1797    with self.assertRaises(ValueError):
1798      partition_info.single_slice_dim([1, 2, 3])
1799
1800    # Shape is too large given var_offset (4+6 > 9).
1801    with self.assertRaises(ValueError):
1802      partition_info.single_slice_dim([6, 3])
1803
1804    # Multiple possible slice dim from shape.
1805    with self.assertRaises(ValueError):
1806      partition_info.single_slice_dim([1, 1])
1807
1808    partition_info = variable_scope._PartitionInfo(
1809        full_shape=[9, 3], var_offset=[0, 0])
1810    self.assertEqual(1, partition_info.single_slice_dim([9, 2]))
1811    partition_info = variable_scope._PartitionInfo(
1812        full_shape=[9, 3], var_offset=[4, 0])
1813    self.assertEqual(0, partition_info.single_slice_dim([2, 3]))
1814
1815
1816class VariableScopeMultithreadedTest(test.TestCase):
1817
1818  @test_util.run_in_graph_and_eager_modes
1819  @run_inside_wrap_function_in_eager_mode
1820  def testTwoThreadsDisjointScopeEntry(self):
1821
1822    def thread_fn(i, graph):
1823      with graph.as_default():
1824        with variable_scope.variable_scope("foo"):
1825          if i == 0:
1826            v = variable_scope.get_variable("v", [])
1827            self.assertEqual("foo/v:0", v.name)
1828          else:
1829            # Any thread after the first one should fail to create variable
1830            # with the same name.
1831            with self.assertRaises(ValueError):
1832              variable_scope.get_variable("v", [])
1833
1834    graph = ops.get_default_graph()
1835    threads = [
1836        threading.Thread(target=thread_fn, args=(
1837            i,
1838            graph,
1839        )) for i in range(2)
1840    ]
1841
1842    threads[0].start()
1843    # Allow thread 0 to finish before starting thread 1.
1844    threads[0].join()
1845    threads[1].start()
1846    threads[1].join()
1847
1848  @test_util.run_in_graph_and_eager_modes
1849  @run_inside_wrap_function_in_eager_mode
1850  def testTwoThreadsNestedScopeEntry(self):
1851
1852    def thread_fn(i, graph, run_event, pause_event):
1853      with graph.as_default():
1854        with variable_scope.variable_scope("foo"):
1855          if i == 0:
1856            v = variable_scope.get_variable("v", [])
1857            self.assertEqual("foo/v:0", v.name)
1858          else:
1859            # Any thread after the first one should fail to create variable
1860            # with the same name.
1861            with self.assertRaises(ValueError):
1862              variable_scope.get_variable("v", [])
1863          pause_event.set()
1864          run_event.wait()
1865
1866    graph = ops.get_default_graph()
1867    run_events = [threading.Event() for _ in range(2)]
1868    pause_events = [threading.Event() for _ in range(2)]
1869    threads = [
1870        threading.Thread(
1871            target=thread_fn, args=(i, graph, run_events[i], pause_events[i]))
1872        for i in range(2)
1873    ]
1874
1875    # Start first thread.
1876    threads[0].start()
1877    pause_events[0].wait()
1878    # Start next thread once the first thread has paused.
1879    threads[1].start()
1880    pause_events[1].wait()
1881    # Resume both threads.
1882    run_events[0].set()
1883    run_events[1].set()
1884    threads[0].join()
1885    threads[1].join()
1886
1887  @test_util.run_in_graph_and_eager_modes
1888  @run_inside_wrap_function_in_eager_mode
1889  def testReenterMainScope(self):
1890
1891    def thread_fn(graph, main_thread_scope):
1892      with graph.as_default():
1893        # Variable created with main scope will have prefix "main".
1894        with variable_scope.variable_scope(main_thread_scope):
1895          with variable_scope.variable_scope("foo"):
1896            v = variable_scope.get_variable("v", [])
1897            self.assertEqual("main/foo/v:0", v.name)
1898
1899        # Variable created outside main scope will not have prefix "main".
1900        with variable_scope.variable_scope("bar"):
1901          v = variable_scope.get_variable("v", [])
1902          self.assertEqual("bar/v:0", v.name)
1903
1904    graph = ops.get_default_graph()
1905    with variable_scope.variable_scope("main") as main_thread_scope:
1906      thread = threading.Thread(
1907          target=thread_fn, args=(graph, main_thread_scope))
1908      thread.start()
1909      thread.join()
1910
1911
1912if __name__ == "__main__":
1913  test.main()
1914