1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tf.Module`.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22import collections 23import itertools 24 25from absl.testing import parameterized 26import six 27 28from tensorflow.python import tf2 29from tensorflow.python.distribute import ps_values 30from tensorflow.python.distribute import tpu_values 31from tensorflow.python.distribute import values as distributed_values 32from tensorflow.python.eager import context 33from tensorflow.python.eager import def_function 34from tensorflow.python.framework import composite_tensor 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import test_util 37from tensorflow.python.framework import type_spec 38from tensorflow.python.module import module 39from tensorflow.python.ops import variables 40from tensorflow.python.platform import test 41 42 43class TestModuleNaming(test_util.TensorFlowTestCase): 44 45 def test_single_name(self): 46 mod = module.Module(name="simple") 47 self.assertEqual(mod.name, "simple") 48 self.assertEqual(mod.name_scope.name, "simple/") 49 50 def test_construct_in_scope(self): 51 with ops.name_scope("foo", skip_on_eager=False): 52 mod = module.Module(name="bar") 53 self.assertEqual(mod.name, "bar") 54 self.assertEqual(mod.name_scope.name, "foo/bar/") 55 56 def test_enters_name_scope_in_call(self): 57 mod = ReturnsNameScopeModule() 58 for _ in range(3): 59 self.assertEqual(mod(), mod.name_scope.name) 60 61 def test_enters_name_scope_in_other_method(self): 62 mod = ReturnsNameScopeModule() 63 for _ in range(3): 64 self.assertEqual(mod.alternative_forward(), mod.name_scope.name) 65 66 def test_subclassed_module(self): 67 mod = SubclassedReturnsNameScopeModule() 68 for _ in range(3): 69 self.assertEqual(mod.alternative_forward(), mod.name_scope.name) 70 self.assertEqual(mod.alternative_alternative_forward(), 71 mod.name_scope.name) 72 73 def test_submodule_created_late(self): 74 m = TreeModule() 75 self.assertEqual(m.name, "tree_module") 76 self.assertEqual(m.name_scope.name, "tree_module/") 77 leaf1 = m.new_leaf() 78 self.assertEqual(leaf1.name, "tree_module") 79 self.assertEqual(leaf1.name_scope.name, "tree_module/tree_module/") 80 81 def test_does_not_evaluate_property_methods(self): 82 mod = PropertyThrowsWhenCalledModule() 83 with self.assertRaises(AssertionError): 84 mod.raise_assertion_error # pylint: disable=pointless-statement 85 86 def test_overridden_name_scope(self): 87 mod = ModuleOverridingNameScope() 88 self.assertEqual(mod(), mod.name_scope.name) 89 self.assertEqual(mod.alternative_forward(), mod.name_scope.name) 90 91 def test_patched_callable(self): 92 with ops.name_scope("foo", skip_on_eager=False): 93 mod = module.Module(name="bar") 94 mod.foo = get_name_scope 95 # `foo` is not a method so we do not re-enter the name scope. 96 self.assertEqual(mod.foo(), "") 97 98 def test_property(self): 99 mod = PropertyModule() 100 mod.some_property = None, None # None, None for the linter. 101 getter_scope_name, setter_scope_name = mod.some_property 102 self.assertEqual(getter_scope_name, "property_module/") 103 self.assertEqual(setter_scope_name, "property_module/") 104 105 def test_property_no_name_scope(self): 106 mod = PropertyModule() 107 mod.no_name_scope_property = None, None # None, None for the linter. 108 getter_scope_name, setter_scope_name = mod.no_name_scope_property 109 self.assertEqual(getter_scope_name, "") 110 self.assertEqual(setter_scope_name, "") 111 112 def test_invalid_name(self): 113 msg = ".* is not a valid module name" 114 with self.assertRaisesRegex(ValueError, msg): 115 module.Module(name="$Foo") 116 117 @test_util.run_in_graph_and_eager_modes 118 def test_modules_not_numbered_in_eager(self): 119 if not context.executing_eagerly(): 120 self.skipTest("Eager specific") 121 122 mod = RecursiveModule(2) 123 self.assertEqual(mod.name_scope.name, "badger/") 124 self.assertEqual(mod.child.name_scope.name, "badger/badger/") 125 126 mod = RecursiveModule(2) 127 self.assertEqual(mod.name_scope.name, "badger/") 128 self.assertEqual(mod.child.name_scope.name, "badger/badger/") 129 130 @test_util.run_in_graph_and_eager_modes 131 def test_module_numbering_in_graph(self): 132 if context.executing_eagerly(): 133 self.skipTest("Graph specific") 134 135 mod = RecursiveModule(2) 136 self.assertEqual(mod.name_scope.name, "badger/") 137 self.assertEqual(mod.child.name_scope.name, "badger/badger/") 138 139 mod = RecursiveModule(2) 140 self.assertEqual(mod.name_scope.name, "badger_1/") 141 self.assertEqual(mod.child.name_scope.name, "badger_1/badger/") 142 143 def test_ctor_error_closes_name_scope(self): 144 with self.assertRaises(ErrorModuleError): 145 # If super constructor is called then a name scope is opened then an error 146 # is thrown. The metaclass should handle this and close the namescope 147 # before re-throwing the exception. 148 ErrorModule(call_super=True) 149 150 self.assertEqual("", get_name_scope()) 151 152 def test_ctor_error_handles_ctor_not_opening_name_scope(self): 153 with self.assertRaises(ErrorModuleError): 154 # If super ctor is not called then the name scope isn't opened. We need to 155 # ensure that this doesn't trigger an exception (e.g. the metaclass trying 156 # to __exit__ a non-existent name scope). 157 ErrorModule(call_super=False) 158 159 self.assertEqual("", get_name_scope()) 160 161 def test_forward_method_closes_name_scope(self): 162 mod = ErrorModule(call_super=True, raise_in_constructor=False) 163 with self.assertRaises(ErrorModuleError): 164 mod() 165 166 self.assertEqual("", get_name_scope()) 167 168 def test_get_attr_doesnt_enter_name_scope(self): 169 scope_names = [] 170 171 class GetAttrModule(module.Module): 172 173 def __getattr__(self, name): 174 scope_names.append((name, get_name_scope())) 175 return super(GetAttrModule, self).__getattr__(name) 176 177 mod = GetAttrModule() 178 with self.assertRaises(AttributeError): 179 mod.does_not_exist # pylint: disable=pointless-statement 180 self.assertIn(("does_not_exist", ""), scope_names) 181 182 def test_get_attribute_doesnt_enter_name_scope(self): 183 scope_names = [] 184 185 class GetAttributeModule(module.Module): 186 187 def __getattribute__(self, name): 188 scope_names.append((name, get_name_scope())) 189 return super(GetAttributeModule, self).__getattribute__(name) 190 191 mod = GetAttributeModule() 192 with self.assertRaises(AttributeError): 193 mod.does_not_exist # pylint: disable=pointless-statement 194 self.assertIn(("does_not_exist", ""), scope_names) 195 196 197class VariableNamingTest(test_util.TensorFlowTestCase): 198 199 def test_variable_names(self): 200 mod = RecursiveModule(3) 201 self.assertEqual(mod.w.name, "badger/mushroom:0") 202 self.assertEqual(mod.child.w.name, "badger/badger/mushroom:0") 203 self.assertEqual(mod.child.child.w.name, "badger/badger/badger/mushroom:0") 204 205 206class NameScopeTest(test_util.TensorFlowTestCase): 207 208 @test_util.run_deprecated_v1 209 def test_not_memoized_in_tf1(self): 210 if tf2.enabled(): 211 self.skipTest("Requires TF1") 212 213 mod = module.Module(name="name") 214 name_scope_1 = mod.name_scope 215 name_scope_2 = mod.name_scope 216 self.assertIsNot(name_scope_1, name_scope_2) 217 self.assertEqual(name_scope_1.name, name_scope_2.name) 218 219 def test_memoized_in_tf2(self): 220 if not tf2.enabled(): 221 self.skipTest("Requires TF2") 222 223 mod = module.Module(name="name") 224 name_scope_1 = mod.name_scope 225 name_scope_2 = mod.name_scope 226 self.assertIs(name_scope_1, name_scope_2) 227 228 229class VariableTrackingTest(test_util.TensorFlowTestCase): 230 231 def test_variables(self): 232 m = RecursiveModule(3) 233 self.assertEqual(m.variables, (m.w, m.child.w, m.child.child.w)) 234 self.assertEqual(m.child.variables, (m.child.w, m.child.child.w)) 235 self.assertEqual(m.child.child.variables, (m.child.child.w,)) 236 237 def test_trainable_variables(self): 238 m = RecursiveModule(3) 239 self.assertEqual(m.trainable_variables, 240 (m.w, m.child.w, m.child.child.w)) 241 self.assertEqual(m.child.trainable_variables, 242 (m.child.w, m.child.child.w)) 243 self.assertEqual(m.child.child.trainable_variables, (m.child.child.w,)) 244 245 def test_trainable_variables_ignores_non_trainable(self): 246 m = RecursiveModule(3, trainable=False) 247 self.assertEqual(len(m.trainable_variables), 0) 248 self.assertEqual(len(m.child.trainable_variables), 0) 249 self.assertEqual(len(m.child.child.trainable_variables), 0) 250 251 def test_supports_distributed_variables(self): 252 mirrored = distributed_values.MirroredVariable( 253 None, [variables.Variable(1.)], variables.VariableAggregation.SUM) 254 tpu = tpu_values.TPUMirroredVariable( 255 strategy=None, values=[variables.Variable(42.)], aggregation=None) 256 aggregating = ps_values.AggregatingVariable( 257 strategy=None, v=variables.Variable(1.), aggregation=None) 258 259 m = module.Module() 260 m.a = mirrored 261 m.b = tpu 262 m.c = aggregating 263 self.assertEqual(m.variables, (mirrored, tpu, aggregating)) 264 265 def test_composite_variable(self): 266 267 class Spec(type_spec.TypeSpec): 268 269 value_type = property(lambda self: CompositeVariable) 270 271 def _component_specs(self): 272 pass 273 274 def _serialize(self): 275 pass 276 277 def _to_components(self, value): 278 return value._variables 279 280 def _from_components(self, variable_list): 281 return CompositeVariable(variable_list) 282 283 class CompositeVariable(composite_tensor.CompositeTensor): 284 285 def __init__(self, variable_list): 286 self._variables = variable_list 287 288 @property 289 def _type_spec(self): 290 return Spec() 291 292 m = module.Module() 293 m.a = CompositeVariable([variables.Variable(1.), variables.Variable(2.)]) 294 self.assertAllEqual(m.variables, m.a._variables) 295 296 297class ModuleTrackingTest(test_util.TensorFlowTestCase): 298 299 def test_submodules(self): 300 m = RecursiveModule(3) 301 self.assertEqual(list(m.submodules), [m.child, m.child.child]) 302 self.assertEqual(list(m.child.submodules), [m.child.child]) 303 self.assertEqual(list(m.child.child.submodules), []) 304 305 def test_non_ctor_submodule(self): 306 m = TreeModule() 307 leaf1 = m.new_leaf() 308 self.assertEqual(set(m.submodules), {leaf1}) 309 leaf2 = m.new_leaf() 310 self.assertEqual(set(m.submodules), {leaf1, leaf2}) 311 312 313class ForwardMethodsTest(test_util.TensorFlowTestCase): 314 315 def testFunctionType(self): 316 mod = ModuleWithFunctionAnnotatedCall() 317 self.assertIsInstance(mod.forward, def_function.Function) 318 self.assertIsInstance(mod.forward_ag, def_function.Function) 319 320 def testEntersNameScope_call(self): 321 mod = ModuleWithFunctionAnnotatedCall() 322 self.assertEqual(self.evaluate(mod.forward()), 323 b"module_with_function_annotated_call/") 324 self.assertEqual(self.evaluate(mod.forward_ag()), 325 b"module_with_function_annotated_call/") 326 327 def testEntersNameScope_concreteFunction(self): 328 mod = ModuleWithFunctionAnnotatedCall() 329 self.assertEqual(self.evaluate(mod.forward.get_concrete_function()()), 330 b"module_with_function_annotated_call/") 331 self.assertEqual(self.evaluate(mod.forward_ag.get_concrete_function()()), 332 b"module_with_function_annotated_call/") 333 334 335class AbcTest(test_util.TensorFlowTestCase): 336 337 def testAbstract(self): 338 msg = "Can't instantiate .* abstract methods" 339 with self.assertRaisesRegex(TypeError, msg): 340 AbstractModule() # pylint: disable=abstract-class-instantiated 341 342 def testConcrete(self): 343 mod = ConcreteModule() 344 x, scope_name = mod(2.) 345 self.assertEqual(x, 4.) 346 self.assertEqual(scope_name, "concrete_module/") 347 self.assertEqual(get_name_scope(), "") 348 349 350def get_name_scope(): 351 with ops.name_scope("x", skip_on_eager=False) as ns: 352 ns = "/".join(ns.split("/")[:-2]) 353 return ns + "/" if ns else "" 354 355 356class ErrorModuleError(Exception): 357 pass 358 359 360class ErrorModule(module.Module): 361 362 def __init__(self, call_super, raise_in_constructor=True): 363 if call_super: 364 super(ErrorModule, self).__init__() 365 if raise_in_constructor: 366 raise ErrorModuleError("Deliberate error!") 367 368 def __call__(self): 369 raise ErrorModuleError("Deliberate error!") 370 371 372class RecursiveModule(module.Module): 373 374 def __init__(self, depth, trainable=True): 375 super(RecursiveModule, self).__init__(name="badger") 376 with self.name_scope: 377 self.child = None 378 if depth > 1: 379 self.child = RecursiveModule(depth - 1, trainable=trainable) 380 self.w = variables.Variable(1.0, trainable=trainable, name="mushroom") 381 382 383@six.add_metaclass(abc.ABCMeta) 384class AbstractModule(module.Module): 385 386 @abc.abstractmethod 387 def __call__(self, x): 388 pass 389 390 391class ConcreteModule(AbstractModule): 392 393 @module.Module.with_name_scope 394 def __call__(self, x): 395 return x ** 2, get_name_scope() 396 397 398class TreeModule(module.Module): 399 400 def __init__(self, name=None): 401 super(TreeModule, self).__init__(name=name) 402 self._leaves = [] 403 404 @module.Module.with_name_scope 405 def new_leaf(self, name=None): 406 leaf = TreeModule(name=name) 407 self._leaves.append(leaf) 408 return leaf 409 410 411class ReturnsNameScopeModule(module.Module): 412 413 @module.Module.with_name_scope 414 def alternative_forward(self): 415 return get_name_scope() 416 417 @module.Module.with_name_scope 418 def __call__(self): 419 return get_name_scope() 420 421 422class SubclassedReturnsNameScopeModule(ReturnsNameScopeModule): 423 424 @module.Module.with_name_scope 425 def alternative_alternative_forward(self): 426 return get_name_scope() 427 428 429class PropertyThrowsWhenCalledModule(module.Module): 430 431 @property 432 def raise_assertion_error(self): 433 raise AssertionError 434 435 436class ModuleOverridingNameScope(ReturnsNameScopeModule): 437 438 @property 439 def name_scope(self): 440 return ops.name_scope("yolo/", skip_on_eager=False) 441 442 443class ModuleWithFunctionAnnotatedCall(module.Module): 444 445 @def_function.function(autograph=False) 446 @module.Module.with_name_scope 447 def forward(self): 448 return get_name_scope() 449 450 @def_function.function(autograph=True) 451 @module.Module.with_name_scope 452 def forward_ag(self): 453 return get_name_scope() 454 455 456class PropertyModule(module.Module): 457 458 def __init__(self): 459 super(PropertyModule, self).__init__() 460 self._setter_scope_name = None 461 462 @property 463 @module.Module.with_name_scope 464 def some_property(self): 465 getter_scope_name = get_name_scope() 466 return getter_scope_name, self._setter_scope_name 467 468 @some_property.setter 469 @module.Module.with_name_scope 470 def some_property(self, my_property): 471 self._setter_scope_name = get_name_scope() 472 473 @property 474 def no_name_scope_property(self): 475 getter_scope_name = get_name_scope() 476 return getter_scope_name, self._setter_scope_name 477 478 @no_name_scope_property.setter 479 def no_name_scope_property(self, my_property): 480 self._setter_scope_name = get_name_scope() 481 482NamedPair = collections.namedtuple("NamedPair", ("first", "second")) 483mk_index_dict = lambda v: dict(enumerate(v)) 484 485 486class FlattenTest(parameterized.TestCase, test_util.TensorFlowTestCase): 487 488 @parameterized.parameters(lambda v: NamedPair(*v), list, tuple, mk_index_dict) 489 def test_flatten(self, container_type): 490 parent = SimpleModule(container_type=container_type) 491 child = parent.c 492 493 self.assertEqual( 494 list(parent._flatten(recursive=False, predicate=is_member)), 495 [parent.a[0], parent.a[1], parent.z]) 496 497 self.assertEqual( 498 list(parent._flatten(predicate=is_member)), 499 [parent.a[0], parent.a[1], parent.z, child.a[0], child.a[1], child.z]) 500 501 def test_attribute_traversal_key(self): 502 mod = LayerModule() 503 self.assertEqual( 504 mod.variables, 505 mod._trainable_variables + mod._non_trainable_variables + [mod._bonus]) 506 507 def test_attributes_to_ignore(self): 508 class DangerousModule(module.Module): 509 _TF_MODULE_IGNORED_PROPERTIES = frozenset(itertools.chain( 510 ("dangerous_submodule", "dangerous_variable"), 511 module.Module._TF_MODULE_IGNORED_PROPERTIES 512 )) 513 514 mod = DangerousModule() 515 mod.dangerous_submodule = module.Module() 516 mod.dangerous_variable = variables.Variable(1.) 517 mod.normal_variable = variables.Variable(2.) 518 519 self.assertEmpty(mod.submodules) 520 self.assertLen(mod.variables, 1) 521 self.assertEqual(mod.variables[0], mod.normal_variable) 522 523 def test_with_path(self): 524 mod = module.Module() 525 mod.w = variables.Variable(1.) 526 mod.encoder = module.Module() 527 mod.encoder.w = [({"k": mod.w}, {"k": mod.w})] 528 mod.decoder = mod.encoder 529 530 state_dict = dict( 531 mod._flatten(with_path=True, predicate=module._is_variable)) 532 533 self.assertEqual(state_dict, 534 {("w",): mod.w, 535 ("encoder", "w", 0, 0, "k"): mod.encoder.w[0][0]["k"], 536 ("encoder", "w", 0, 1, "k"): mod.encoder.w[0][1]["k"], 537 ("decoder", "w", 0, 0, "k"): mod.decoder.w[0][0]["k"], 538 ("decoder", "w", 0, 1, "k"): mod.decoder.w[0][1]["k"]},) 539 540 def test_raises_error_with_path(self): 541 if six.PY2: 542 class NonOrderable(object): 543 __lt__ = None 544 545 non_orderable = NonOrderable 546 else: 547 non_orderable = object 548 549 m = module.Module() 550 m.layers = {non_orderable(): None, non_orderable(): None} 551 with self.assertRaisesRegex(ValueError, 552 "Error processing property 'layers'"): 553 m.variables # pylint: disable=pointless-statement 554 555 556class LayerModule(module.Module): 557 558 def __init__(self): 559 super(LayerModule, self).__init__() 560 self._trainable_variables = [ 561 variables.Variable(1., name="a"), 562 variables.Variable(2., name="b"), 563 ] 564 self._non_trainable_variables = [ 565 variables.Variable(3., name="c"), 566 variables.Variable(4., name="d"), 567 ] 568 self._bonus = variables.Variable(5., name="e") 569 570 @property 571 def variables(self): 572 def key_function(name): 573 indexes = {"_trainable_variables": 0, "_non_trainable_variables": 1} 574 return indexes.get(name, 2), name 575 576 return list( 577 self._flatten( 578 predicate=module._is_variable, 579 attribute_traversal_key=key_function)) 580 581 582class MemberType(object): 583 """A simple type to search for.""" 584 pass 585 586 587class SimpleModule(module.Module): 588 589 def __init__(self, create_child=True, container_type=list): 590 super(SimpleModule, self).__init__() 591 self.z = MemberType() 592 self.a = container_type([MemberType(), MemberType()]) 593 if create_child: 594 self.c = SimpleModule(create_child=False) 595 596is_member = lambda v: isinstance(v, MemberType) 597 598if __name__ == "__main__": 599 test.main() 600