1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for make_template.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import functools 21import traceback 22 23from tensorflow.python.client import session 24from tensorflow.python.eager import context 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import random_seed 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import init_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import template 32from tensorflow.python.ops import variable_scope 33from tensorflow.python.ops import variables 34import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 35from tensorflow.python.platform import test 36from tensorflow.python.training import gradient_descent 37 38 39def variable_scoped_function(trainable=True): 40 return variable_scope.get_variable( 41 "dummy", shape=[1], trainable=trainable, 42 initializer=init_ops.zeros_initializer()) 43 44 45def internally_variable_scoped_function(scope_name): 46 with variable_scope.variable_scope(scope_name): 47 return variable_scope.get_variable( 48 "dummy", shape=[1], initializer=init_ops.zeros_initializer()) 49 50 51def function_with_create(trainable): 52 """Creates a variable as a side effect using tf.Variable.""" 53 variables.Variable(0, trainable=trainable) 54 return variable_scope.get_variable( 55 "dummy", shape=[1], initializer=init_ops.zeros_initializer()) 56 57 58def function_with_side_create(trainable, name="side"): 59 """Creates a variable as a side effect using tf.get_variable.""" 60 variable_scope.get_variable(name, shape=[1], trainable=trainable) 61 return variable_scope.get_variable( 62 "dummy", shape=[1], initializer=init_ops.zeros_initializer()) 63 64 65def variable_scoped_function_with_local_variable(): 66 variable_scope.get_local_variable( 67 "local", shape=[1], initializer=init_ops.zeros_initializer()) 68 return variable_scope.get_variable( 69 "dummy", shape=[1], initializer=init_ops.zeros_initializer()) 70 71 72class TemplateTest(test.TestCase): 73 74 @test_util.run_deprecated_v1 75 def test_end_to_end(self): 76 """This test shows a very simple line model with test_loss. 77 78 The template is used to share parameters between a training and test model. 79 """ 80 # y = 2x + 1 81 training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7]) 82 test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17]) 83 84 random_seed.set_random_seed(1234) 85 86 def test_line(x): 87 m = variable_scope.get_variable( 88 "w", shape=[], initializer=init_ops.truncated_normal_initializer()) 89 b = variable_scope.get_variable( 90 "b", shape=[], initializer=init_ops.truncated_normal_initializer()) 91 return x * m + b 92 93 line_template = template.make_template("line", test_line) 94 95 train_prediction = line_template(training_input) 96 test_prediction = line_template(test_input) 97 98 train_loss = math_ops.reduce_mean( 99 math_ops.square(train_prediction - training_output)) 100 test_loss = math_ops.reduce_mean( 101 math_ops.square(test_prediction - test_output)) 102 103 optimizer = gradient_descent.GradientDescentOptimizer(0.1) 104 train_op = optimizer.minimize(train_loss) 105 106 with session.Session() as sess: 107 self.evaluate(variables.global_variables_initializer()) 108 initial_test_loss = self.evaluate(test_loss) 109 self.evaluate(train_op) 110 final_test_loss = self.evaluate(test_loss) 111 112 # Parameters are tied, so the loss should have gone down when we trained it. 113 self.assertLess(final_test_loss, initial_test_loss) 114 115 def test_end_to_end_eager(self): 116 """This test shows a very simple line model with test_loss in eager mode. 117 118 The template is used to share parameters between a training and test model. 119 """ 120 with context.eager_mode(): 121 # y = 2x + 1 122 training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7]) 123 test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17]) 124 125 random_seed.set_random_seed(1234) 126 127 def test_line(x): 128 m = variable_scope.get_variable( 129 "w", shape=[], initializer=init_ops.truncated_normal_initializer()) 130 b = variable_scope.get_variable( 131 "b", shape=[], initializer=init_ops.truncated_normal_initializer()) 132 return x * m + b 133 134 line_template = template.make_template("line", test_line) 135 136 def train_loss(): 137 train_prediction = line_template(training_input) 138 return math_ops.reduce_mean( 139 math_ops.square(train_prediction - training_output)) 140 141 def test_loss(): 142 test_prediction = line_template(test_input) 143 return math_ops.reduce_mean( 144 math_ops.square(test_prediction - test_output)) 145 146 optimizer = gradient_descent.GradientDescentOptimizer(0.1) 147 initial_test_loss = test_loss() 148 optimizer.minimize(train_loss) 149 final_test_loss = test_loss() 150 151 # Parameters are tied, so the loss should have gone down after training. 152 self.assertLess(final_test_loss.numpy(), initial_test_loss.numpy()) 153 154 @test_util.run_in_graph_and_eager_modes 155 def test_skip_stack_frames(self): 156 first = traceback.format_stack() 157 second = traceback.format_stack() 158 result = template._skip_common_stack_elements(first, second) 159 self.assertEqual(1, len(result)) 160 self.assertNotEqual(len(first), len(result)) 161 162 @test_util.run_in_graph_and_eager_modes 163 def test_template_with_empty_name(self): 164 tpl = template.make_template("", variable_scoped_function) 165 with variable_scope.variable_scope("outer"): 166 x = variable_scope.get_variable("x", []) 167 v = tpl() 168 self.assertEqual("outer/", tpl.variable_scope_name) 169 self.assertEqual("outer//dummy:0", v.name) 170 if context.executing_eagerly(): 171 # In eager mode `x` is not visible to the template since the template does 172 # not rely on global collections. 173 self.assertEqual(1, len(tpl.variables)) 174 self.assertIs(v, tpl.variables[0]) 175 else: 176 self.assertEqual([x, v], tpl.variables) 177 178 @test_util.run_in_graph_and_eager_modes 179 def test_template_with_name(self): 180 tmpl1 = template.make_template("s1", variable_scoped_function) 181 tmpl2 = template.make_template("s1", variable_scoped_function) 182 183 v1 = tmpl1() 184 v2 = tmpl1() 185 v3 = tmpl2() 186 self.assertIs(v1, v2) 187 self.assertIsNot(v1, v3) 188 self.assertEqual("s1/dummy:0", v1.name) 189 self.assertEqual("s1_1/dummy:0", v3.name) 190 191 @test_util.run_deprecated_v1 192 def test_same_unique_name_raise_error(self): 193 tmpl1 = template.make_template( 194 "_", variable_scoped_function, unique_name_="s1") 195 tmpl1() 196 tmpl2 = template.make_template( 197 "_", variable_scoped_function, unique_name_="s1") 198 with self.assertRaisesRegex( 199 ValueError, "Variable s1/dummy already exists, disallowed.*"): 200 tmpl2() 201 202 def test_unique_name_raise_error_in_eager(self): 203 with context.eager_mode(): 204 with self.assertRaisesRegex( 205 ValueError, 206 "unique_name_ cannot be used when eager execution is enabled."): 207 template.make_template( 208 "_", variable_scoped_function, unique_name_="s1") 209 210 @test_util.run_deprecated_v1 211 def test_unique_name_and_reuse(self): 212 tmpl1 = template.make_template( 213 "_", variable_scoped_function, unique_name_="s1") 214 v1 = tmpl1() 215 v2 = tmpl1() 216 217 variable_scope.get_variable_scope().reuse_variables() 218 tmpl2 = template.make_template( 219 "_", variable_scoped_function, unique_name_="s1") 220 v3 = tmpl2() 221 222 self.assertIs(v1, v2) 223 self.assertIs(v1, v3) 224 self.assertEqual("s1/dummy:0", v1.name) 225 226 @test_util.run_in_graph_and_eager_modes 227 def test_template_in_scope(self): 228 tmpl1 = template.make_template("s1", variable_scoped_function) 229 tmpl2 = template.make_template("s1", variable_scoped_function) 230 231 with variable_scope.variable_scope("scope"): 232 v1 = tmpl1() 233 v3 = tmpl2() 234 235 # The template contract requires the following to ignore scope2. 236 with variable_scope.variable_scope("scope2"): 237 v2 = tmpl1() 238 self.assertIs(v1, v2) 239 self.assertIsNot(v1, v3) 240 self.assertEqual("scope/s1/dummy:0", v1.name) 241 self.assertEqual("scope/s1_1/dummy:0", v3.name) 242 243 @test_util.run_in_graph_and_eager_modes 244 def test_template_with_internal_reuse(self): 245 tmpl1 = template.make_template("s1", internally_variable_scoped_function) 246 tmpl2 = template.make_template("s1", internally_variable_scoped_function) 247 248 v1 = tmpl1("test") 249 v2 = tmpl1("test") 250 v3 = tmpl2("test") 251 self.assertIs(v1, v2) 252 self.assertIsNot(v1, v3) 253 self.assertEqual("s1/test/dummy:0", v1.name) 254 self.assertEqual("s1_1/test/dummy:0", v3.name) 255 256 with self.assertRaises(ValueError): 257 tmpl1("not_test") 258 259 @test_util.run_in_graph_and_eager_modes 260 def test_template_without_name(self): 261 with self.assertRaisesRegex(ValueError, "name cannot be None."): 262 template.make_template(None, variable_scoped_function) 263 264 @test_util.run_in_graph_and_eager_modes 265 def test_make_template(self): 266 # Test both that we can call it with positional and keywords. 267 tmpl1 = template.make_template( 268 "s1", internally_variable_scoped_function, scope_name="test") 269 tmpl2 = template.make_template( 270 "s1", internally_variable_scoped_function, scope_name="test") 271 272 v1 = tmpl1() 273 v2 = tmpl1() 274 v3 = tmpl2() 275 self.assertIs(v1, v2) 276 self.assertIsNot(v1, v3) 277 self.assertEqual("s1/test/dummy:0", v1.name) 278 self.assertEqual("s1_1/test/dummy:0", v3.name) 279 280 @test_util.run_deprecated_v1 281 def test_enforces_no_extra_trainable_variables(self): 282 tmpl = template.make_template("s", function_with_create, trainable=True) 283 284 tmpl() 285 with self.assertRaises(ValueError): 286 tmpl() 287 288 @test_util.run_in_graph_and_eager_modes 289 def test_enforces_no_extra_trainable_variables_eager(self): 290 tmpl = template.make_template("s", 291 function_with_side_create, 292 trainable=True) 293 294 tmpl(name="1") 295 with self.assertRaises(ValueError): 296 tmpl(name="2") 297 298 def test_permits_extra_non_trainable_variables(self): 299 tmpl = template.make_template("s", function_with_create, trainable=False) 300 self.assertIs(tmpl(), tmpl()) 301 302 def test_permits_extra_non_trainable_variables_eager(self): 303 with context.eager_mode(): 304 tmpl = template.make_template("s", 305 function_with_side_create, 306 trainable=False) 307 self.assertIs(tmpl(name="1"), tmpl(name="2")) 308 309 @test_util.run_in_graph_and_eager_modes 310 def test_internal_variable_reuse(self): 311 312 def nested(): 313 with variable_scope.variable_scope("nested") as vs: 314 v1 = variable_scope.get_variable( 315 "x", initializer=init_ops.zeros_initializer(), shape=[]) 316 with variable_scope.variable_scope(vs, reuse=True): 317 v2 = variable_scope.get_variable("x") 318 self.assertIs(v1, v2) 319 return v1 320 321 tmpl1 = template.make_template("s1", nested) 322 tmpl2 = template.make_template("s1", nested) 323 324 v1 = tmpl1() 325 v2 = tmpl1() 326 v3 = tmpl2() 327 self.assertIs(v1, v2) 328 self.assertIsNot(v1, v3) 329 self.assertEqual("s1/nested/x:0", v1.name) 330 self.assertEqual("s1_1/nested/x:0", v3.name) 331 332 @test_util.run_in_graph_and_eager_modes 333 def test_nested_templates(self): 334 335 def nested_template(): 336 nested1 = template.make_template("nested", variable_scoped_function) 337 nested2 = template.make_template("nested", variable_scoped_function) 338 v1 = nested1() 339 v2 = nested2() 340 341 # nested1 and nested2 should not share variables 342 self.assertIsNot(v1, v2) 343 344 # Variables created by nested1 should be isolated from variables 345 # created by nested2. 346 self.assertEqual(1, len(nested1.variables)) 347 self.assertEqual(1, len(nested2.variables)) 348 self.assertIs(nested1.variables[0], v1) 349 self.assertIs(nested2.variables[0], v2) 350 self.assertEqual(1, len(nested1.trainable_variables)) 351 self.assertEqual(1, len(nested2.trainable_variables)) 352 self.assertIs(nested1.trainable_variables[0], v1) 353 self.assertIs(nested2.trainable_variables[0], v2) 354 self.assertEqual(len(nested1.non_trainable_variables), 0) 355 self.assertEqual(len(nested2.non_trainable_variables), 0) 356 return v1, v2 357 358 tmpl1 = template.make_template("s1", nested_template) 359 tmpl2 = template.make_template("s1", nested_template) 360 361 v1, v2 = tmpl1() 362 v3, v4 = tmpl1() 363 v5, v6 = tmpl2() 364 365 # The second invocation of tmpl1 should reuse the variables 366 # created in the first invocation. 367 self.assertIs(v1, v3) 368 self.assertIs(v2, v4) 369 for v, w in zip(tmpl1.variables, [v1, v2]): 370 self.assertIs(v, w) 371 for v, w in zip(tmpl1.trainable_variables, [v1, v2]): 372 self.assertIs(v, w) 373 self.assertEqual(len(tmpl1.non_trainable_variables), 0) 374 375 # tmpl1 and tmpl2 should not share variables. 376 self.assertIsNot(v1, v5) 377 self.assertIsNot(v2, v6) 378 for v, w in zip(tmpl2.variables, [v5, v6]): 379 self.assertIs(v, w) 380 for v, w in zip(tmpl2.trainable_variables, [v5, v6]): 381 self.assertIs(v, w) 382 self.assertEqual(len(tmpl2.non_trainable_variables), 0) 383 self.assertEqual("s1/nested/dummy:0", v1.name) 384 self.assertEqual("s1/nested_1/dummy:0", v2.name) 385 self.assertEqual("s1_1/nested/dummy:0", v5.name) 386 self.assertEqual("s1_1/nested_1/dummy:0", v6.name) 387 388 self.assertEqual(2, len(tmpl1._checkpoint_dependencies)) 389 self.assertEqual("nested", tmpl1._checkpoint_dependencies[0].name) 390 self.assertEqual("nested_1", tmpl1._checkpoint_dependencies[1].name) 391 392 @test_util.run_in_graph_and_eager_modes 393 def test_nested_templates_with_defun(self): 394 395 def variable_scoped_function_no_return_value(trainable=True): 396 # defun cannot compile functions that return non-Tensor objects 397 _ = variable_scope.get_variable( 398 "dummy", 399 shape=[1], 400 trainable=trainable, 401 initializer=init_ops.zeros_initializer()) 402 403 def nested_template(): 404 nested1 = template.make_template_internal( 405 "nested", 406 variable_scoped_function_no_return_value, 407 create_graph_function_=True) 408 nested2 = template.make_template_internal( 409 "nested", 410 variable_scoped_function_no_return_value, 411 create_graph_function_=True) 412 nested1() 413 nested2() 414 v1 = nested1.variables 415 v2 = nested2.variables 416 417 self.assertEqual(len(v1), 1) 418 self.assertEqual(len(v2), 1) 419 420 # nested1 and nested2 should not share variables 421 self.assertIsNot(v1[0], v2[0]) 422 self.assertIs(nested1.trainable_variables[0], v1[0]) 423 self.assertIs(nested2.trainable_variables[0], v2[0]) 424 self.assertEqual(len(nested1.non_trainable_variables), 0) 425 self.assertEqual(len(nested2.non_trainable_variables), 0) 426 427 tmpl1 = template.make_template("s1", nested_template) 428 tmpl2 = template.make_template("s1", nested_template) 429 430 tmpl1() 431 v1 = tmpl1.variables 432 tmpl1() 433 v2 = tmpl1.variables 434 tmpl2() 435 v3 = tmpl2.variables 436 437 # The second invocation of tmpl1 should reuse the variables 438 # created in the first invocation. 439 for v, w in zip(v1, v2): 440 self.assertIs(v, w) 441 442 # tmpl1 and tmpl2 should not share variables. 443 for v, w in zip(v1, v3): 444 self.assertIsNot(v, w) 445 446 self.assertEqual("s1/nested/dummy:0", v1[0].name) 447 self.assertEqual("s1/nested_1/dummy:0", v1[1].name) 448 self.assertEqual("s1_1/nested/dummy:0", v3[0].name) 449 self.assertEqual("s1_1/nested_1/dummy:0", v3[1].name) 450 451 def test_graph_function_no_name(self): 452 with context.eager_mode(): 453 454 def f(_, y): 455 return y + 1 456 457 partial = functools.partial(f, 1.0) 458 tmpl = template.make_template_internal( 459 "a", partial, create_graph_function_=True) 460 self.assertAllEqual(tmpl(ops.convert_to_tensor(1.0)), 2.0) 461 462 @test_util.run_in_graph_and_eager_modes 463 def test_immediate_scope_creation(self): 464 # Create templates in scope a then call in scope b. make_template should 465 # capture the scope the first time it is called, and make_immediate_template 466 # should capture the scope at construction time. 467 with variable_scope.variable_scope("ctor_scope"): 468 # Create scope here: 469 tmpl_immed = template.make_template("a", variable_scoped_function, 470 True) 471 # default: create scope at __call__ 472 tmpl_defer = template.make_template( 473 "b", variable_scoped_function, False) 474 with variable_scope.variable_scope("call_scope"): 475 inner_imm_var = tmpl_immed() 476 inner_defer_var = tmpl_defer() 477 outer_imm_var = tmpl_immed() 478 outer_defer_var = tmpl_defer() 479 480 self.assertIsNot(inner_imm_var, inner_defer_var) 481 self.assertIs(outer_imm_var, inner_imm_var) 482 self.assertIs(outer_defer_var, inner_defer_var) 483 484 self.assertEqual("ctor_scope/a/dummy:0", inner_imm_var.name) 485 self.assertEqual("call_scope/b/dummy:0", inner_defer_var.name) 486 487 @test_util.run_in_graph_and_eager_modes 488 def test_scope_access(self): 489 # Ensure that we can access the scope inside the template, because the name 490 # of that scope may be different from the name we pass to make_template, due 491 # to having been made unique by variable_scope. 492 with variable_scope.variable_scope("foo"): 493 # Create two templates with the same name, ensure scopes are made unique. 494 ta = template.make_template("bar", variable_scoped_function, True) 495 tb = template.make_template("bar", variable_scoped_function, True) 496 497 # Ensure we can get the scopes before either template is actually called. 498 self.assertEqual(ta.variable_scope.name, "foo/bar") 499 self.assertEqual(tb.variable_scope.name, "foo/bar_1") 500 501 with variable_scope.variable_scope("foo_2"): 502 # Create a template which defers scope creation. 503 tc = template.make_template("blah", variable_scoped_function, False) 504 505 # Before we call the template, the scope property will be set to None. 506 self.assertEqual(tc.variable_scope, None) 507 tc() 508 509 # Template is called at the top level, so there is no preceding "foo_2". 510 self.assertEqual(tc.variable_scope.name, "blah") 511 512 @test_util.run_in_graph_and_eager_modes 513 def test_custom_getter(self): 514 # Custom getter that maintains call count and forwards to true getter 515 custom_getter_count = [0] 516 517 def custom_getter(getter, name, *args, **kwargs): 518 custom_getter_count[0] += 1 519 return getter(name, *args, **kwargs) 520 521 # Test that custom getter is called both when variables are created and 522 # subsequently accessed 523 tmpl1 = template.make_template( 524 "s1", variable_scoped_function, custom_getter_=custom_getter) 525 self.assertEqual(custom_getter_count[0], 0) 526 tmpl1() 527 self.assertEqual(custom_getter_count[0], 1) 528 tmpl1() 529 self.assertEqual(custom_getter_count[0], 2) 530 531 # Test that custom getter is called when the variable scope is created 532 # during construction 533 custom_getter_count[0] = 0 534 tmpl2 = template.make_template( 535 "s2", 536 variable_scoped_function, 537 custom_getter_=custom_getter, 538 create_scope_now_=True) 539 self.assertEqual(custom_getter_count[0], 0) 540 tmpl2() 541 self.assertEqual(custom_getter_count[0], 1) 542 tmpl2() 543 self.assertEqual(custom_getter_count[0], 2) 544 545 @test_util.run_in_graph_and_eager_modes 546 def test_fails_gracefully(self): 547 for create_scope_now in [True, False]: 548 def module_function_with_one_arg(inputs): 549 w = variable_scope.get_variable( 550 "w", shape=[1], initializer=init_ops.zeros_initializer()) 551 return inputs * w 552 553 templatized_function = template.make_template( 554 "f1", module_function_with_one_arg, 555 create_scope_now_=create_scope_now) 556 data = array_ops.zeros([1]) 557 try: 558 # Try to connect with a kwarg which is unsupported. 559 templatized_function(data, is_training=True) 560 except TypeError: 561 pass 562 563 # The failed __call__ hasn't modified the inner state. 564 self.assertFalse(templatized_function._variables_created) 565 templatized_function(data) 566 self.assertTrue(templatized_function._variables_created) 567 568 @test_util.run_in_graph_and_eager_modes 569 def test_name_scopes_for_variable_scopes(self): 570 # Test that name scopes are not unnecessarily uniquified (but are 571 # still uniquified when necessary). 572 def linear_module(x, output_size): 573 w = variable_scope.get_variable( 574 "w", shape=[x.get_shape()[1], output_size], 575 initializer=init_ops.zeros_initializer()) 576 b = variable_scope.get_variable( 577 "b", shape=[output_size], 578 initializer=init_ops.zeros_initializer()) 579 return (math_ops.matmul(x, w) + b), w 580 581 def make_linear_module(output_size, name): 582 return template.make_template( 583 name, 584 linear_module, 585 output_size=output_size, 586 create_scope_now_=True) 587 588 inputs = array_ops.ones((3, 4)) 589 590 linear1 = make_linear_module(output_size=2, name="foo") 591 outputs_a, w1 = linear1(inputs) 592 outputs_b, _ = linear1(inputs) 593 self.assertEqual("foo", linear1.variable_scope.name) 594 self.assertEqual("foo/w:0", w1.name) 595 if not context.executing_eagerly(): 596 self.assertEqual( 597 "foo/add:0", outputs_a.name, 598 "First application of template should get " 599 "same name scope as variables.") 600 self.assertEqual( 601 "foo_1/add:0", outputs_b.name, 602 "Second application of template should get " 603 "a freshly uniquified name scope.") 604 605 linear2 = make_linear_module(output_size=2, name="foo") 606 outputs_c, w2 = linear2(inputs) 607 outputs_d, _ = linear2(inputs) 608 self.assertEqual( 609 "foo_1", linear2.variable_scope.name, 610 "New template gets a freshly uniquified variable scope " 611 "because 'foo' is already taken.") 612 self.assertEqual("foo_1/w:0", w2.name) 613 if not context.executing_eagerly(): 614 self.assertEqual( 615 "foo_1_1/add:0", outputs_c.name, 616 "First application of template would get " 617 "same name scope as variables, but 'foo_1' is already " 618 "a name scope.") 619 self.assertEqual( 620 "foo_1_2/add:0", outputs_d.name, 621 "Second application of template should also get " 622 "a freshly uniquified name scope.") 623 624 @test_util.run_in_graph_and_eager_modes 625 def test_global_variables(self): 626 # Make sure global_variables are created. 627 with variable_scope.variable_scope("foo"): 628 # Create two templates with the same name, ensure scopes are made unique. 629 ta = template.make_template("bar", variable_scoped_function, True) 630 if context.executing_eagerly(): 631 tb = template.make_template("s", function_with_side_create, 632 trainable=False) 633 else: 634 tb = template.make_template("s", function_with_create, trainable=False) 635 636 # Initially there are not variables created. 637 self.assertEqual([], list(ta.global_variables)) 638 self.assertEqual([], list(tb.global_variables)) 639 # After calling there are variables created. 640 ta() 641 tb() 642 # Ensure we can get the scopes before either template is actually called. 643 self.assertEqual(1, len(ta.global_variables)) 644 self.assertEqual(2, len(tb.global_variables)) 645 646 @test_util.run_in_graph_and_eager_modes 647 def test_trainable_variables(self): 648 # Make sure trainable_variables are created. 649 with variable_scope.variable_scope("foo2"): 650 # Create two templates with the same name, ensure scopes are made unique. 651 ta = template.make_template("bar", variable_scoped_function, True) 652 tb = template.make_template("bar", variable_scoped_function, True) 653 654 # Initially there are not variables created. 655 self.assertEqual([], list(ta.trainable_variables)) 656 self.assertEqual([], list(tb.trainable_variables)) 657 # After calling there are variables created. 658 ta() 659 tb() 660 # Ensure we can get the scopes before either template is actually called. 661 self.assertEqual(1, len(ta.trainable_variables)) 662 self.assertEqual(1, len(tb.trainable_variables)) 663 # None non-trainable variable was created. 664 self.assertEqual([], list(ta.non_trainable_variables)) 665 self.assertEqual([], list(tb.non_trainable_variables)) 666 # Ensure variables returns all the variables. 667 self.assertEqual(1, len(ta.variables)) 668 self.assertEqual(1, len(tb.variables)) 669 670 @test_util.run_in_graph_and_eager_modes 671 def test_non_trainable_variables(self): 672 # Make sure non_trainable_variables are created. 673 with variable_scope.variable_scope("foo2"): 674 ta = template.make_template("a", variable_scoped_function, 675 trainable=True) 676 tb = template.make_template("b", variable_scoped_function, 677 trainable=False) 678 # Initially there are not variables created. 679 self.assertEqual([], list(ta.variables)) 680 self.assertEqual([], list(tb.variables)) 681 # After calling there are variables created. 682 ta() 683 tb() 684 # Check the trainable and non_trainable variables. 685 self.assertEqual(1, len(ta.trainable_variables)) 686 self.assertEqual([], list(ta.non_trainable_variables)) 687 688 self.assertEqual([], list(tb.trainable_variables)) 689 self.assertEqual(1, len(tb.non_trainable_variables)) 690 # Ensure variables returns all the variables. 691 self.assertEqual(1, len(ta.variables)) 692 self.assertEqual(1, len(tb.variables)) 693 694 # TODO(apassos) handle local variables in Eager 695 @test_util.run_deprecated_v1 696 def test_local_variables(self): 697 # Make sure trainable_variables are created. 698 with variable_scope.variable_scope("foo3"): 699 # Create two templates with the same name, ensure scopes are made unique. 700 ta = template.make_template("bar", variable_scoped_function, True) 701 tb = template.make_template("bar", 702 variable_scoped_function_with_local_variable) 703 704 # Initially there are not variables created. 705 self.assertEqual([], list(ta.local_variables)) 706 self.assertEqual([], list(tb.local_variables)) 707 # After calling there are variables created. 708 ta() 709 tb() 710 # Ensure we can get the scopes before either template is actually called. 711 self.assertEqual(0, len(ta.local_variables)) 712 self.assertEqual(1, len(tb.local_variables)) 713 714 @test_util.run_in_graph_and_eager_modes 715 def test_make_template_with_defun(self): 716 717 def variable_scoped_function_no_return_value(scope_name): 718 # defun cannot compile functions that return non-Tensor objects 719 with variable_scope.variable_scope(scope_name): 720 _ = variable_scope.get_variable( 721 "dummy", shape=[1], initializer=init_ops.zeros_initializer()) 722 723 tmpl = template.make_template_internal( 724 "s1", 725 variable_scoped_function_no_return_value, 726 create_graph_function_=True, 727 scope_name="test") 728 729 # The first invocation of tmpl1 creates variables, the second should 730 # be executed as a graph function. 731 tmpl() 732 v1 = tmpl.variables 733 tmpl() 734 v2 = tmpl.variables 735 736 self.assertEqual(len(v1), len(v2)) 737 for v, w in zip(v1, v2): 738 self.assertIs(v, w) 739 self.assertEqual("s1/test/dummy:0", v1[0].name) 740 741 742if __name__ == "__main__": 743 test.main() 744