• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for variable store."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import gc
22import threading
23
24import numpy
25
26from tensorflow.python.eager import context
27from tensorflow.python.eager import function
28from tensorflow.python.eager import wrap_function
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import errors
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import test_util
34from tensorflow.python.layers import core as core_layers
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import init_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import resource_variable_ops
40from tensorflow.python.ops import state_ops
41from tensorflow.python.ops import variable_scope
42from tensorflow.python.ops import variables as variables_lib
43from tensorflow.python.platform import test
44from tensorflow.python.util import compat
45from tensorflow.python.util import tf_inspect
46
47
48def run_inside_wrap_function_in_eager_mode(graph_function):
49  """Decorator to execute the same graph code in eager and graph modes.
50
51  In graph mode, we just execute the graph_function passed as argument. In eager
52  mode, we wrap the function using wrap_function and then execute the wrapped
53  result.
54
55  Args:
56    graph_function: python function containing graph code to be wrapped
57
58  Returns:
59    decorated function
60  """
61  def wrap_and_execute(self):
62    if context.executing_eagerly():
63      wrapped = wrap_function.wrap_function(graph_function, [self])
64      # use the wrapped graph function
65      wrapped()
66    else:
67      # use the original function
68      graph_function(self)
69  return wrap_and_execute
70
71
72class VariableScopeTest(test.TestCase):
73
74  def tearDown(self):
75    gc.collect()
76    # This will only contain uncollectable garbage, i.e. reference cycles
77    # involving objects with __del__ defined.
78    self.assertEqual(0, len(gc.garbage))
79
80  @test_util.run_in_graph_and_eager_modes
81  @run_inside_wrap_function_in_eager_mode
82  def testGetVar(self):
83    vs = variable_scope._get_default_variable_store()
84    v = vs.get_variable("v", [1])
85    v1 = vs.get_variable("v", [1])
86    self.assertEqual(v, v1)
87
88  @test_util.run_in_graph_and_eager_modes
89  @run_inside_wrap_function_in_eager_mode
90  def testResource(self):
91    vs = variable_scope._get_default_variable_store()
92    v1 = vs.get_variable("v", [1], use_resource=True)
93    self.assertTrue(isinstance(v1, resource_variable_ops.ResourceVariable))
94
95  @test_util.run_in_graph_and_eager_modes
96  @run_inside_wrap_function_in_eager_mode
97  def testNameExists(self):
98    vs = variable_scope._get_default_variable_store()
99    # No check by default, so we can both create and get existing names.
100    v = vs.get_variable("v", [1])
101    v1 = vs.get_variable("v", [1])
102    self.assertEqual(v, v1)
103
104    # When reuse is False, we fail when variables are already there.
105    vs.get_variable("w", [1], reuse=False)  # That's ok.
106    with self.assertRaises(ValueError):
107      vs.get_variable("v", [1], reuse=False)  # That fails.
108    # When reuse is True, we fail when variables are new.
109    vs.get_variable("v", [1], reuse=True)  # That's ok.
110    with self.assertRaises(ValueError):
111      vs.get_variable("u", [1], reuse=True)  # That fails.
112
113  @test_util.run_in_graph_and_eager_modes
114  @run_inside_wrap_function_in_eager_mode
115  def testNamelessStore(self):
116    vs = variable_scope._get_default_variable_store()
117    vs.get_variable("v1", [2])
118    vs.get_variable("v2", [2])
119    expected_names = ["%s:0" % name for name in ["v1", "v2"]]
120    self.assertEqual(
121        set(expected_names), set([v.name for v in vs._vars.values()]))
122
123  # TODO(mihaimaruseac): Not converted to use wrap_function because of
124  # TypeError: Expected tf.group() expected Tensor arguments not 'None' with
125  # type '<type 'NoneType'>'
126  @test_util.run_in_graph_and_eager_modes
127  def testVarScopeInitializer(self):
128    init = init_ops.constant_initializer(0.3)
129    with variable_scope.variable_scope("tower0") as tower:
130      with variable_scope.variable_scope("foo", initializer=init):
131        v = variable_scope.get_variable("v", [])
132        self.evaluate(variables_lib.variables_initializer([v]))
133        self.assertAllClose(self.evaluate(v.value()), 0.3)
134      with variable_scope.variable_scope(tower, initializer=init):
135        w = variable_scope.get_variable("w", [])
136        self.evaluate(variables_lib.variables_initializer([w]))
137        self.assertAllClose(self.evaluate(w.value()), 0.3)
138
139  @test_util.run_in_graph_and_eager_modes
140  @run_inside_wrap_function_in_eager_mode
141  def testVarScopeConstraint(self):
142    constraint = lambda x: 0. * x
143    with variable_scope.variable_scope("tower1") as tower:
144      with variable_scope.variable_scope("foo", constraint=constraint):
145        v = variable_scope.get_variable("v", [])
146        self.assertEqual(v.constraint, constraint)
147      with variable_scope.variable_scope(tower, constraint=constraint):
148        w = variable_scope.get_variable("w", [])
149        self.assertEqual(w.constraint, constraint)
150
151  # TODO(mihaimaruseac): Not converted to use wrap_function because of
152  # TypeError: Fetch argument <tf.Variable 'string:0' shape=() dtype=string>
153  # has invalid type <class '...ResourceVariable'>, must be a string or Tensor.
154  # (Can not convert a ResourceVariable into a Tensor or Operation.)
155  @test_util.run_deprecated_v1
156  def testStringDefaultInitializer(self):
157    with self.cached_session():
158      v = variable_scope.get_variable("string", shape=[], dtype=dtypes.string)
159      variables_lib.global_variables_initializer().run()
160      self.assertAllEqual(compat.as_bytes(self.evaluate(v)), b"")
161
162  @test_util.run_in_graph_and_eager_modes
163  @run_inside_wrap_function_in_eager_mode
164  def testVarScopeDType(self):
165    with variable_scope.variable_scope("tower2") as tower:
166      with variable_scope.variable_scope("foo", dtype=dtypes.float16):
167        v = variable_scope.get_variable("v", [])
168        self.assertEqual(v.dtype.base_dtype, dtypes.float16)
169      with variable_scope.variable_scope(tower, dtype=dtypes.float16):
170        w = variable_scope.get_variable("w", [])
171        self.assertEqual(w.dtype.base_dtype, dtypes.float16)
172
173  def testGetVariableInGraphNestedUnderEagerContext(self):
174    with context.eager_mode():
175
176      @function.defun
177      def f():
178        v = variable_scope.get_variable("should_be_resource", [])
179        self.assertEqual(type(v), resource_variable_ops.ResourceVariable)
180
181      f()
182
183  def testEagerVariableStore(self):
184    with context.eager_mode():
185      store = variable_scope.EagerVariableStore()
186      with store.as_default():
187        v = variable_scope.get_variable("v", shape=(), trainable=True)
188        w = variable_scope.get_variable("w", shape=(), trainable=False)
189
190      self.assertTrue(v in store.variables())
191      self.assertTrue(w in store.variables())
192      self.assertTrue(v in store.trainable_variables())
193      self.assertFalse(w in store.trainable_variables())
194      self.assertFalse(v in store.non_trainable_variables())
195      self.assertTrue(w in store.non_trainable_variables())
196
197      # Test copying.
198      new_store = store.copy()
199      with new_store.as_default():
200        new_v = variable_scope.get_variable("v")
201        new_w = variable_scope.get_variable("w")
202      self.assertEqual(new_v.numpy(), v.numpy())
203      self.assertEqual(new_w.numpy(), w.numpy())
204      self.assertTrue(new_v in new_store.variables())
205      self.assertTrue(new_w in new_store.variables())
206      self.assertTrue(new_v in new_store.trainable_variables())
207      self.assertFalse(new_w in new_store.trainable_variables())
208      self.assertFalse(new_v in new_store.non_trainable_variables())
209      self.assertTrue(new_w in new_store.non_trainable_variables())
210
211      # Check that variables are separate instances.
212      for v in store.variables():
213        v.assign(-1)
214      for v in new_store.variables():
215        v.assign(1)
216      for v in store.variables():
217        self.assertEqual(v.numpy(), -1)
218      for v in new_store.variables():
219        self.assertEqual(v.numpy(), 1)
220
221  def testEagerVariableStoreWithEagerDefun(self):
222    with context.eager_mode():
223
224      @function.defun
225      def f():
226        x = constant_op.constant([[2.0]])
227        d1 = core_layers.Dense(
228            1, name="my_dense", kernel_initializer=init_ops.ones_initializer())
229        _ = d1(x)  # create variables
230        self.assertEqual(len(d1.variables), 2)
231        v1, v2 = d1.variables
232        d2 = core_layers.Dense(
233            1,
234            name="my_dense",
235            kernel_initializer=init_ops.ones_initializer(),
236            _reuse=True)
237        _ = d2(x)
238        self.assertEqual(len(d2.variables), 2)
239        v3, v4 = d2.variables
240        self.assertEqual(v1, v3)
241        self.assertEqual(v2, v4)
242      f()
243
244  # TODO(mihaimaruseac): Not converted to use wrap_function because of
245  # obtaining different results in the eager case compared to the graph one
246  @test_util.run_in_graph_and_eager_modes
247  def testEagerVariablesStoreAddsToCollections(self):
248    store = variable_scope.EagerVariableStore()
249    with store.as_default():
250      trainable = variable_scope.get_variable("v1", [], trainable=True)
251      not_trainable = variable_scope.get_variable("v2", [], trainable=False)
252      concat = variable_scope.get_variable(
253          "v3", [], collections=[ops.GraphKeys.CONCATENATED_VARIABLES])
254      self.assertEqual(
255          ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES),
256          [trainable, not_trainable])
257      self.assertEqual(
258          ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES),
259          [trainable, concat])
260      self.assertEqual(
261          ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES), [concat])
262
263  def testEagerVariablesOutsideStoreNotAddedToCollections(self):
264    with context.eager_mode():
265      variable_scope.get_variable("v1", [], trainable=True)
266      variable_scope.get_variable("v2", [], trainable=False)
267      self.assertFalse(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
268      self.assertFalse(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
269
270  # TODO(mihaimaruseac): Not converted to use wrap_function because of
271  # TypeError: Expected tf.group() expected Tensor arguments not 'None' with
272  # type '<type 'NoneType'>'.
273  @test_util.run_in_graph_and_eager_modes
274  def testInitFromNonTensorValue(self):
275    v = variable_scope.get_variable("v4", initializer=4, dtype=dtypes.int32)
276    self.evaluate(variables_lib.variables_initializer([v]))
277    self.assertAllClose(self.evaluate(v.value()), 4)
278
279    w = variable_scope.get_variable(
280        "w4", initializer=numpy.array([1, 2, 3]), dtype=dtypes.int64)
281    self.evaluate(variables_lib.variables_initializer([w]))
282    self.assertAllClose(self.evaluate(w.value()), [1, 2, 3])
283
284    # A quirk to be revisited?
285    error = ValueError if context.executing_eagerly() else TypeError
286    with self.assertRaises(error):
287      variable_scope.get_variable("x4", initializer={})
288
289  # TODO(mihaimaruseac): Not converted to use wrap_function because of
290  # InvalidArgumentError=: You must feed a value for placeholder tensor
291  # 'ReadVariableOp/resource' with dtype resource
292  @test_util.run_in_graph_and_eager_modes
293  def testInitFromNonInitializer(self):
294    # Test various dtypes with zeros initializer as following:
295    types = [
296        dtypes.int8, dtypes.uint8, dtypes.int16, dtypes.uint16, dtypes.int32,
297        dtypes.int64, dtypes.bool
298    ]
299
300    # Use different variable_name to distinguish various dtypes
301    for (i, dtype) in enumerate(types):
302      x = variable_scope.get_variable(
303          name="xx%d" % i, shape=(3, 4), dtype=dtype)
304      y = variable_scope.get_variable(
305          name="yy%d" % i,
306          shape=(3, 4),
307          dtype=dtype,
308          initializer=init_ops.zeros_initializer(dtype=dtype))
309
310      self.evaluate(variables_lib.global_variables_initializer())
311      self.assertAllEqual(self.evaluate(x.value()), self.evaluate(y.value()))
312
313  # TODO(mihaimaruseac): Not converted to use wrap_function because of
314  # InvalidArgumentError: /job:moo/replica:0/task:0/device:CPU:0 unknown device.
315  @test_util.run_deprecated_v1
316  def testVarScopeCachingDevice(self):
317    with self.cached_session():
318      caching_device = "/job:moo"
319      with variable_scope.variable_scope("tower"):
320        with variable_scope.variable_scope(
321            "caching", caching_device=caching_device):
322          v = variable_scope.get_variable("v", [])
323          self.assertTrue(v.value().device.startswith(caching_device))
324
325          with variable_scope.variable_scope("child"):
326            v2 = variable_scope.get_variable("v", [])
327            self.assertTrue(v2.value().device.startswith(caching_device))
328
329          with variable_scope.variable_scope("not_cached", caching_device=""):
330            v2_not_cached = variable_scope.get_variable("v", [])
331            self.assertFalse(
332                v2_not_cached.value().device.startswith(caching_device))
333
334          with variable_scope.variable_scope(
335              "not_cached_identity_device",
336              caching_device=lambda op: op.device):
337            v2_identity_device = variable_scope.get_variable("v", [])
338            self.assertFalse(
339                v2_identity_device.value().device.startswith(caching_device))
340
341          with variable_scope.variable_scope("we_will_do_it_live") as vs_live:
342            vs_live.set_caching_device("/job:live")
343            v_live = variable_scope.get_variable("v", [])
344            self.assertTrue(v_live.value().device.startswith("/job:live"))
345
346        v_tower = variable_scope.get_variable("v", [])
347        self.assertFalse(v_tower.value().device.startswith(caching_device))
348
349  # TODO(mihaimaruseac): Not converted to use wrap_function because of
350  # AttributeError: Tensor.name is meaningless when eager execution is enabled.
351  @test_util.run_in_graph_and_eager_modes
352  def testVarScopeRegularizer(self):
353    init = init_ops.constant_initializer(0.3)
354
355    def regularizer1(v):
356      return math_ops.reduce_mean(v) + 0.1
357
358    def regularizer2(v):
359      return math_ops.reduce_mean(v) + 0.2
360
361    with variable_scope.variable_scope(
362        "tower3", regularizer=regularizer1) as tower:
363      with variable_scope.variable_scope("foo", initializer=init):
364        v = variable_scope.get_variable("v", [])
365        self.evaluate(variables_lib.variables_initializer([v]))
366        losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
367        self.assertEqual(1, len(losses))
368        self.assertAllClose(self.evaluate(losses[0]), 0.4)
369      with variable_scope.variable_scope(tower, initializer=init) as vs:
370        u = variable_scope.get_variable("u", [])
371        vs.set_regularizer(regularizer2)
372        w = variable_scope.get_variable("w", [])
373        # Next 3 variable not regularized to test disabling regularization.
374        x = variable_scope.get_variable(
375            "x", [], regularizer=variable_scope.no_regularizer)
376        with variable_scope.variable_scope(
377            "baz", regularizer=variable_scope.no_regularizer):
378          y = variable_scope.get_variable("y", [])
379        vs.set_regularizer(variable_scope.no_regularizer)
380        z = variable_scope.get_variable("z", [])
381        # Check results.
382        losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
383        self.assertEqual(3, len(losses))
384        self.evaluate(variables_lib.variables_initializer([u, w, x, y, z]))
385        self.assertAllClose(self.evaluate(losses[0]), 0.4)
386        self.assertAllClose(self.evaluate(losses[1]), 0.4)
387        self.assertAllClose(self.evaluate(losses[2]), 0.5)
388      with variable_scope.variable_scope("foo", reuse=True):
389        # reuse=True is for now only supported when eager execution is disabled.
390        if not context.executing_eagerly():
391          v = variable_scope.get_variable("v",
392                                          [])  # "v" is already there, reused
393          losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
394          self.assertEqual(3, len(losses))  # No new loss added.
395
396  # TODO(mihaimaruseac): Not converted to use wrap_function because of
397  # ValueError: Tensor-typed variable initializers must either be wrapped in an
398  # init_scope or callable...
399  @test_util.run_in_graph_and_eager_modes
400  def testInitializeFromValue(self):
401    init = constant_op.constant(0.1)
402    w = variable_scope.get_variable("v", initializer=init)
403    self.evaluate(variables_lib.variables_initializer([w]))
404    self.assertAllClose(self.evaluate(w.value()), 0.1)
405
406    with self.assertRaisesRegexp(ValueError, "shape"):
407      # We disallow explicit shape specification when initializer is constant.
408      variable_scope.get_variable("u", [1], initializer=init)
409
410    with variable_scope.variable_scope("foo", initializer=init):
411      # Constant initializer can be passed through scopes if needed.
412      v = variable_scope.get_variable("v")
413      self.evaluate(variables_lib.variables_initializer([v]))
414      self.assertAllClose(self.evaluate(v.value()), 0.1)
415
416    # Check that non-float32 initializer creates a non-float32 variable.
417    init = constant_op.constant(1, dtype=dtypes.int32)
418    t = variable_scope.get_variable("t", initializer=init)
419    self.assertEqual(t.dtype.base_dtype, dtypes.int32)
420
421    # Raise error if `initializer` dtype and `dtype` are not identical.
422    with self.assertRaisesRegexp(ValueError, "don't match"):
423      variable_scope.get_variable("s", initializer=init, dtype=dtypes.float64)
424
425  # TODO(mihaimaruseac): Not converted to use wrap_function because of
426  # TypeError: Fetch argument <tf.Variable 'v0:0' shape=(1,) dtype=float32> has
427  # invalid type <class '...ops.resource_variable_ops.ResourceVariable'>, must
428  # be a string or Tensor. (Can not convert a ResourceVariable into a Tensor or
429  # Operation.)
430  @test_util.run_deprecated_v1
431  def testControlDeps(self):
432    with self.cached_session() as sess:
433      v0 = variable_scope.get_variable(
434          "v0", [1], initializer=init_ops.constant_initializer(0))
435      with ops.control_dependencies([v0.value()]):
436        v1 = variable_scope.get_variable(
437            "v1", [1], initializer=init_ops.constant_initializer(1))
438        add = v1 + v0
439      # v0 should be uninitialized.
440      with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
441        self.evaluate(v0)
442      # We should be able to initialize and run v1 without initializing
443      # v0, even if the variable was created with a control dep on v0.
444      self.evaluate(v1.initializer)
445      self.assertEqual(1, self.evaluate(v1))
446      # v0 should still be uninitialized.
447      with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
448        self.evaluate(v0)
449      with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
450        self.evaluate(add)
451      # If we initialize v0 we should be able to run 'add'.
452      self.evaluate(v0.initializer)
453      self.evaluate(add)
454
455  # TODO(mihaimaruseac): Not converted to use wrap_function because of
456  # AssertionError: True is not false (last assertFalse)
457  @test_util.run_deprecated_v1
458  def testEnableResourceVariables(self):
459    old = variable_scope._DEFAULT_USE_RESOURCE
460    try:
461      variable_scope.enable_resource_variables()
462      self.assertTrue(isinstance(variables_lib.VariableV1(1.0),
463                                 resource_variable_ops.ResourceVariable))
464      variable_scope.disable_resource_variables()
465      self.assertFalse(isinstance(variables_lib.VariableV1(1.0),
466                                  resource_variable_ops.ResourceVariable))
467    finally:
468      variable_scope._DEFAULT_USE_RESOURCE = old
469
470  # TODO(mihaimaruseac): Not converted to use wrap_function because of
471  # TypeError: Fetch argument None has invalid type <type 'NoneType'>
472  @test_util.run_deprecated_v1
473  def testControlFlow(self):
474    with self.cached_session() as sess:
475      v0 = variable_scope.get_variable(
476          "v0", [], initializer=init_ops.constant_initializer(0))
477      var_dict = {}
478
479      # Call get_variable in each of the cond clauses.
480      def var_in_then_clause():
481        v1 = variable_scope.get_variable(
482            "v1", [1], initializer=init_ops.constant_initializer(1))
483        var_dict["v1"] = v1
484        return v1 + v0
485
486      def var_in_else_clause():
487        v2 = variable_scope.get_variable(
488            "v2", [1], initializer=init_ops.constant_initializer(2))
489        var_dict["v2"] = v2
490        return v2 + v0
491
492      add = control_flow_ops.cond(
493          math_ops.less(v0, 10), var_in_then_clause, var_in_else_clause)
494      v1 = var_dict["v1"]
495      v2 = var_dict["v2"]
496      # We should be able to initialize and run v1 and v2 without initializing
497      # v0, even if the variable was created with a control dep on v0.
498      self.evaluate(v1.initializer)
499      self.assertEqual([1], self.evaluate(v1))
500      self.evaluate(v2.initializer)
501      self.assertEqual([2], self.evaluate(v2))
502      # v0 should still be uninitialized.
503      with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
504        self.evaluate(v0)
505      # We should not be able to run 'add' yet.
506      with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
507        self.evaluate(add)
508      # If we initialize v0 we should be able to run 'add'.
509      self.evaluate(v0.initializer)
510      self.evaluate(add)
511
512  # TODO(mihaimaruseac): Not converted to use wrap_function because of
513  # TypeError: Expected tf.group() expected Tensor arguments not 'None' with
514  # type '<type 'NoneType'>'.
515  @test_util.run_in_graph_and_eager_modes
516  def testGetVariableScope(self):
517    # Test the get_variable_scope() function and setting properties of result.
518    init = init_ops.constant_initializer(0.3)
519    with variable_scope.variable_scope("bar"):
520      new_init1 = variable_scope.get_variable_scope().initializer
521      self.assertEqual(new_init1, None)
522      # Check that we can set initializer like this.
523      variable_scope.get_variable_scope().set_initializer(init)
524      v = variable_scope.get_variable("v", [])
525      self.evaluate(variables_lib.variables_initializer([v]))
526      self.assertAllClose(self.evaluate(v.value()), 0.3)
527      if not context.executing_eagerly():
528        # Check that we can set reuse.
529        variable_scope.get_variable_scope().reuse_variables()
530        with self.assertRaises(ValueError):  # Fail, w does not exist yet.
531          variable_scope.get_variable("w", [1])
532    # Check that the set initializer goes away.
533    new_init = variable_scope.get_variable_scope().initializer
534    self.assertEqual(new_init, None)
535
536  @test_util.run_in_graph_and_eager_modes
537  @run_inside_wrap_function_in_eager_mode
538  def testVarScope(self):
539    with variable_scope.variable_scope("tower4") as tower:
540      self.assertEqual(tower.name, "tower4")
541      with ops.name_scope("scope") as sc:
542        self.assertEqual(sc, "tower4/scope/")
543
544    with variable_scope.variable_scope("tower5"):
545      with variable_scope.variable_scope("bar") as bar:
546        self.assertEqual(bar.name, "tower5/bar")
547        with ops.name_scope("scope") as sc:
548          self.assertEqual(sc, "tower5/bar/scope/")
549
550    with variable_scope.variable_scope("tower6"):
551      with variable_scope.variable_scope(tower, reuse=True) as tower_shared:
552        self.assertEqual(tower_shared.name, "tower4")
553        with ops.name_scope("scope") as sc:
554          self.assertEqual(sc, "tower6/tower4/scope/")
555
556  @test_util.run_in_graph_and_eager_modes
557  @run_inside_wrap_function_in_eager_mode
558  def testVarScopeNameScope(self):
559    with ops.name_scope("testVarScopeNameScope1"):
560      with variable_scope.variable_scope("tower") as tower:
561        with ops.name_scope("scope2") as sc2:
562          self.assertEqual(sc2, "testVarScopeNameScope1/tower/scope2/")
563      if not context.executing_eagerly():
564        with variable_scope.variable_scope(
565            tower):  # Re-entering acts like another "tower".
566          with ops.name_scope("scope2") as sc2:
567            self.assertEqual(sc2, "testVarScopeNameScope1/tower_1/scope2/")
568        with variable_scope.variable_scope(
569            "tower"):  # Re-entering by string acts the same.
570          with ops.name_scope("scope2") as sc2:
571            self.assertEqual(sc2, "testVarScopeNameScope1/tower_2/scope2/")
572
573    with ops.name_scope("testVarScopeNameScope2"):
574      with variable_scope.variable_scope("tower"):
575        with ops.name_scope("scope2") as sc2:
576          self.assertEqual(sc2, "testVarScopeNameScope2/tower/scope2/")
577      if not context.executing_eagerly():
578        with variable_scope.variable_scope(tower):
579          with ops.name_scope("scope2") as sc2:
580            self.assertEqual(sc2, "testVarScopeNameScope2/tower_1/scope2/")
581
582    root_var_scope = variable_scope.get_variable_scope()
583    with ops.name_scope("testVarScopeNameScope3"):
584      with variable_scope.variable_scope(root_var_scope):
585        with ops.name_scope("scope2") as sc2:
586          self.assertEqual(sc2, "testVarScopeNameScope3/scope2/")
587
588  @test_util.run_in_graph_and_eager_modes
589  @run_inside_wrap_function_in_eager_mode
590  def testVarScopeOriginalNameScope(self):
591    with self.cached_session():
592      with ops.name_scope("scope1"):
593        with variable_scope.variable_scope("tower") as tower:
594          self.assertEqual(tower.original_name_scope, "scope1/tower/")
595          with ops.name_scope("scope2") as sc2:
596            self.assertEqual(sc2, "scope1/tower/scope2/")
597      with ops.name_scope("scope2"):
598        with variable_scope.variable_scope(tower) as tower1:
599          # Re-entering preserves original name scope.
600          self.assertEqual(tower1.original_name_scope, "scope1/tower/")
601          with ops.name_scope("foo") as sc2:
602            self.assertEqual(sc2, "scope2/tower/foo/")
603        # Test re-entering original name scope.
604        with ops.name_scope(tower.original_name_scope):
605          with ops.name_scope("bar") as sc3:
606            self.assertEqual(sc3, "scope1/tower/bar/")
607      with ops.name_scope("scope2"):
608        with variable_scope.variable_scope(tower):
609          with ops.name_scope(tower.original_name_scope):
610            with ops.name_scope("bar") as sc3:
611              self.assertEqual(sc3, "scope1/tower/bar_1/")
612
613  @test_util.run_in_graph_and_eager_modes
614  @run_inside_wrap_function_in_eager_mode
615  def testVarScopeObjectReuse(self):
616    with self.cached_session():
617      vs = None
618      with variable_scope.variable_scope("jump", reuse=True) as scope:
619        vs = scope
620
621      with variable_scope.variable_scope(vs) as jump:
622        self.assertTrue(jump.reuse)
623
624      with variable_scope.variable_scope(vs, reuse=True) as jump_reuse:
625        self.assertTrue(jump_reuse.reuse)
626
627      with variable_scope.variable_scope(vs, reuse=False) as jump_no_reuse:
628        self.assertTrue(jump_no_reuse.reuse)  # Inherited, cannot be undone.
629
630      with variable_scope.variable_scope("jump", reuse=False) as scope:
631        vs = scope
632
633      with variable_scope.variable_scope(vs) as jump:
634        self.assertFalse(jump.reuse)
635
636      with variable_scope.variable_scope(vs, reuse=True) as jump_reuse:
637        self.assertTrue(jump_reuse.reuse)
638
639      with variable_scope.variable_scope(vs, reuse=False) as jump_no_reuse:
640        self.assertFalse(jump_no_reuse.reuse)
641
642  @test_util.run_in_graph_and_eager_modes
643  @run_inside_wrap_function_in_eager_mode
644  def testVarScopeGetOrCreateReuse(self):
645    with self.cached_session():
646
647      def test_value(value):
648        x = constant_op.constant(value)
649        with variable_scope.variable_scope(
650            "testVarScopeGetOrCreateReuse_bar",
651            reuse=variable_scope.AUTO_REUSE):
652          _ = state_ops.assign(variable_scope.get_variable("var", []), x)
653        with variable_scope.variable_scope(
654            "testVarScopeGetOrCreateReuse_bar",
655            reuse=variable_scope.AUTO_REUSE):
656          _ = variable_scope.get_variable("var", [])
657        self.assertEqual(value, self.evaluate(x))
658
659      test_value(42.)  # Variable is created.
660      test_value(13.)  # Variable is reused hereafter.
661      test_value(17.)
662
663  @test_util.run_in_graph_and_eager_modes
664  @run_inside_wrap_function_in_eager_mode
665  def testVarOpScope(self):
666    with self.cached_session():
667      with ops.name_scope("testVarOpScope1"):
668        with variable_scope.variable_scope("tower", "default", []):
669          self.assertEqual(
670              variable_scope.get_variable("w", []).name, "tower/w:0")
671          with ops.name_scope("testVarOpScope2") as sc2:
672            self.assertEqual(sc2, "testVarOpScope1/tower/testVarOpScope2/")
673        with variable_scope.variable_scope("tower", "default", []):
674          with self.assertRaises(ValueError):
675            variable_scope.get_variable("w", [])
676          with ops.name_scope("testVarOpScope2") as sc2:
677            self.assertEqual(sc2, "testVarOpScope1/tower_1/testVarOpScope2/")
678
679      with ops.name_scope("testVarOpScope2"):
680        with variable_scope.variable_scope(None, "default", []):
681          self.assertEqual(
682              variable_scope.get_variable("w", []).name, "default/w:0")
683          with ops.name_scope("testVarOpScope2") as sc2:
684            self.assertEqual(sc2, "testVarOpScope2/default/testVarOpScope2/")
685        with variable_scope.variable_scope(None, "default", []):
686          self.assertEqual(
687              variable_scope.get_variable("w", []).name, "default_1/w:0")
688          with ops.name_scope("testVarOpScope2") as sc2:
689            self.assertEqual(sc2, "testVarOpScope2/default_1/testVarOpScope2/")
690
691  @test_util.run_in_graph_and_eager_modes
692  @run_inside_wrap_function_in_eager_mode
693  def testVarOpScopeUniqueNamesInterleavedSubstringScopes(self):
694    with self.cached_session():
695      with variable_scope.variable_scope(None, "defaultScope1"):
696        with variable_scope.variable_scope(None, "layer"):
697          self.assertEqual(
698              variable_scope.get_variable("w", []).name,
699              "defaultScope1/layer/w:0")
700      with variable_scope.variable_scope(None, "defaultScope1"):
701        with variable_scope.variable_scope(None, "layer"):
702          self.assertEqual(
703              variable_scope.get_variable("w", []).name,
704              "defaultScope1_1/layer/w:0")
705      with variable_scope.variable_scope(None, "defaultScope"):
706        with variable_scope.variable_scope(None, "layer"):
707          self.assertEqual(
708              variable_scope.get_variable("w", []).name,
709              "defaultScope/layer/w:0")
710      with variable_scope.variable_scope(None, "defaultScope1"):
711        with variable_scope.variable_scope(None, "layer"):
712          self.assertEqual(
713              variable_scope.get_variable("w", []).name,
714              "defaultScope1_2/layer/w:0")
715
716  @test_util.run_in_graph_and_eager_modes
717  @run_inside_wrap_function_in_eager_mode
718  def testVarOpScopeUniqueNamesWithJump(self):
719    with self.cached_session():
720      with variable_scope.variable_scope("default") as default:
721        with variable_scope.variable_scope(None, "layer"):
722          self.assertEqual(
723              variable_scope.get_variable("w", []).name, "default/layer/w:0")
724        with variable_scope.variable_scope(None, "layer"):
725          self.assertEqual(
726              variable_scope.get_variable("w", []).name,
727              "default/layer_1/w:0")
728        with variable_scope.variable_scope(default):
729          pass
730        # No matter the jump in the middle, unique numbering continues.
731        with variable_scope.variable_scope(None, "layer"):
732          self.assertEqual(
733              variable_scope.get_variable("w", []).name,
734              "default/layer_2/w:0")
735
736  @test_util.run_in_graph_and_eager_modes
737  @run_inside_wrap_function_in_eager_mode
738  def testVarOpScopeReuse(self):
739    with self.cached_session():
740      with variable_scope.variable_scope("outer") as outer:
741        with variable_scope.variable_scope("tower", "default", []):
742          self.assertEqual(
743              variable_scope.get_variable("w", []).name, "outer/tower/w:0")
744          with ops.name_scope("scope2") as sc2:
745            self.assertEqual(sc2, "outer/tower/scope2/")
746        with variable_scope.variable_scope(None, "default", []):
747          self.assertEqual(
748              variable_scope.get_variable("w", []).name, "outer/default/w:0")
749          with ops.name_scope("scope2") as sc2:
750            self.assertEqual(sc2, "outer/default/scope2/")
751
752      with variable_scope.variable_scope(outer, reuse=True) as outer:
753        with variable_scope.variable_scope("tower", "default", []):
754          self.assertEqual(
755              variable_scope.get_variable("w", []).name, "outer/tower/w:0")
756          with ops.name_scope("scope2") as sc2:
757            self.assertEqual(sc2, "outer_1/tower/scope2/")
758        with variable_scope.variable_scope(None, "default", []):
759          self.assertEqual(
760              variable_scope.get_variable("w", []).name, "outer/default/w:0")
761          with ops.name_scope("scope2") as sc2:
762            self.assertEqual(sc2, "outer_1/default/scope2/")
763
764  @test_util.run_in_graph_and_eager_modes
765  @run_inside_wrap_function_in_eager_mode
766  def testVarScopeGetVar(self):
767    with self.cached_session():
768      with variable_scope.variable_scope("root"):
769        with variable_scope.variable_scope("towerA") as tower_a:
770          va = variable_scope.get_variable("v", [1])
771          self.assertEqual(va.name, "root/towerA/v:0")
772
773        with variable_scope.variable_scope(tower_a, reuse=True):
774          va2 = variable_scope.get_variable("v", [1])
775          self.assertEqual(va2, va)
776
777        with variable_scope.variable_scope("towerB"):
778          vb = variable_scope.get_variable("v", [1])
779          self.assertEqual(vb.name, "root/towerB/v:0")
780
781        with self.assertRaises(ValueError):
782          with variable_scope.variable_scope("towerA"):
783            va2 = variable_scope.get_variable("v", [1])
784
785        with variable_scope.variable_scope("towerA", reuse=True):
786          va2 = variable_scope.get_variable("v", [1])
787          self.assertEqual(va2, va)
788
789        with variable_scope.variable_scope("foo"):
790          with variable_scope.variable_scope("bar"):
791            v = variable_scope.get_variable("v", [1])
792            self.assertEqual(v.name, "root/foo/bar/v:0")
793            with variable_scope.variable_scope(tower_a, reuse=True):
794              va3 = variable_scope.get_variable("v", [1])
795              self.assertEqual(va, va3)
796
797        with self.assertRaises(ValueError):
798          with variable_scope.variable_scope(tower_a, reuse=True):
799            with variable_scope.variable_scope("baz"):
800              variable_scope.get_variable("v", [1])
801
802        with self.assertRaises(ValueError) as exc:
803          with variable_scope.variable_scope(tower_a, reuse=True):
804            variable_scope.get_variable("v", [2])  # Different shape.
805        self.assertEqual("shape" in str(exc.exception), True)
806
807        with self.assertRaises(ValueError) as exc:
808          with variable_scope.variable_scope(tower_a, reuse=True):
809            variable_scope.get_variable("v", [1], dtype=dtypes.int32)
810        self.assertEqual("dtype" in str(exc.exception), True)
811
812  @test_util.run_in_graph_and_eager_modes
813  @run_inside_wrap_function_in_eager_mode
814  def testVarScopeOuterScope(self):
815    with self.cached_session():
816      with variable_scope.variable_scope("outer") as outer:
817        pass
818      with variable_scope.variable_scope(outer):
819        self.assertEqual(
820            variable_scope.get_variable("w", []).name, "outer/w:0")
821        with ops.name_scope("scope2") as sc2:
822          self.assertEqual(sc2, "outer_1/scope2/")
823        with variable_scope.variable_scope("default"):
824          self.assertEqual(
825              variable_scope.get_variable("w", []).name, "outer/default/w:0")
826          with ops.name_scope("scope2") as sc2:
827            self.assertEqual(sc2, "outer_1/default/scope2/")
828
829      with variable_scope.variable_scope(outer, reuse=True):
830        self.assertEqual(
831            variable_scope.get_variable("w", []).name, "outer/w:0")
832        with ops.name_scope("scope2") as sc2:
833          self.assertEqual(sc2, "outer_2/scope2/")
834        with variable_scope.variable_scope("default", reuse=True):
835          self.assertEqual(
836              variable_scope.get_variable("w", []).name, "outer/default/w:0")
837          with ops.name_scope("scope2") as sc2:
838            self.assertEqual(sc2, "outer_2/default/scope2/")
839
840  @test_util.run_in_graph_and_eager_modes
841  @run_inside_wrap_function_in_eager_mode
842  def testVarScopeNestedOuterScope(self):
843    with self.cached_session():
844      with variable_scope.variable_scope("outer") as outer:
845        with variable_scope.variable_scope(outer):
846          self.assertEqual(
847              variable_scope.get_variable("w", []).name, "outer/w:0")
848          with ops.name_scope("scope2") as sc2:
849            self.assertEqual(sc2, "outer/outer/scope2/")
850        with variable_scope.variable_scope("default"):
851          self.assertEqual(
852              variable_scope.get_variable("w", []).name, "outer/default/w:0")
853          with ops.name_scope("scope2") as sc2:
854            self.assertEqual(sc2, "outer/default/scope2/")
855
856        with variable_scope.variable_scope(outer, reuse=True):
857          self.assertEqual(
858              variable_scope.get_variable("w", []).name, "outer/w:0")
859          with ops.name_scope("scope2") as sc2:
860            self.assertEqual(sc2, "outer/outer_1/scope2/")
861        with variable_scope.variable_scope("default", reuse=True):
862          self.assertEqual(
863              variable_scope.get_variable("w", []).name, "outer/default/w:0")
864          with ops.name_scope("scope2") as sc2:
865            self.assertEqual(sc2, "outer/default_1/scope2/")
866
867  @test_util.run_in_graph_and_eager_modes
868  @run_inside_wrap_function_in_eager_mode
869  def testVarOpScopeReuseParam(self):
870    with self.cached_session():
871      with variable_scope.variable_scope("outer") as outer:
872        with variable_scope.variable_scope("tower", "default", []):
873          self.assertEqual(
874              variable_scope.get_variable("w", []).name, "outer/tower/w:0")
875          with ops.name_scope("scope2") as sc2:
876            self.assertEqual(sc2, "outer/tower/scope2/")
877        with variable_scope.variable_scope(None, "default", []):
878          self.assertEqual(
879              variable_scope.get_variable("w", []).name, "outer/default/w:0")
880          with ops.name_scope("scope2") as sc2:
881            self.assertEqual(sc2, "outer/default/scope2/")
882
883      with variable_scope.variable_scope(outer) as outer:
884        with variable_scope.variable_scope("tower", "default", reuse=True):
885          self.assertEqual(
886              variable_scope.get_variable("w", []).name, "outer/tower/w:0")
887          with ops.name_scope("scope2") as sc2:
888            self.assertEqual(sc2, "outer_1/tower/scope2/")
889        outer.reuse_variables()
890        with variable_scope.variable_scope(None, "default", []):
891          self.assertEqual(
892              variable_scope.get_variable("w", []).name, "outer/default/w:0")
893          with ops.name_scope("scope2") as sc2:
894            self.assertEqual(sc2, "outer_1/default/scope2/")
895
896  @test_util.run_in_graph_and_eager_modes
897  @run_inside_wrap_function_in_eager_mode
898  def testVarOpScopeReuseError(self):
899    with self.cached_session():
900      with self.assertRaises(ValueError):
901        with variable_scope.variable_scope(None, "default", reuse=True):
902          self.assertEqual(
903              variable_scope.get_variable("w", []).name, "outer/tower/w:0")
904
905  @test_util.run_in_graph_and_eager_modes
906  @run_inside_wrap_function_in_eager_mode
907  def testVarOpScopeOuterScope(self):
908    with self.cached_session():
909      with variable_scope.variable_scope("outer") as outer:
910        pass
911      with variable_scope.variable_scope(outer, "default", []):
912        self.assertEqual(
913            variable_scope.get_variable("w", []).name, "outer/w:0")
914        with ops.name_scope("scope2") as sc2:
915          self.assertEqual(sc2, "outer_1/scope2/")
916        with variable_scope.variable_scope(None, "default", []):
917          self.assertEqual(
918              variable_scope.get_variable("w", []).name, "outer/default/w:0")
919          with ops.name_scope("scope2") as sc2:
920            self.assertEqual(sc2, "outer_1/default/scope2/")
921
922      with variable_scope.variable_scope(outer, "default", reuse=True):
923        self.assertEqual(
924            variable_scope.get_variable("w", []).name, "outer/w:0")
925        with ops.name_scope("scope2") as sc2:
926          self.assertEqual(sc2, "outer_2/scope2/")
927        outer.reuse_variables()
928        with variable_scope.variable_scope(None, "default", []):
929          self.assertEqual(
930              variable_scope.get_variable("w", []).name, "outer/default/w:0")
931          with ops.name_scope("scope2") as sc2:
932            self.assertEqual(sc2, "outer_2/default/scope2/")
933
934  @test_util.run_in_graph_and_eager_modes
935  @run_inside_wrap_function_in_eager_mode
936  def testVarOpScopeNestedOuterScope(self):
937    with self.cached_session():
938      with variable_scope.variable_scope("outer") as outer:
939        with variable_scope.variable_scope(outer, "default", []):
940          self.assertEqual(
941              variable_scope.get_variable("w", []).name, "outer/w:0")
942          with ops.name_scope("scope2") as sc2:
943            self.assertEqual(sc2, "outer/outer/scope2/")
944        with variable_scope.variable_scope(None, "default", []):
945          self.assertEqual(
946              variable_scope.get_variable("w", []).name, "outer/default/w:0")
947          with ops.name_scope("scope2") as sc2:
948            self.assertEqual(sc2, "outer/default/scope2/")
949
950      with variable_scope.variable_scope(outer, "default", reuse=True):
951        self.assertEqual(
952            variable_scope.get_variable("w", []).name, "outer/w:0")
953        with ops.name_scope("scope2") as sc2:
954          self.assertEqual(sc2, "outer_1/scope2/")
955        with variable_scope.variable_scope(None, "default", []):
956          self.assertEqual(
957              variable_scope.get_variable("w", []).name, "outer/default/w:0")
958          with ops.name_scope("scope2") as sc2:
959            self.assertEqual(sc2, "outer_1/default/scope2/")
960
961  @test_util.run_in_graph_and_eager_modes
962  @run_inside_wrap_function_in_eager_mode
963  def testBasicWhenAuxiliaryNameScopeIsFalse(self):
964    with self.cached_session():
965      with variable_scope.variable_scope(
966          "scope", auxiliary_name_scope=False) as scope:
967        self.assertEqual(scope.original_name_scope, "")
968        self.assertEqual(
969            variable_scope.get_variable("w", []).name, "scope/w:0")
970        self.assertEqual(constant_op.constant([], name="c").name, "c:0")
971      with variable_scope.variable_scope(scope, auxiliary_name_scope=False):
972        self.assertEqual(scope.original_name_scope, "")
973        self.assertEqual(
974            variable_scope.get_variable("w1", []).name, "scope/w1:0")
975        self.assertEqual(constant_op.constant([], name="c1").name, "c1:0")
976      # Recheck: new name scope is NOT created before
977      with ops.name_scope("scope"):
978        self.assertEqual(constant_op.constant([], name="c").name, "scope/c:0")
979
980      with variable_scope.variable_scope("outer"):
981        with variable_scope.variable_scope(
982            "inner", auxiliary_name_scope=False) as inner:
983          self.assertEqual(inner.original_name_scope, "outer/")
984          self.assertEqual(
985              variable_scope.get_variable("w", []).name, "outer/inner/w:0")
986          self.assertEqual(
987              constant_op.constant([], name="c").name, "outer/c:0")
988        with variable_scope.variable_scope(
989            inner, auxiliary_name_scope=False) as inner1:
990          self.assertEqual(inner1.original_name_scope, "outer/")
991          self.assertEqual(
992              variable_scope.get_variable("w1", []).name, "outer/inner/w1:0")
993          self.assertEqual(
994              constant_op.constant([], name="c1").name, "outer/c1:0")
995        # Recheck: new name scope is NOT created before
996        with ops.name_scope("inner"):
997          self.assertEqual(
998              constant_op.constant([], name="c").name, "outer/inner/c:0")
999
1000  @test_util.run_in_graph_and_eager_modes
1001  @run_inside_wrap_function_in_eager_mode
1002  def testCreatedByDefaultNameWhenAuxiliaryNameScopeIsFalse(self):
1003    with self.cached_session():
1004      with variable_scope.variable_scope(
1005          None, default_name="default", auxiliary_name_scope=False) as scope:
1006        self.assertEqual(scope.original_name_scope, "")
1007        self.assertEqual(
1008            variable_scope.get_variable("w", []).name, "default/w:0")
1009        self.assertEqual(constant_op.constant([], name="c").name, "c:0")
1010      # Recheck: new name scope is NOT created before
1011      with ops.name_scope("default"):
1012        self.assertEqual(
1013            constant_op.constant([], name="c").name, "default/c:0")
1014
1015      with variable_scope.variable_scope("outer"):
1016        with variable_scope.variable_scope(
1017            None, default_name="default",
1018            auxiliary_name_scope=False) as inner:
1019          self.assertEqual(inner.original_name_scope, "outer/")
1020          self.assertEqual(
1021              variable_scope.get_variable("w", []).name, "outer/default/w:0")
1022          self.assertEqual(
1023              constant_op.constant([], name="c").name, "outer/c:0")
1024        # Recheck: new name scope is NOT created before
1025        with ops.name_scope("default"):
1026          self.assertEqual(
1027              constant_op.constant([], name="c").name, "outer/default/c:0")
1028
1029  @test_util.run_in_graph_and_eager_modes
1030  @run_inside_wrap_function_in_eager_mode
1031  def testReenterRootScopeWhenAuxiliaryNameScopeIsFalse(self):
1032    with self.cached_session():
1033      root_scope = variable_scope.get_variable_scope()
1034      with variable_scope.variable_scope(
1035          root_scope, auxiliary_name_scope=False) as scope:
1036        self.assertEqual(scope.original_name_scope, "")
1037        self.assertEqual(variable_scope.get_variable("w", []).name, "w:0")
1038        self.assertEqual(constant_op.constant([], name="c").name, "c:0")
1039
1040      with variable_scope.variable_scope("outer"):
1041        with variable_scope.variable_scope(
1042            root_scope, auxiliary_name_scope=False) as inner:
1043          self.assertEqual(inner.original_name_scope, "")
1044          self.assertEqual(variable_scope.get_variable("w1", []).name, "w1:0")
1045          self.assertEqual(
1046              constant_op.constant([], name="c1").name, "outer/c1:0")
1047
1048  @test_util.run_in_graph_and_eager_modes
1049  @run_inside_wrap_function_in_eager_mode
1050  def testAuxiliaryNameScopeIsInvalid(self):
1051    with self.cached_session():
1052      with self.assertRaisesRegexp(TypeError, "auxiliary_name_scope"):
1053        with variable_scope.variable_scope(
1054            None, default_name="scope", auxiliary_name_scope="invalid"):
1055          pass
1056
1057      with self.assertRaisesRegexp(TypeError, "auxiliary_name_scope"):
1058        with variable_scope.variable_scope(
1059            "scope", auxiliary_name_scope="invalid"):
1060          pass
1061
1062      with variable_scope.variable_scope("scope") as scope:
1063        pass
1064      with self.assertRaisesRegexp(TypeError, "auxiliary_name_scope"):
1065        with variable_scope.variable_scope(
1066            scope, auxiliary_name_scope="invalid"):
1067          pass
1068
1069  @test_util.run_in_graph_and_eager_modes
1070  @run_inside_wrap_function_in_eager_mode
1071  def testReuseScopeWithoutNameScopeCollision(self):
1072    # Github issue: #13429
1073    with self.cached_session():
1074      with variable_scope.variable_scope("outer"):
1075        with variable_scope.variable_scope("inner") as inner:
1076          pass
1077
1078      with variable_scope.variable_scope(
1079          inner, auxiliary_name_scope=False) as scope:
1080        with ops.name_scope(scope.original_name_scope):
1081          self.assertEqual(
1082              variable_scope.get_variable("w", []).name, "outer/inner/w:0")
1083          self.assertEqual(
1084              constant_op.constant([], name="c").name, "outer/inner/c:0")
1085        with ops.name_scope("inner"):
1086          self.assertEqual(
1087              constant_op.constant([], name="c").name, "inner/c:0")
1088
1089      with variable_scope.variable_scope("another"):
1090        with variable_scope.variable_scope(
1091            inner, auxiliary_name_scope=False) as scope1:
1092          with ops.name_scope(scope1.original_name_scope):
1093            self.assertEqual(
1094                variable_scope.get_variable("w1", []).name,
1095                "outer/inner/w1:0")
1096            self.assertEqual(
1097                constant_op.constant([], name="c1").name, "outer/inner/c1:0")
1098          with ops.name_scope("inner"):
1099            self.assertEqual(
1100                constant_op.constant([], name="c").name, "another/inner/c:0")
1101
1102  # TODO(mihaimaruseac): Not converted to use wrap_function because of
1103  # obtaining different results in the eager case compared to the graph one
1104  # (different assertions failing after wrapping, in both execution modes)
1105  @test_util.run_in_graph_and_eager_modes
1106  def testGetLocalVar(self):
1107    # Check that local variable respects naming.
1108    with variable_scope.variable_scope("outer") as outer:
1109      with variable_scope.variable_scope(outer, "default", []):
1110        local_var = variable_scope.get_local_variable(
1111            "w", [], collections=["foo"])
1112        self.assertEqual(local_var.name, "outer/w:0")
1113
1114    if not context.executing_eagerly():
1115      # Since variable is local, it should be in the local variable collection
1116      # but not the trainable collection.
1117      self.assertIn(local_var,
1118                    ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
1119      self.assertIn(local_var, ops.get_collection("foo"))
1120      self.assertNotIn(local_var,
1121                       ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
1122      # Check that local variable respects `reuse`.
1123      with variable_scope.variable_scope(outer, "default", reuse=True):
1124        self.assertEqual(
1125            variable_scope.get_local_variable("w", []).name, "outer/w:0")
1126
1127  @test_util.run_in_graph_and_eager_modes
1128  @run_inside_wrap_function_in_eager_mode
1129  def testSignatureGetVarVsGetLocalVar(self):
1130    """get_{local,}variable() must take the same list of args."""
1131    arg_names = tf_inspect.getargspec(variable_scope.get_variable)[0]
1132    local_arg_names = tf_inspect.getargspec(
1133        variable_scope.get_local_variable)[0]
1134    self.assertEqual(arg_names, local_arg_names)
1135
1136  @test_util.run_in_graph_and_eager_modes
1137  @run_inside_wrap_function_in_eager_mode
1138  def testGetVarWithDevice(self):
1139    g = ops.Graph()
1140    varname_type = []
1141
1142    def device_func(op):
1143      if op.type in ["Variable", "VariableV2", "VarHandleOp"]:
1144        varname_type.append((op.name, op.get_attr("dtype")))
1145      return "/device:GPU:0"
1146
1147    with g.as_default():
1148      with ops.device(device_func):
1149        _ = variable_scope.get_variable("x", (100, 200))
1150        _ = variable_scope.get_variable(
1151            "y", dtype=dtypes.int64, initializer=numpy.arange(73))
1152    self.assertEqual(varname_type[0], ("x", dtypes.float32))
1153    self.assertEqual(varname_type[1], ("y", dtypes.int64))
1154
1155  # TODO(mihaimaruseac): Not converted to use wrap_function because of
1156  # obtaining different results in the eager case compared to the graph one
1157  @test_util.run_deprecated_v1
1158  def testGetCollection(self):
1159    with self.cached_session():
1160      _ = variable_scope.get_variable("testGetCollection_a", [])
1161      _ = variable_scope.get_variable(
1162          "testGetCollection_b", [], trainable=False)
1163      with variable_scope.variable_scope("testGetCollection_foo_") as scope1:
1164        _ = variable_scope.get_variable("testGetCollection_a", [])
1165        _ = variable_scope.get_variable(
1166            "testGetCollection_b", [], trainable=False)
1167        self.assertEqual([
1168            v.name
1169            for v in scope1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
1170        ], ["testGetCollection_foo_/testGetCollection_a:0"])
1171        self.assertEqual([
1172            v.name
1173            for v in scope1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
1174        ], [
1175            "testGetCollection_foo_/testGetCollection_a:0",
1176            "testGetCollection_foo_/testGetCollection_b:0"
1177        ])
1178      with variable_scope.variable_scope("testGetCollection_foo") as scope2:
1179        _ = variable_scope.get_variable("testGetCollection_a", [])
1180        _ = variable_scope.get_variable(
1181            "testGetCollection_b", [], trainable=False)
1182        self.assertEqual([
1183            v.name
1184            for v in scope2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
1185        ], ["testGetCollection_foo/testGetCollection_a:0"])
1186        self.assertEqual([
1187            v.name
1188            for v in scope2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
1189        ], [
1190            "testGetCollection_foo/testGetCollection_a:0",
1191            "testGetCollection_foo/testGetCollection_b:0"
1192        ])
1193      scope = variable_scope.get_variable_scope()
1194      self.assertEqual([
1195          v.name for v in scope.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
1196      ], [
1197          "testGetCollection_a:0", "testGetCollection_b:0",
1198          "testGetCollection_foo_/testGetCollection_a:0",
1199          "testGetCollection_foo_/testGetCollection_b:0",
1200          "testGetCollection_foo/testGetCollection_a:0",
1201          "testGetCollection_foo/testGetCollection_b:0"
1202      ])
1203      self.assertEqual([
1204          v.name
1205          for v in scope.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
1206      ], [
1207          "testGetCollection_a:0",
1208          "testGetCollection_foo_/testGetCollection_a:0",
1209          "testGetCollection_foo/testGetCollection_a:0"
1210      ])
1211
1212  # TODO(mihaimaruseac): Not converted to use wrap_function because of
1213  # obtaining different results in the eager case compared to the graph one
1214  @test_util.run_deprecated_v1
1215  def testGetTrainableVariablesWithGetVariable(self):
1216    with self.cached_session():
1217      _ = variable_scope.get_variable("testGetTrainableVariables_a", [])
1218      with variable_scope.variable_scope(
1219          "testGetTrainableVariables_foo") as scope:
1220        _ = variable_scope.get_variable("testGetTrainableVariables_b", [])
1221        _ = variable_scope.get_variable(
1222            "testGetTrainableVariables_c", [], trainable=False)
1223
1224        # sync `ON_READ` sets trainable=False
1225        _ = variable_scope.get_variable(
1226            "testGetTrainableVariables_d", [],
1227            synchronization=variable_scope.VariableSynchronization.ON_READ)
1228        self.assertEqual(
1229            [v.name for v in scope.trainable_variables()],
1230            ["testGetTrainableVariables_foo/testGetTrainableVariables_b:0"])
1231
1232        # All other sync values sets trainable=True
1233        _ = variable_scope.get_variable(
1234            "testGetTrainableVariables_e", [],
1235            synchronization=variable_scope.VariableSynchronization.ON_WRITE)
1236        self.assertEqual([v.name for v in scope.trainable_variables()], [
1237            "testGetTrainableVariables_foo/testGetTrainableVariables_b:0",
1238            "testGetTrainableVariables_foo/testGetTrainableVariables_e:0"
1239        ])
1240
1241      with self.assertRaisesRegexp(
1242          ValueError, "Synchronization value can be set to "
1243          "VariableSynchronization.ON_READ only for non-trainable variables. "
1244          "You have specified trainable=True and "
1245          "synchronization=VariableSynchronization.ON_READ."):
1246        _ = variable_scope.get_variable(
1247            "testGetTrainableVariables_e", [],
1248            synchronization=variable_scope.VariableSynchronization.ON_READ,
1249            trainable=True)
1250
1251  # TODO(mihaimaruseac): Not converted to use wrap_function because of
1252  # obtaining different results in the eager case compared to the graph one
1253  @test_util.run_deprecated_v1
1254  def testGetTrainableVariablesWithVariable(self):
1255    with self.cached_session():
1256      _ = variable_scope.variable(1.0, name="testGetTrainableVariables_a")
1257      with variable_scope.variable_scope(
1258          "testGetTrainableVariables_foo") as scope:
1259        _ = variable_scope.variable(1.0, name="testGetTrainableVariables_b")
1260        _ = variable_scope.variable(
1261            1.0, name="testGetTrainableVariables_c", trainable=False)
1262
1263        # sync `ON_READ` sets trainable=False
1264        _ = variable_scope.variable(
1265            1.0,
1266            name="testGetTrainableVariables_d",
1267            synchronization=variable_scope.VariableSynchronization.ON_READ)
1268        self.assertEqual(
1269            [v.name for v in scope.trainable_variables()],
1270            ["testGetTrainableVariables_foo/testGetTrainableVariables_b:0"])
1271
1272        # All other sync values sets trainable=True
1273        _ = variable_scope.variable(
1274            1.0,
1275            name="testGetTrainableVariables_e",
1276            synchronization=variable_scope.VariableSynchronization.ON_WRITE)
1277        self.assertEqual([v.name for v in scope.trainable_variables()], [
1278            "testGetTrainableVariables_foo/testGetTrainableVariables_b:0",
1279            "testGetTrainableVariables_foo/testGetTrainableVariables_e:0"
1280        ])
1281
1282      with self.assertRaisesRegexp(
1283          ValueError, "Synchronization value can be set to "
1284          "VariableSynchronization.ON_READ only for non-trainable variables. "
1285          "You have specified trainable=True and "
1286          "synchronization=VariableSynchronization.ON_READ."):
1287        _ = variable_scope.variable(
1288            1.0,
1289            name="testGetTrainableVariables_e",
1290            synchronization=variable_scope.VariableSynchronization.ON_READ,
1291            trainable=True)
1292
1293  # TODO(mihaimaruseac): Not converted to use wrap_function because of
1294  # obtaining different results in the eager case compared to the graph one
1295  @test_util.run_deprecated_v1
1296  def testGetGlobalVariables(self):
1297    with self.cached_session():
1298      _ = variable_scope.get_variable("testGetGlobalVariables_a", [])
1299      with variable_scope.variable_scope("testGetGlobalVariables_foo") as scope:
1300        _ = variable_scope.get_variable("testGetGlobalVariables_b", [])
1301        self.assertEqual(
1302            [v.name for v in scope.global_variables()],
1303            ["testGetGlobalVariables_foo/"
1304             "testGetGlobalVariables_b:0"])
1305
1306  # TODO(mihaimaruseac): Not converted to use wrap_function because of
1307  # obtaining different results in the eager case compared to the graph one
1308  @test_util.run_deprecated_v1
1309  def testGetLocalVariables(self):
1310    with self.cached_session():
1311      _ = variable_scope.get_variable(
1312          "a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
1313      with variable_scope.variable_scope("foo") as scope:
1314        _ = variable_scope.get_variable(
1315            "b", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
1316        _ = variable_scope.get_variable("c", [])
1317        self.assertEqual([v.name for v in scope.local_variables()], ["foo/b:0"])
1318
1319  @test_util.run_in_graph_and_eager_modes
1320  @run_inside_wrap_function_in_eager_mode
1321  def testGetVariableWithRefDtype(self):
1322    v = variable_scope.get_variable("v", shape=[3, 4], dtype=dtypes.float32)
1323    # Ensure it is possible to do get_variable with a _ref dtype passed in.
1324    _ = variable_scope.get_variable("w", shape=[5, 6], dtype=v.dtype)
1325
1326  @test_util.run_in_graph_and_eager_modes
1327  @run_inside_wrap_function_in_eager_mode
1328  def testGetVariableWithInitializerWhichTakesNoArgs(self):
1329    v = variable_scope.get_variable("foo", initializer=lambda: [2])
1330    self.assertEqual(v.name, "foo:0")
1331
1332  @test_util.run_in_graph_and_eager_modes
1333  @run_inside_wrap_function_in_eager_mode
1334  def testGetVariableWithInitializerWhichTakesOptionalArgs(self):
1335    v = variable_scope.get_variable("foo", initializer=lambda x=True: [2])
1336    self.assertEqual(v.name, "foo:0")
1337
1338  @test_util.run_in_graph_and_eager_modes
1339  @run_inside_wrap_function_in_eager_mode
1340  def testGetVariableWithInitializerWhichTakesUnprovidedArgsAndNoShape(self):
1341    with self.assertRaisesRegexp(
1342        ValueError,
1343        "The initializer passed is not valid. It should be a callable with no "
1344        "arguments and the shape should not be provided or an instance of "
1345        "`tf.keras.initializers.*' and `shape` should be fully defined."):
1346      variable_scope.get_variable("foo", initializer=lambda x: [2])
1347
1348  @test_util.run_in_graph_and_eager_modes
1349  @run_inside_wrap_function_in_eager_mode
1350  def testTwoGraphs(self):
1351
1352    def f():
1353      g1 = ops.Graph()
1354      g2 = ops.Graph()
1355      with g1.as_default():
1356        with g2.as_default():
1357          with variable_scope.variable_scope("_"):
1358            pass
1359
1360    self.assertRaisesRegexp(ValueError, "'_' is not a valid scope name", f)
1361
1362
1363def axis0_into1_partitioner(shape=None, **unused_kwargs):
1364  part = [1] * len(shape)
1365  return part
1366
1367
1368def axis0_into2_partitioner(shape=None, **unused_kwargs):
1369  part = [1] * len(shape)
1370  part[0] = 2
1371  return part
1372
1373
1374def axis0_into3_partitioner(shape=None, **unused_kwargs):
1375  part = [1] * len(shape)
1376  part[0] = 3
1377  return part
1378
1379
1380class VariableScopeWithPartitioningTest(test.TestCase):
1381
1382  # TODO(mihaimaruseac): Not converted to use wrap_function because of
1383  # obtaining different results in the eager case compared to the graph one
1384  @test_util.run_deprecated_v1
1385  def testResultNameMatchesRequested(self):
1386    with variable_scope.variable_scope(
1387        "scope0", partitioner=axis0_into2_partitioner):
1388      v = variable_scope.get_variable("name0", shape=(3, 1, 1))
1389      self.assertEqual(v.name, "scope0/name0")
1390      v_concat = v.as_tensor()
1391      self.assertEqual(v_concat.name, "scope0/name0:0")
1392      variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
1393      self.assertIn("scope0/name0/part_0:0", [x.name for x in variables])
1394      self.assertIn("scope0/name0/part_1:0", [x.name for x in variables])
1395      self.assertNotIn("scope0/name0/part_2:0", [x.name for x in variables])
1396
1397  @test_util.run_in_graph_and_eager_modes
1398  @run_inside_wrap_function_in_eager_mode
1399  def testBreaksIfPartitioningChanges(self):
1400    with variable_scope.variable_scope(
1401        "scope0", partitioner=axis0_into2_partitioner):
1402      variable_scope.get_variable("name0", shape=(3, 1, 1))
1403
1404    with variable_scope.variable_scope(
1405        "scope0", partitioner=axis0_into3_partitioner, reuse=True):
1406      with self.assertRaisesRegexp(
1407          ValueError,
1408          "Trying to reuse partitioned variable .* but specified partitions "
1409          ".* and found partitions .*"):
1410        variable_scope.get_variable("name0", shape=(3, 1, 1))
1411
1412    with variable_scope.variable_scope(
1413        "scope0", partitioner=axis0_into1_partitioner, reuse=True):
1414      with self.assertRaisesRegexp(
1415          ValueError,
1416          "Trying to reuse partitioned variable .* but specified partitions "
1417          ".* and found partitions .*"):
1418        variable_scope.get_variable("name0", shape=(3, 1, 1))
1419
1420  @test_util.run_in_graph_and_eager_modes
1421  @run_inside_wrap_function_in_eager_mode
1422  def testReturnsExistingConcatenatedValueIfReuse(self):
1423    with variable_scope.variable_scope(
1424        "scope0", partitioner=axis0_into2_partitioner):
1425      v_concat = variable_scope.get_variable("name0", shape=(3, 1, 1))
1426      variable_scope.get_variable_scope().reuse_variables()
1427      v_concat_2 = variable_scope.get_variable("name0", shape=(3, 1, 1))
1428      self.assertEqual(v_concat, v_concat_2)
1429
1430  @test_util.run_in_graph_and_eager_modes
1431  @run_inside_wrap_function_in_eager_mode
1432  def testAllowsReuseWithoutPartitioner(self):
1433    with variable_scope.variable_scope(
1434        "scope0", partitioner=axis0_into2_partitioner):
1435      v = variable_scope.get_variable("name0", shape=(3, 1, 1))
1436    with variable_scope.variable_scope("scope0", reuse=True):
1437      v_reused = variable_scope.get_variable("name0")
1438    self.assertEqual(v, v_reused)
1439
1440  def testNoReuseInEagerByDefault(self):
1441    with context.eager_mode():
1442      with variable_scope.variable_scope(
1443          "scope0", partitioner=axis0_into2_partitioner):
1444        v1 = variable_scope.get_variable("name0", shape=(3, 1, 1))
1445        v2 = variable_scope.get_variable("name0", shape=(3, 1, 1))
1446        self.assertIsNot(v1, v2)
1447
1448  @test_util.run_in_graph_and_eager_modes
1449  @run_inside_wrap_function_in_eager_mode
1450  def testPropagatePartitionerOnReopening(self):
1451    with variable_scope.variable_scope(
1452        "scope0", partitioner=axis0_into2_partitioner) as vs:
1453      self.assertEqual(axis0_into2_partitioner, vs.partitioner)
1454      with variable_scope.variable_scope(vs) as vs1:
1455        self.assertEqual(axis0_into2_partitioner, vs1.partitioner)
1456
1457  # TODO(mihaimaruseac): Not converted to use wrap_function because of
1458  # obtaining different results in the eager case compared to the graph one
1459  @test_util.run_deprecated_v1
1460  def testScalarIgnoresPartitioner(self):
1461    with variable_scope.variable_scope(
1462        "scope0", partitioner=axis0_into2_partitioner):
1463      v = variable_scope.get_variable("name0", shape=())
1464      self.assertEqual(v.name, "scope0/name0:0")
1465      variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
1466      self.assertIn("scope0/name0:0", [x.name for x in variables])
1467
1468  def _testPartitionConcatenatesAlongCorrectAxis(self, use_resource):
1469    def _part_axis_0(**unused_kwargs):
1470      return (2, 1, 1)
1471
1472    def _part_axis_1(**unused_kwargs):
1473      return (1, 2, 1)
1474
1475    with variable_scope.variable_scope("root", use_resource=use_resource):
1476      v0 = variable_scope.get_variable(
1477          "n0", shape=(2, 2, 2), partitioner=_part_axis_0)
1478      v1 = variable_scope.get_variable(
1479          "n1", shape=(2, 2, 2), partitioner=_part_axis_1)
1480
1481    self.assertEqual(v0.get_shape(), (2, 2, 2))
1482    self.assertEqual(v1.get_shape(), (2, 2, 2))
1483
1484    n0_0 = list(v0)[0]
1485    n0_1 = list(v0)[1]
1486    self.assertEqual(n0_0.get_shape(), (1, 2, 2))
1487    self.assertEqual(n0_1.get_shape(), (1, 2, 2))
1488
1489    n1_0 = list(v1)[0]
1490    n1_1 = list(v1)[1]
1491    self.assertEqual(n1_0.get_shape(), (2, 1, 2))
1492    self.assertEqual(n1_1.get_shape(), (2, 1, 2))
1493
1494  @test_util.run_in_graph_and_eager_modes
1495  @run_inside_wrap_function_in_eager_mode
1496  def testPartitionConcatenatesAlongCorrectAxis(self):
1497    self._testPartitionConcatenatesAlongCorrectAxis(use_resource=False)
1498
1499  @test_util.run_in_graph_and_eager_modes
1500  @run_inside_wrap_function_in_eager_mode
1501  def testPartitionConcatenatesAlongCorrectAxisResource(self):
1502    self._testPartitionConcatenatesAlongCorrectAxis(use_resource=True)
1503
1504  def testPartitionConcatenatesAlongCorrectAxisResourceInEager(self):
1505    with context.eager_mode():
1506      self._testPartitionConcatenatesAlongCorrectAxis(use_resource=True)
1507
1508
1509class VariableScopeWithCustomGetterTest(test.TestCase):
1510
1511  @test_util.run_in_graph_and_eager_modes
1512  @run_inside_wrap_function_in_eager_mode
1513  def testNonCallableGetterFails(self):
1514    with self.assertRaisesRegexp(ValueError,
1515                                 r"custom_getter .* not callable:"):
1516      with variable_scope.variable_scope("scope0", custom_getter=3):
1517        variable_scope.get_variable("name0")
1518    with self.assertRaisesRegexp(ValueError,
1519                                 r"custom_getter .* not callable:"):
1520      variable_scope.get_variable("name0", custom_getter=3)
1521
1522  @test_util.run_in_graph_and_eager_modes
1523  @run_inside_wrap_function_in_eager_mode
1524  def testNoSideEffectsWithIdentityCustomGetter(self):
1525    called = [0]
1526
1527    def custom_getter(getter, *args, **kwargs):
1528      called[0] += 1
1529      return getter(*args, **kwargs)
1530
1531    with variable_scope.variable_scope(
1532        "scope", custom_getter=custom_getter) as scope:
1533      v = variable_scope.get_variable("v", [1])
1534    with variable_scope.variable_scope(scope, reuse=True):
1535      v2 = variable_scope.get_variable("v", [1])
1536    with variable_scope.variable_scope("new_scope") as new_scope:
1537      v3 = variable_scope.get_variable("v3", [1])
1538    with variable_scope.variable_scope(
1539        new_scope, reuse=True, custom_getter=custom_getter):
1540      v4 = variable_scope.get_variable("v3", [1])
1541
1542    self.assertEqual(v, v2)
1543    self.assertEqual(v3, v4)
1544    self.assertEqual(3, called[0])  # skipped one in the first new_scope
1545
1546  @test_util.run_in_graph_and_eager_modes
1547  @run_inside_wrap_function_in_eager_mode
1548  def testSynchronizationAndAggregationWithCustomGetter(self):
1549    called = [0]
1550    synchronization = variable_scope.VariableSynchronization.AUTO
1551    aggregation = variable_scope.VariableAggregation.NONE
1552
1553    def custom_getter(getter, *args, **kwargs):
1554      called[0] += 1
1555
1556      # Verify synchronization and aggregation kwargs are as expected.
1557      self.assertEqual(kwargs["synchronization"], synchronization)
1558      self.assertEqual(kwargs["aggregation"], aggregation)
1559      return getter(*args, **kwargs)
1560
1561    with variable_scope.variable_scope("scope", custom_getter=custom_getter):
1562      variable_scope.get_variable("v", [1])
1563    self.assertEqual(1, called[0])
1564
1565    with variable_scope.variable_scope("scope", custom_getter=custom_getter):
1566      synchronization = variable_scope.VariableSynchronization.ON_READ
1567      aggregation = variable_scope.VariableAggregation.MEAN
1568      variable_scope.get_variable(
1569          "v1", [1], synchronization=synchronization, aggregation=aggregation)
1570
1571    self.assertEqual(2, called[0])
1572
1573  @test_util.run_in_graph_and_eager_modes
1574  @run_inside_wrap_function_in_eager_mode
1575  def testCustomGetterWithReuse(self):
1576    # Custom getter can choose to behave differently on reused variables.
1577    def custom_getter(getter, *args, **kwargs):
1578      var = getter(*args, **kwargs)
1579      if kwargs["reuse"]:
1580        # This can be used, e.g., for changing the caching device if needed.
1581        return array_ops.identity(var, name="reused")
1582      else:
1583        return array_ops.identity(var, name="not_reused")
1584
1585    with variable_scope.variable_scope(
1586        "scope", custom_getter=custom_getter) as scope:
1587      v = variable_scope.get_variable("v", [1])
1588    with variable_scope.variable_scope(scope, reuse=True):
1589      v2 = variable_scope.get_variable("v", [1])
1590
1591    self.assertEqual(v.name, "not_reused:0")
1592    self.assertEqual(v2.name, "reused:0")
1593
1594  # TODO(mihaimaruseac): Not converted to use wrap_function because of
1595  # ValueError: Fetch argument <tf.Tensor 'custom_getter/add:0' shape=(1, 2, 3)
1596  # dtype=float32> cannot be interpreted as a Tensor. (Tensor
1597  # Tensor("custom_getter/add:0", shape=(1, 2, 3), dtype=float32) is not an
1598  # element of this graph.)
1599  @test_util.run_deprecated_v1
1600  def testGetterThatCreatesTwoVariablesAndSumsThem(self):
1601
1602    def custom_getter(getter, name, *args, **kwargs):
1603      g_0 = getter("%s/0" % name, *args, **kwargs)
1604      g_1 = getter("%s/1" % name, *args, **kwargs)
1605      with ops.name_scope("custom_getter"):
1606        return g_0 + g_1
1607
1608    with variable_scope.variable_scope("scope", custom_getter=custom_getter):
1609      v = variable_scope.get_variable("v", [1, 2, 3])
1610
1611    self.assertEqual([1, 2, 3], v.get_shape())
1612    true_vars = variables_lib.trainable_variables()
1613    self.assertEqual(2, len(true_vars))
1614    self.assertEqual("scope/v/0:0", true_vars[0].name)
1615    self.assertEqual("scope/v/1:0", true_vars[1].name)
1616    self.assertEqual("custom_getter/add:0", v.name)
1617    with self.cached_session() as sess:
1618      variables_lib.global_variables_initializer().run()
1619      np_vars, np_v = self.evaluate([true_vars, v])
1620      self.assertAllClose(np_v, sum(np_vars))
1621
1622  # TODO(mihaimaruseac): Not converted to use wrap_function because of
1623  # ValueError: Fetch argument <tf.Tensor 'sum_getter_2/add:0' shape=(1, 2, 3)
1624  # dtype=float32> cannot be interpreted as a Tensor. (Tensor
1625  # Tensor("sum_getter_2/add:0", shape=(1, 2, 3), dtype=float32) is not an
1626  # element of this graph.)
1627  @test_util.run_deprecated_v1
1628  def testNestedCustomGetters(self):
1629
1630    def sum_getter(getter, name, *args, **kwargs):
1631      g_0 = getter("%s/sum_0" % name, *args, **kwargs)
1632      g_1 = getter("%s/sum_1" % name, *args, **kwargs)
1633      with ops.name_scope("sum_getter"):
1634        return g_0 + g_1
1635
1636    def prod_getter(getter, name, *args, **kwargs):
1637      g_0 = getter("%s/prod_0" % name, *args, **kwargs)
1638      g_1 = getter("%s/prod_1" % name, *args, **kwargs)
1639      with ops.name_scope("prod_getter"):
1640        return g_0 * g_1
1641
1642    with variable_scope.variable_scope("prod_scope", custom_getter=prod_getter):
1643      with variable_scope.variable_scope("sum_scope", custom_getter=sum_getter):
1644        with variable_scope.variable_scope(
1645            "inner_sum_scope", custom_getter=sum_getter):
1646          # take sums of sums of products
1647          v = variable_scope.get_variable("v", [1, 2, 3])
1648
1649    self.assertEqual([1, 2, 3], v.get_shape())
1650    true_vars = variables_lib.trainable_variables()
1651    self.assertEqual(8, len(true_vars))
1652    template = (
1653        "prod_scope/sum_scope/inner_sum_scope/v/sum_%d/sum_%d/prod_%d:0")
1654    self.assertEqual(template % (0, 0, 0), true_vars[0].name)
1655    self.assertEqual(template % (0, 0, 1), true_vars[1].name)
1656    self.assertEqual(template % (0, 1, 0), true_vars[2].name)
1657    self.assertEqual(template % (0, 1, 1), true_vars[3].name)
1658    self.assertEqual(template % (1, 0, 0), true_vars[4].name)
1659    self.assertEqual(template % (1, 0, 1), true_vars[5].name)
1660    self.assertEqual(template % (1, 1, 0), true_vars[6].name)
1661    self.assertEqual(template % (1, 1, 1), true_vars[7].name)
1662
1663    with self.cached_session() as sess:
1664      variables_lib.global_variables_initializer().run()
1665      np_vars, np_v = self.evaluate([true_vars, v])
1666      # take products of sums of products
1667      self.assertAllClose(
1668          np_v, (((np_vars[0] * np_vars[1]) + (np_vars[2] * np_vars[3])) + (
1669              (np_vars[4] * np_vars[5]) + (np_vars[6] * np_vars[7]))))
1670
1671  @test_util.run_in_graph_and_eager_modes
1672  @run_inside_wrap_function_in_eager_mode
1673  def testVariableCreator(self):
1674    variable_names = []
1675
1676    def creator_a(next_creator, **kwargs):
1677      variable_names.append(kwargs.get("name", ""))
1678      return next_creator(**kwargs)
1679
1680    def creator_b(next_creator, **kwargs):
1681      kwargs["name"] = "forced_name"
1682      return next_creator(**kwargs)
1683
1684    with variable_scope.variable_creator_scope(creator_a):
1685      with variable_scope.variable_creator_scope(creator_b):
1686        variable_scope.variable(1.0, name="one_name")
1687
1688    self.assertEqual(variable_names[0], "forced_name")
1689
1690    called = [False]
1691
1692    def creater_c(next_creator, **kwargs):
1693      called[0] = True
1694      self.assertEqual(kwargs["synchronization"],
1695                       variable_scope.VariableSynchronization.ON_WRITE)
1696      self.assertEqual(kwargs["aggregation"],
1697                       variable_scope.VariableAggregation.MEAN)
1698      return next_creator(**kwargs)
1699
1700    with variable_scope.variable_creator_scope(creater_c):
1701      variable_scope.get_variable(
1702          "v", [],
1703          synchronization=variable_scope.VariableSynchronization.ON_WRITE,
1704          aggregation=variable_scope.VariableAggregation.MEAN)
1705    self.assertTrue(called[0])
1706
1707
1708class PartitionInfoTest(test.TestCase):
1709
1710  @test_util.run_in_graph_and_eager_modes
1711  @run_inside_wrap_function_in_eager_mode
1712  def testConstructorChecks(self):
1713    # Invalid arg types.
1714    with self.assertRaises(TypeError):
1715      variable_scope._PartitionInfo(full_shape=None, var_offset=[0, 1])
1716    with self.assertRaises(TypeError):
1717      variable_scope._PartitionInfo(full_shape=[0, 1], var_offset=None)
1718    with self.assertRaises(TypeError):
1719      variable_scope._PartitionInfo(full_shape="foo", var_offset=[0, 1])
1720    with self.assertRaises(TypeError):
1721      variable_scope._PartitionInfo(full_shape=[0, 1], var_offset="foo")
1722
1723    # full_shape and var_offset must have same length.
1724    with self.assertRaises(ValueError):
1725      variable_scope._PartitionInfo(full_shape=[0, 1], var_offset=[0])
1726    # Offset must always be less than shape.
1727    with self.assertRaises(ValueError):
1728      variable_scope._PartitionInfo(full_shape=[1, 1], var_offset=[0, 1])
1729
1730  @test_util.run_in_graph_and_eager_modes
1731  @run_inside_wrap_function_in_eager_mode
1732  def testSingleOffset(self):
1733    partition_info = variable_scope._PartitionInfo(
1734        full_shape=[9, 3], var_offset=[4, 0])
1735    self.assertEqual(4, partition_info.single_offset([1, 3]))
1736
1737    # Tests when the variable isn't partitioned at all.
1738    partition_info = variable_scope._PartitionInfo(
1739        full_shape=[9, 3], var_offset=[0, 0])
1740    self.assertEqual(0, partition_info.single_offset([9, 3]))
1741
1742  @test_util.run_in_graph_and_eager_modes
1743  @run_inside_wrap_function_in_eager_mode
1744  def testSingleSliceDim(self):
1745    partition_info = variable_scope._PartitionInfo(
1746        full_shape=[9, 3], var_offset=[4, 0])
1747    # Invalid shape.
1748    with self.assertRaises(TypeError):
1749      partition_info.single_slice_dim(None)
1750
1751    # Rank of shape differs from full_shape.
1752    with self.assertRaises(ValueError):
1753      partition_info.single_slice_dim([1, 2, 3])
1754
1755    # Shape is too large given var_offset (4+6 > 9).
1756    with self.assertRaises(ValueError):
1757      partition_info.single_slice_dim([6, 3])
1758
1759    # Multiple possible slice dim from shape.
1760    with self.assertRaises(ValueError):
1761      partition_info.single_slice_dim([1, 1])
1762
1763    partition_info = variable_scope._PartitionInfo(
1764        full_shape=[9, 3], var_offset=[0, 0])
1765    self.assertEqual(1, partition_info.single_slice_dim([9, 2]))
1766    partition_info = variable_scope._PartitionInfo(
1767        full_shape=[9, 3], var_offset=[4, 0])
1768    self.assertEqual(0, partition_info.single_slice_dim([2, 3]))
1769
1770
1771class VariableScopeMultithreadedTest(test.TestCase):
1772
1773  @test_util.run_in_graph_and_eager_modes
1774  @run_inside_wrap_function_in_eager_mode
1775  def testTwoThreadsDisjointScopeEntry(self):
1776
1777    def thread_fn(i, graph):
1778      with graph.as_default():
1779        with variable_scope.variable_scope("foo"):
1780          if i == 0:
1781            v = variable_scope.get_variable("v", [])
1782            self.assertEquals("foo/v:0", v.name)
1783          else:
1784            # Any thread after the first one should fail to create variable
1785            # with the same name.
1786            with self.assertRaises(ValueError):
1787              variable_scope.get_variable("v", [])
1788
1789    graph = ops.get_default_graph()
1790    threads = [
1791        threading.Thread(target=thread_fn, args=(
1792            i,
1793            graph,
1794        )) for i in range(2)
1795    ]
1796
1797    threads[0].start()
1798    # Allow thread 0 to finish before starting thread 1.
1799    threads[0].join()
1800    threads[1].start()
1801    threads[1].join()
1802
1803  @test_util.run_in_graph_and_eager_modes
1804  @run_inside_wrap_function_in_eager_mode
1805  def testTwoThreadsNestedScopeEntry(self):
1806
1807    def thread_fn(i, graph, run_event, pause_event):
1808      with graph.as_default():
1809        with variable_scope.variable_scope("foo"):
1810          if i == 0:
1811            v = variable_scope.get_variable("v", [])
1812            self.assertEquals("foo/v:0", v.name)
1813          else:
1814            # Any thread after the first one should fail to create variable
1815            # with the same name.
1816            with self.assertRaises(ValueError):
1817              variable_scope.get_variable("v", [])
1818          pause_event.set()
1819          run_event.wait()
1820
1821    graph = ops.get_default_graph()
1822    run_events = [threading.Event() for _ in range(2)]
1823    pause_events = [threading.Event() for _ in range(2)]
1824    threads = [
1825        threading.Thread(
1826            target=thread_fn, args=(i, graph, run_events[i], pause_events[i]))
1827        for i in range(2)
1828    ]
1829
1830    # Start first thread.
1831    threads[0].start()
1832    pause_events[0].wait()
1833    # Start next thread once the first thread has paused.
1834    threads[1].start()
1835    pause_events[1].wait()
1836    # Resume both threads.
1837    run_events[0].set()
1838    run_events[1].set()
1839    threads[0].join()
1840    threads[1].join()
1841
1842  @test_util.run_in_graph_and_eager_modes
1843  @run_inside_wrap_function_in_eager_mode
1844  def testReenterMainScope(self):
1845
1846    def thread_fn(graph, main_thread_scope):
1847      with graph.as_default():
1848        # Variable created with main scope will have prefix "main".
1849        with variable_scope.variable_scope(main_thread_scope):
1850          with variable_scope.variable_scope("foo"):
1851            v = variable_scope.get_variable("v", [])
1852            self.assertEquals("main/foo/v:0", v.name)
1853
1854        # Variable created outside main scope will not have prefix "main".
1855        with variable_scope.variable_scope("bar"):
1856          v = variable_scope.get_variable("v", [])
1857          self.assertEquals("bar/v:0", v.name)
1858
1859    graph = ops.get_default_graph()
1860    with variable_scope.variable_scope("main") as main_thread_scope:
1861      thread = threading.Thread(
1862          target=thread_fn, args=(graph, main_thread_scope))
1863      thread.start()
1864      thread.join()
1865
1866
1867if __name__ == "__main__":
1868  test.main()
1869