• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for make_template."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import functools
21import traceback
22
23from tensorflow.python.client import session
24from tensorflow.python.eager import context
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import random_seed
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import init_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import template
32from tensorflow.python.ops import variable_scope
33from tensorflow.python.ops import variables
34import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
35from tensorflow.python.platform import test
36from tensorflow.python.training import gradient_descent
37
38
39def variable_scoped_function(trainable=True):
40  return variable_scope.get_variable(
41      "dummy", shape=[1], trainable=trainable,
42      initializer=init_ops.zeros_initializer())
43
44
45def internally_variable_scoped_function(scope_name):
46  with variable_scope.variable_scope(scope_name):
47    return variable_scope.get_variable(
48        "dummy", shape=[1], initializer=init_ops.zeros_initializer())
49
50
51def function_with_create(trainable):
52  """Creates a variable as a side effect using tf.Variable."""
53  variables.Variable(0, trainable=trainable)
54  return variable_scope.get_variable(
55      "dummy", shape=[1], initializer=init_ops.zeros_initializer())
56
57
58def function_with_side_create(trainable, name="side"):
59  """Creates a variable as a side effect using tf.get_variable."""
60  variable_scope.get_variable(name, shape=[1], trainable=trainable)
61  return variable_scope.get_variable(
62      "dummy", shape=[1], initializer=init_ops.zeros_initializer())
63
64
65def variable_scoped_function_with_local_variable():
66  variable_scope.get_local_variable(
67      "local", shape=[1], initializer=init_ops.zeros_initializer())
68  return variable_scope.get_variable(
69      "dummy", shape=[1], initializer=init_ops.zeros_initializer())
70
71
72class TemplateTest(test.TestCase):
73
74  @test_util.run_deprecated_v1
75  def test_end_to_end(self):
76    """This test shows a very simple line model with test_loss.
77
78    The template is used to share parameters between a training and test model.
79    """
80    # y = 2x + 1
81    training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7])
82    test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17])
83
84    random_seed.set_random_seed(1234)
85
86    def test_line(x):
87      m = variable_scope.get_variable(
88          "w", shape=[], initializer=init_ops.truncated_normal_initializer())
89      b = variable_scope.get_variable(
90          "b", shape=[], initializer=init_ops.truncated_normal_initializer())
91      return x * m + b
92
93    line_template = template.make_template("line", test_line)
94
95    train_prediction = line_template(training_input)
96    test_prediction = line_template(test_input)
97
98    train_loss = math_ops.reduce_mean(
99        math_ops.square(train_prediction - training_output))
100    test_loss = math_ops.reduce_mean(
101        math_ops.square(test_prediction - test_output))
102
103    optimizer = gradient_descent.GradientDescentOptimizer(0.1)
104    train_op = optimizer.minimize(train_loss)
105
106    with session.Session() as sess:
107      self.evaluate(variables.global_variables_initializer())
108      initial_test_loss = self.evaluate(test_loss)
109      self.evaluate(train_op)
110      final_test_loss = self.evaluate(test_loss)
111
112    # Parameters are tied, so the loss should have gone down when we trained it.
113    self.assertLess(final_test_loss, initial_test_loss)
114
115  def test_end_to_end_eager(self):
116    """This test shows a very simple line model with test_loss in eager mode.
117
118    The template is used to share parameters between a training and test model.
119    """
120    with context.eager_mode():
121      # y = 2x + 1
122      training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7])
123      test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17])
124
125      random_seed.set_random_seed(1234)
126
127      def test_line(x):
128        m = variable_scope.get_variable(
129            "w", shape=[], initializer=init_ops.truncated_normal_initializer())
130        b = variable_scope.get_variable(
131            "b", shape=[], initializer=init_ops.truncated_normal_initializer())
132        return x * m + b
133
134      line_template = template.make_template("line", test_line)
135
136      def train_loss():
137        train_prediction = line_template(training_input)
138        return math_ops.reduce_mean(
139            math_ops.square(train_prediction - training_output))
140
141      def test_loss():
142        test_prediction = line_template(test_input)
143        return math_ops.reduce_mean(
144            math_ops.square(test_prediction - test_output))
145
146      optimizer = gradient_descent.GradientDescentOptimizer(0.1)
147      initial_test_loss = test_loss()
148      optimizer.minimize(train_loss)
149      final_test_loss = test_loss()
150
151      # Parameters are tied, so the loss should have gone down after training.
152      self.assertLess(final_test_loss.numpy(), initial_test_loss.numpy())
153
154  @test_util.run_in_graph_and_eager_modes
155  def test_skip_stack_frames(self):
156    first = traceback.format_stack()
157    second = traceback.format_stack()
158    result = template._skip_common_stack_elements(first, second)
159    self.assertEqual(1, len(result))
160    self.assertNotEqual(len(first), len(result))
161
162  @test_util.run_in_graph_and_eager_modes
163  def test_template_with_empty_name(self):
164    tpl = template.make_template("", variable_scoped_function)
165    with variable_scope.variable_scope("outer"):
166      x = variable_scope.get_variable("x", [])
167      v = tpl()
168    self.assertEqual("outer/", tpl.variable_scope_name)
169    self.assertEqual("outer//dummy:0", v.name)
170    if context.executing_eagerly():
171      # In eager mode `x` is not visible to the template since the template does
172      # not rely on global collections.
173      self.assertEqual(1, len(tpl.variables))
174      self.assertIs(v, tpl.variables[0])
175    else:
176      self.assertEqual([x, v], tpl.variables)
177
178  @test_util.run_in_graph_and_eager_modes
179  def test_template_with_name(self):
180    tmpl1 = template.make_template("s1", variable_scoped_function)
181    tmpl2 = template.make_template("s1", variable_scoped_function)
182
183    v1 = tmpl1()
184    v2 = tmpl1()
185    v3 = tmpl2()
186    self.assertIs(v1, v2)
187    self.assertIsNot(v1, v3)
188    self.assertEqual("s1/dummy:0", v1.name)
189    self.assertEqual("s1_1/dummy:0", v3.name)
190
191  @test_util.run_deprecated_v1
192  def test_same_unique_name_raise_error(self):
193    tmpl1 = template.make_template(
194        "_", variable_scoped_function, unique_name_="s1")
195    tmpl1()
196    tmpl2 = template.make_template(
197        "_", variable_scoped_function, unique_name_="s1")
198    with self.assertRaisesRegex(
199        ValueError, "Variable s1/dummy already exists, disallowed.*"):
200      tmpl2()
201
202  def test_unique_name_raise_error_in_eager(self):
203    with context.eager_mode():
204      with self.assertRaisesRegex(
205          ValueError,
206          "unique_name_ cannot be used when eager execution is enabled."):
207        template.make_template(
208            "_", variable_scoped_function, unique_name_="s1")
209
210  @test_util.run_deprecated_v1
211  def test_unique_name_and_reuse(self):
212    tmpl1 = template.make_template(
213        "_", variable_scoped_function, unique_name_="s1")
214    v1 = tmpl1()
215    v2 = tmpl1()
216
217    variable_scope.get_variable_scope().reuse_variables()
218    tmpl2 = template.make_template(
219        "_", variable_scoped_function, unique_name_="s1")
220    v3 = tmpl2()
221
222    self.assertIs(v1, v2)
223    self.assertIs(v1, v3)
224    self.assertEqual("s1/dummy:0", v1.name)
225
226  @test_util.run_in_graph_and_eager_modes
227  def test_template_in_scope(self):
228    tmpl1 = template.make_template("s1", variable_scoped_function)
229    tmpl2 = template.make_template("s1", variable_scoped_function)
230
231    with variable_scope.variable_scope("scope"):
232      v1 = tmpl1()
233      v3 = tmpl2()
234
235    # The template contract requires the following to ignore scope2.
236    with variable_scope.variable_scope("scope2"):
237      v2 = tmpl1()
238    self.assertIs(v1, v2)
239    self.assertIsNot(v1, v3)
240    self.assertEqual("scope/s1/dummy:0", v1.name)
241    self.assertEqual("scope/s1_1/dummy:0", v3.name)
242
243  @test_util.run_in_graph_and_eager_modes
244  def test_template_with_internal_reuse(self):
245    tmpl1 = template.make_template("s1", internally_variable_scoped_function)
246    tmpl2 = template.make_template("s1", internally_variable_scoped_function)
247
248    v1 = tmpl1("test")
249    v2 = tmpl1("test")
250    v3 = tmpl2("test")
251    self.assertIs(v1, v2)
252    self.assertIsNot(v1, v3)
253    self.assertEqual("s1/test/dummy:0", v1.name)
254    self.assertEqual("s1_1/test/dummy:0", v3.name)
255
256    with self.assertRaises(ValueError):
257      tmpl1("not_test")
258
259  @test_util.run_in_graph_and_eager_modes
260  def test_template_without_name(self):
261    with self.assertRaisesRegex(ValueError, "name cannot be None."):
262      template.make_template(None, variable_scoped_function)
263
264  @test_util.run_in_graph_and_eager_modes
265  def test_make_template(self):
266    # Test both that we can call it with positional and keywords.
267    tmpl1 = template.make_template(
268        "s1", internally_variable_scoped_function, scope_name="test")
269    tmpl2 = template.make_template(
270        "s1", internally_variable_scoped_function, scope_name="test")
271
272    v1 = tmpl1()
273    v2 = tmpl1()
274    v3 = tmpl2()
275    self.assertIs(v1, v2)
276    self.assertIsNot(v1, v3)
277    self.assertEqual("s1/test/dummy:0", v1.name)
278    self.assertEqual("s1_1/test/dummy:0", v3.name)
279
280  @test_util.run_deprecated_v1
281  def test_enforces_no_extra_trainable_variables(self):
282    tmpl = template.make_template("s", function_with_create, trainable=True)
283
284    tmpl()
285    with self.assertRaises(ValueError):
286      tmpl()
287
288  @test_util.run_in_graph_and_eager_modes
289  def test_enforces_no_extra_trainable_variables_eager(self):
290    tmpl = template.make_template("s",
291                                  function_with_side_create,
292                                  trainable=True)
293
294    tmpl(name="1")
295    with self.assertRaises(ValueError):
296      tmpl(name="2")
297
298  def test_permits_extra_non_trainable_variables(self):
299    tmpl = template.make_template("s", function_with_create, trainable=False)
300    self.assertIs(tmpl(), tmpl())
301
302  def test_permits_extra_non_trainable_variables_eager(self):
303    with context.eager_mode():
304      tmpl = template.make_template("s",
305                                    function_with_side_create,
306                                    trainable=False)
307      self.assertIs(tmpl(name="1"), tmpl(name="2"))
308
309  @test_util.run_in_graph_and_eager_modes
310  def test_internal_variable_reuse(self):
311
312    def nested():
313      with variable_scope.variable_scope("nested") as vs:
314        v1 = variable_scope.get_variable(
315            "x", initializer=init_ops.zeros_initializer(), shape=[])
316      with variable_scope.variable_scope(vs, reuse=True):
317        v2 = variable_scope.get_variable("x")
318      self.assertIs(v1, v2)
319      return v1
320
321    tmpl1 = template.make_template("s1", nested)
322    tmpl2 = template.make_template("s1", nested)
323
324    v1 = tmpl1()
325    v2 = tmpl1()
326    v3 = tmpl2()
327    self.assertIs(v1, v2)
328    self.assertIsNot(v1, v3)
329    self.assertEqual("s1/nested/x:0", v1.name)
330    self.assertEqual("s1_1/nested/x:0", v3.name)
331
332  @test_util.run_in_graph_and_eager_modes
333  def test_nested_templates(self):
334
335    def nested_template():
336      nested1 = template.make_template("nested", variable_scoped_function)
337      nested2 = template.make_template("nested", variable_scoped_function)
338      v1 = nested1()
339      v2 = nested2()
340
341      # nested1 and nested2 should not share variables
342      self.assertIsNot(v1, v2)
343
344      # Variables created by nested1 should be isolated from variables
345      # created by nested2.
346      self.assertEqual(1, len(nested1.variables))
347      self.assertEqual(1, len(nested2.variables))
348      self.assertIs(nested1.variables[0], v1)
349      self.assertIs(nested2.variables[0], v2)
350      self.assertEqual(1, len(nested1.trainable_variables))
351      self.assertEqual(1, len(nested2.trainable_variables))
352      self.assertIs(nested1.trainable_variables[0], v1)
353      self.assertIs(nested2.trainable_variables[0], v2)
354      self.assertEqual(len(nested1.non_trainable_variables), 0)
355      self.assertEqual(len(nested2.non_trainable_variables), 0)
356      return v1, v2
357
358    tmpl1 = template.make_template("s1", nested_template)
359    tmpl2 = template.make_template("s1", nested_template)
360
361    v1, v2 = tmpl1()
362    v3, v4 = tmpl1()
363    v5, v6 = tmpl2()
364
365    # The second invocation of tmpl1 should reuse the variables
366    # created in the first invocation.
367    self.assertIs(v1, v3)
368    self.assertIs(v2, v4)
369    for v, w in zip(tmpl1.variables, [v1, v2]):
370      self.assertIs(v, w)
371    for v, w in zip(tmpl1.trainable_variables, [v1, v2]):
372      self.assertIs(v, w)
373    self.assertEqual(len(tmpl1.non_trainable_variables), 0)
374
375    # tmpl1 and tmpl2 should not share variables.
376    self.assertIsNot(v1, v5)
377    self.assertIsNot(v2, v6)
378    for v, w in zip(tmpl2.variables, [v5, v6]):
379      self.assertIs(v, w)
380    for v, w in zip(tmpl2.trainable_variables, [v5, v6]):
381      self.assertIs(v, w)
382    self.assertEqual(len(tmpl2.non_trainable_variables), 0)
383    self.assertEqual("s1/nested/dummy:0", v1.name)
384    self.assertEqual("s1/nested_1/dummy:0", v2.name)
385    self.assertEqual("s1_1/nested/dummy:0", v5.name)
386    self.assertEqual("s1_1/nested_1/dummy:0", v6.name)
387
388    self.assertEqual(2, len(tmpl1._checkpoint_dependencies))
389    self.assertEqual("nested", tmpl1._checkpoint_dependencies[0].name)
390    self.assertEqual("nested_1", tmpl1._checkpoint_dependencies[1].name)
391
392  @test_util.run_in_graph_and_eager_modes
393  def test_nested_templates_with_defun(self):
394
395    def variable_scoped_function_no_return_value(trainable=True):
396      # defun cannot compile functions that return non-Tensor objects
397      _ = variable_scope.get_variable(
398          "dummy",
399          shape=[1],
400          trainable=trainable,
401          initializer=init_ops.zeros_initializer())
402
403    def nested_template():
404      nested1 = template.make_template_internal(
405          "nested",
406          variable_scoped_function_no_return_value,
407          create_graph_function_=True)
408      nested2 = template.make_template_internal(
409          "nested",
410          variable_scoped_function_no_return_value,
411          create_graph_function_=True)
412      nested1()
413      nested2()
414      v1 = nested1.variables
415      v2 = nested2.variables
416
417      self.assertEqual(len(v1), 1)
418      self.assertEqual(len(v2), 1)
419
420      # nested1 and nested2 should not share variables
421      self.assertIsNot(v1[0], v2[0])
422      self.assertIs(nested1.trainable_variables[0], v1[0])
423      self.assertIs(nested2.trainable_variables[0], v2[0])
424      self.assertEqual(len(nested1.non_trainable_variables), 0)
425      self.assertEqual(len(nested2.non_trainable_variables), 0)
426
427    tmpl1 = template.make_template("s1", nested_template)
428    tmpl2 = template.make_template("s1", nested_template)
429
430    tmpl1()
431    v1 = tmpl1.variables
432    tmpl1()
433    v2 = tmpl1.variables
434    tmpl2()
435    v3 = tmpl2.variables
436
437    # The second invocation of tmpl1 should reuse the variables
438    # created in the first invocation.
439    for v, w in zip(v1, v2):
440      self.assertIs(v, w)
441
442    # tmpl1 and tmpl2 should not share variables.
443    for v, w in zip(v1, v3):
444      self.assertIsNot(v, w)
445
446    self.assertEqual("s1/nested/dummy:0", v1[0].name)
447    self.assertEqual("s1/nested_1/dummy:0", v1[1].name)
448    self.assertEqual("s1_1/nested/dummy:0", v3[0].name)
449    self.assertEqual("s1_1/nested_1/dummy:0", v3[1].name)
450
451  def test_graph_function_no_name(self):
452    with context.eager_mode():
453
454      def f(_, y):
455        return y + 1
456
457      partial = functools.partial(f, 1.0)
458      tmpl = template.make_template_internal(
459          "a", partial, create_graph_function_=True)
460      self.assertAllEqual(tmpl(ops.convert_to_tensor(1.0)), 2.0)
461
462  @test_util.run_in_graph_and_eager_modes
463  def test_immediate_scope_creation(self):
464    # Create templates in scope a then call in scope b. make_template should
465    # capture the scope the first time it is called, and make_immediate_template
466    # should capture the scope at construction time.
467    with variable_scope.variable_scope("ctor_scope"):
468      # Create scope here:
469      tmpl_immed = template.make_template("a", variable_scoped_function,
470                                          True)
471      # default: create scope at __call__
472      tmpl_defer = template.make_template(
473          "b", variable_scoped_function, False)
474    with variable_scope.variable_scope("call_scope"):
475      inner_imm_var = tmpl_immed()
476      inner_defer_var = tmpl_defer()
477    outer_imm_var = tmpl_immed()
478    outer_defer_var = tmpl_defer()
479
480    self.assertIsNot(inner_imm_var, inner_defer_var)
481    self.assertIs(outer_imm_var, inner_imm_var)
482    self.assertIs(outer_defer_var, inner_defer_var)
483
484    self.assertEqual("ctor_scope/a/dummy:0", inner_imm_var.name)
485    self.assertEqual("call_scope/b/dummy:0", inner_defer_var.name)
486
487  @test_util.run_in_graph_and_eager_modes
488  def test_scope_access(self):
489    # Ensure that we can access the scope inside the template, because the name
490    # of that scope may be different from the name we pass to make_template, due
491    # to having been made unique by variable_scope.
492    with variable_scope.variable_scope("foo"):
493      # Create two templates with the same name, ensure scopes are made unique.
494      ta = template.make_template("bar", variable_scoped_function, True)
495      tb = template.make_template("bar", variable_scoped_function, True)
496
497    # Ensure we can get the scopes before either template is actually called.
498    self.assertEqual(ta.variable_scope.name, "foo/bar")
499    self.assertEqual(tb.variable_scope.name, "foo/bar_1")
500
501    with variable_scope.variable_scope("foo_2"):
502      # Create a template which defers scope creation.
503      tc = template.make_template("blah", variable_scoped_function, False)
504
505    # Before we call the template, the scope property will be set to None.
506    self.assertEqual(tc.variable_scope, None)
507    tc()
508
509    # Template is called at the top level, so there is no preceding "foo_2".
510    self.assertEqual(tc.variable_scope.name, "blah")
511
512  @test_util.run_in_graph_and_eager_modes
513  def test_custom_getter(self):
514    # Custom getter that maintains call count and forwards to true getter
515    custom_getter_count = [0]
516
517    def custom_getter(getter, name, *args, **kwargs):
518      custom_getter_count[0] += 1
519      return getter(name, *args, **kwargs)
520
521    # Test that custom getter is called both when variables are created and
522    # subsequently accessed
523    tmpl1 = template.make_template(
524        "s1", variable_scoped_function, custom_getter_=custom_getter)
525    self.assertEqual(custom_getter_count[0], 0)
526    tmpl1()
527    self.assertEqual(custom_getter_count[0], 1)
528    tmpl1()
529    self.assertEqual(custom_getter_count[0], 2)
530
531    # Test that custom getter is called when the variable scope is created
532    # during construction
533    custom_getter_count[0] = 0
534    tmpl2 = template.make_template(
535        "s2",
536        variable_scoped_function,
537        custom_getter_=custom_getter,
538        create_scope_now_=True)
539    self.assertEqual(custom_getter_count[0], 0)
540    tmpl2()
541    self.assertEqual(custom_getter_count[0], 1)
542    tmpl2()
543    self.assertEqual(custom_getter_count[0], 2)
544
545  @test_util.run_in_graph_and_eager_modes
546  def test_fails_gracefully(self):
547    for create_scope_now in [True, False]:
548      def module_function_with_one_arg(inputs):
549        w = variable_scope.get_variable(
550            "w", shape=[1], initializer=init_ops.zeros_initializer())
551        return inputs * w
552
553      templatized_function = template.make_template(
554          "f1", module_function_with_one_arg,
555          create_scope_now_=create_scope_now)
556      data = array_ops.zeros([1])
557      try:
558        # Try to connect with a kwarg which is unsupported.
559        templatized_function(data, is_training=True)
560      except TypeError:
561        pass
562
563      # The failed __call__ hasn't modified the inner state.
564      self.assertFalse(templatized_function._variables_created)
565      templatized_function(data)
566      self.assertTrue(templatized_function._variables_created)
567
568  @test_util.run_in_graph_and_eager_modes
569  def test_name_scopes_for_variable_scopes(self):
570    # Test that name scopes are not unnecessarily uniquified (but are
571    # still uniquified when necessary).
572    def linear_module(x, output_size):
573      w = variable_scope.get_variable(
574          "w", shape=[x.get_shape()[1], output_size],
575          initializer=init_ops.zeros_initializer())
576      b = variable_scope.get_variable(
577          "b", shape=[output_size],
578          initializer=init_ops.zeros_initializer())
579      return (math_ops.matmul(x, w) + b), w
580
581    def make_linear_module(output_size, name):
582      return template.make_template(
583          name,
584          linear_module,
585          output_size=output_size,
586          create_scope_now_=True)
587
588    inputs = array_ops.ones((3, 4))
589
590    linear1 = make_linear_module(output_size=2, name="foo")
591    outputs_a, w1 = linear1(inputs)
592    outputs_b, _ = linear1(inputs)
593    self.assertEqual("foo", linear1.variable_scope.name)
594    self.assertEqual("foo/w:0", w1.name)
595    if not context.executing_eagerly():
596      self.assertEqual(
597          "foo/add:0", outputs_a.name,
598          "First application of template should get "
599          "same name scope as variables.")
600      self.assertEqual(
601          "foo_1/add:0", outputs_b.name,
602          "Second application of template should get "
603          "a freshly uniquified name scope.")
604
605    linear2 = make_linear_module(output_size=2, name="foo")
606    outputs_c, w2 = linear2(inputs)
607    outputs_d, _ = linear2(inputs)
608    self.assertEqual(
609        "foo_1", linear2.variable_scope.name,
610        "New template gets a freshly uniquified variable scope "
611        "because 'foo' is already taken.")
612    self.assertEqual("foo_1/w:0", w2.name)
613    if not context.executing_eagerly():
614      self.assertEqual(
615          "foo_1_1/add:0", outputs_c.name,
616          "First application of template would get "
617          "same name scope as variables, but 'foo_1' is already "
618          "a name scope.")
619      self.assertEqual(
620          "foo_1_2/add:0", outputs_d.name,
621          "Second application of template should also get "
622          "a freshly uniquified name scope.")
623
624  @test_util.run_in_graph_and_eager_modes
625  def test_global_variables(self):
626    # Make sure global_variables are created.
627    with variable_scope.variable_scope("foo"):
628      # Create two templates with the same name, ensure scopes are made unique.
629      ta = template.make_template("bar", variable_scoped_function, True)
630      if context.executing_eagerly():
631        tb = template.make_template("s", function_with_side_create,
632                                    trainable=False)
633      else:
634        tb = template.make_template("s", function_with_create, trainable=False)
635
636    # Initially there are not variables created.
637    self.assertEqual([], list(ta.global_variables))
638    self.assertEqual([], list(tb.global_variables))
639    # After calling there are variables created.
640    ta()
641    tb()
642    # Ensure we can get the scopes before either template is actually called.
643    self.assertEqual(1, len(ta.global_variables))
644    self.assertEqual(2, len(tb.global_variables))
645
646  @test_util.run_in_graph_and_eager_modes
647  def test_trainable_variables(self):
648    # Make sure trainable_variables are created.
649    with variable_scope.variable_scope("foo2"):
650      # Create two templates with the same name, ensure scopes are made unique.
651      ta = template.make_template("bar", variable_scoped_function, True)
652      tb = template.make_template("bar", variable_scoped_function, True)
653
654    # Initially there are not variables created.
655    self.assertEqual([], list(ta.trainable_variables))
656    self.assertEqual([], list(tb.trainable_variables))
657    # After calling there are variables created.
658    ta()
659    tb()
660    # Ensure we can get the scopes before either template is actually called.
661    self.assertEqual(1, len(ta.trainable_variables))
662    self.assertEqual(1, len(tb.trainable_variables))
663    # None non-trainable variable was created.
664    self.assertEqual([], list(ta.non_trainable_variables))
665    self.assertEqual([], list(tb.non_trainable_variables))
666    # Ensure variables returns all the variables.
667    self.assertEqual(1, len(ta.variables))
668    self.assertEqual(1, len(tb.variables))
669
670  @test_util.run_in_graph_and_eager_modes
671  def test_non_trainable_variables(self):
672    # Make sure non_trainable_variables are created.
673    with variable_scope.variable_scope("foo2"):
674      ta = template.make_template("a", variable_scoped_function,
675                                  trainable=True)
676      tb = template.make_template("b", variable_scoped_function,
677                                  trainable=False)
678    # Initially there are not variables created.
679    self.assertEqual([], list(ta.variables))
680    self.assertEqual([], list(tb.variables))
681    # After calling there are variables created.
682    ta()
683    tb()
684    # Check the trainable and non_trainable variables.
685    self.assertEqual(1, len(ta.trainable_variables))
686    self.assertEqual([], list(ta.non_trainable_variables))
687
688    self.assertEqual([], list(tb.trainable_variables))
689    self.assertEqual(1, len(tb.non_trainable_variables))
690    # Ensure variables returns all the variables.
691    self.assertEqual(1, len(ta.variables))
692    self.assertEqual(1, len(tb.variables))
693
694  # TODO(apassos) handle local variables in Eager
695  @test_util.run_deprecated_v1
696  def test_local_variables(self):
697    # Make sure trainable_variables are created.
698    with variable_scope.variable_scope("foo3"):
699      # Create two templates with the same name, ensure scopes are made unique.
700      ta = template.make_template("bar", variable_scoped_function, True)
701      tb = template.make_template("bar",
702                                  variable_scoped_function_with_local_variable)
703
704    # Initially there are not variables created.
705    self.assertEqual([], list(ta.local_variables))
706    self.assertEqual([], list(tb.local_variables))
707    # After calling there are variables created.
708    ta()
709    tb()
710    # Ensure we can get the scopes before either template is actually called.
711    self.assertEqual(0, len(ta.local_variables))
712    self.assertEqual(1, len(tb.local_variables))
713
714  @test_util.run_in_graph_and_eager_modes
715  def test_make_template_with_defun(self):
716
717    def variable_scoped_function_no_return_value(scope_name):
718      # defun cannot compile functions that return non-Tensor objects
719      with variable_scope.variable_scope(scope_name):
720        _ = variable_scope.get_variable(
721            "dummy", shape=[1], initializer=init_ops.zeros_initializer())
722
723    tmpl = template.make_template_internal(
724        "s1",
725        variable_scoped_function_no_return_value,
726        create_graph_function_=True,
727        scope_name="test")
728
729    # The first invocation of tmpl1 creates variables, the second should
730    # be executed as a graph function.
731    tmpl()
732    v1 = tmpl.variables
733    tmpl()
734    v2 = tmpl.variables
735
736    self.assertEqual(len(v1), len(v2))
737    for v, w in zip(v1, v2):
738      self.assertIs(v, w)
739    self.assertEqual("s1/test/dummy:0", v1[0].name)
740
741
742if __name__ == "__main__":
743  test.main()
744