1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tf.Module`.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22import collections 23 24from absl.testing import parameterized 25import six 26 27from tensorflow.python.compat import v2_compat 28from tensorflow.python.eager import def_function 29from tensorflow.python.framework import ops 30from tensorflow.python.module import module 31from tensorflow.python.ops import variables 32from tensorflow.python.platform import test 33 34 35class TestModuleNaming(test.TestCase): 36 37 def test_single_name(self): 38 mod = module.Module(name="simple") 39 self.assertEqual(mod.name, "simple") 40 self.assertEqual(mod.name_scope.name, "simple/") 41 42 def test_construct_in_scope(self): 43 with ops.name_scope("foo"): 44 mod = module.Module(name="bar") 45 self.assertEqual(mod.name, "bar") 46 self.assertEqual(mod.name_scope.name, "foo/bar/") 47 48 def test_enters_name_scope_in_call(self): 49 mod = ReturnsNameScopeModule() 50 for _ in range(3): 51 self.assertEqual(mod(), mod.name_scope.name) 52 53 def test_enters_name_scope_in_other_method(self): 54 mod = ReturnsNameScopeModule() 55 for _ in range(3): 56 self.assertEqual(mod.alternative_forward(), mod.name_scope.name) 57 58 def test_subclassed_module(self): 59 mod = SubclassedReturnsNameScopeModule() 60 for _ in range(3): 61 self.assertEqual(mod.alternative_forward(), mod.name_scope.name) 62 self.assertEqual(mod.alternative_alternative_forward(), 63 mod.name_scope.name) 64 65 def test_submodule_created_late(self): 66 m = TreeModule() 67 self.assertEqual(m.name, "tree_module") 68 self.assertEqual(m.name_scope.name, "tree_module/") 69 leaf1 = m.new_leaf() 70 self.assertEqual(leaf1.name, "tree_module") 71 self.assertEqual(leaf1.name_scope.name, "tree_module/tree_module/") 72 73 def test_does_not_evaluate_property_methods(self): 74 mod = PropertyThrowsWhenCalledModule() 75 with self.assertRaises(AssertionError): 76 mod.raise_assertion_error # pylint: disable=pointless-statement 77 78 def test_overridden_name_scope(self): 79 mod = ModuleOverridingNameScope() 80 self.assertEqual(mod(), mod.name_scope.name) 81 self.assertEqual(mod.alternative_forward(), mod.name_scope.name) 82 83 def test_patched_callable(self): 84 with ops.name_scope("foo"): 85 mod = module.Module(name="bar") 86 mod.foo = get_name_scope 87 # `foo` is not a method so we do not re-enter the name scope. 88 self.assertEqual(mod.foo(), "") 89 90 def test_property(self): 91 mod = PropertyModule() 92 mod.some_property = None, None # None, None for the linter. 93 getter_scope_name, setter_scope_name = mod.some_property 94 self.assertEqual(getter_scope_name, "property_module/") 95 self.assertEqual(setter_scope_name, "property_module/") 96 97 def test_property_no_name_scope(self): 98 mod = PropertyModule() 99 mod.no_name_scope_property = None, None # None, None for the linter. 100 getter_scope_name, setter_scope_name = mod.no_name_scope_property 101 self.assertEqual(getter_scope_name, "") 102 self.assertEqual(setter_scope_name, "") 103 104 def test_invalid_name(self): 105 msg = ".* is not a valid module name" 106 with self.assertRaisesRegexp(ValueError, msg): 107 module.Module(name="$Foo") 108 109 def test_modules_not_numbered_in_eager(self): 110 mod = RecursiveModule(2) 111 self.assertEqual(mod.name_scope.name, "badger/") 112 self.assertEqual(mod.child.name_scope.name, "badger/badger/") 113 114 mod = RecursiveModule(2) 115 self.assertEqual(mod.name_scope.name, "badger/") 116 self.assertEqual(mod.child.name_scope.name, "badger/badger/") 117 118 def test_module_numbering_in_graph(self): 119 with ops.Graph().as_default(): 120 mod = RecursiveModule(2) 121 self.assertEqual(mod.name_scope.name, "badger/") 122 self.assertEqual(mod.child.name_scope.name, "badger/badger/") 123 124 mod = RecursiveModule(2) 125 self.assertEqual(mod.name_scope.name, "badger_1/") 126 self.assertEqual(mod.child.name_scope.name, "badger_1/badger/") 127 128 def test_ctor_error_closes_name_scope(self): 129 with self.assertRaises(ErrorModuleError): 130 # If super constructor is called then a name scope is opened then an error 131 # is thrown. The metaclass should handle this and close the namescope 132 # before re-throwing the exception. 133 ErrorModule(call_super=True) 134 135 self.assertEqual("", get_name_scope()) 136 137 def test_ctor_error_handles_ctor_not_opening_name_scope(self): 138 with self.assertRaises(ErrorModuleError): 139 # If super ctor is not called then the name scope isn't opened. We need to 140 # ensure that this doesn't trigger an exception (e.g. the metaclass trying 141 # to __exit__ a non-existant name scope). 142 ErrorModule(call_super=False) 143 144 self.assertEqual("", get_name_scope()) 145 146 def test_forward_method_closes_name_scope(self): 147 mod = ErrorModule(call_super=True, raise_in_constructor=False) 148 with self.assertRaises(ErrorModuleError): 149 mod() 150 151 self.assertEqual("", get_name_scope()) 152 153 def test_get_attr_doesnt_enter_name_scope(self): 154 scope_names = [] 155 156 class GetAttrModule(module.Module): 157 158 def __getattr__(self, name): 159 scope_names.append((name, get_name_scope())) 160 return super(GetAttrModule, self).__getattr__(name) 161 162 mod = GetAttrModule() 163 with self.assertRaises(AttributeError): 164 mod.does_not_exist # pylint: disable=pointless-statement 165 self.assertIn(("does_not_exist", ""), scope_names) 166 167 def test_get_attribute_doesnt_enter_name_scope(self): 168 scope_names = [] 169 170 class GetAttributeModule(module.Module): 171 172 def __getattribute__(self, name): 173 scope_names.append((name, get_name_scope())) 174 return super(GetAttributeModule, self).__getattribute__(name) 175 176 mod = GetAttributeModule() 177 with self.assertRaises(AttributeError): 178 mod.does_not_exist # pylint: disable=pointless-statement 179 self.assertIn(("does_not_exist", ""), scope_names) 180 181 182class VariableNamingTest(test.TestCase): 183 184 def test_variable_names(self): 185 mod = RecursiveModule(3) 186 self.assertEqual(mod.w.name, "badger/mushroom:0") 187 self.assertEqual(mod.child.w.name, "badger/badger/mushroom:0") 188 self.assertEqual(mod.child.child.w.name, "badger/badger/badger/mushroom:0") 189 190 191class VariableTrackingTest(test.TestCase): 192 193 def test_variables(self): 194 m = RecursiveModule(3) 195 self.assertEqual(m.variables, (m.w, m.child.w, m.child.child.w)) 196 self.assertEqual(m.child.variables, (m.child.w, m.child.child.w)) 197 self.assertEqual(m.child.child.variables, (m.child.child.w,)) 198 199 def test_trainable_variables(self): 200 m = RecursiveModule(3) 201 self.assertEqual(m.trainable_variables, 202 (m.w, m.child.w, m.child.child.w)) 203 self.assertEqual(m.child.trainable_variables, 204 (m.child.w, m.child.child.w)) 205 self.assertEqual(m.child.child.trainable_variables, (m.child.child.w,)) 206 207 def test_trainable_variables_ignores_non_trainable(self): 208 m = RecursiveModule(3, trainable=False) 209 self.assertEqual(len(m.trainable_variables), 0) 210 self.assertEqual(len(m.child.trainable_variables), 0) 211 self.assertEqual(len(m.child.child.trainable_variables), 0) 212 213 214class ModuleTrackingTest(test.TestCase): 215 216 def test_submodules(self): 217 m = RecursiveModule(3) 218 self.assertEqual(list(m.submodules), [m.child, m.child.child]) 219 self.assertEqual(list(m.child.submodules), [m.child.child]) 220 self.assertEqual(list(m.child.child.submodules), []) 221 222 def test_non_ctor_submodule(self): 223 m = TreeModule() 224 leaf1 = m.new_leaf() 225 self.assertEqual(set(m.submodules), {leaf1}) 226 leaf2 = m.new_leaf() 227 self.assertEqual(set(m.submodules), {leaf1, leaf2}) 228 229 230class ForwardMethodsTest(test.TestCase): 231 232 def testFunctionType(self): 233 mod = ModuleWithFunctionAnnotatedCall() 234 self.assertTrue(isinstance(mod.forward, def_function.Function)) 235 self.assertTrue(isinstance(mod.forward_ag, def_function.Function)) 236 237 def testEntersNameScope_call(self): 238 mod = ModuleWithFunctionAnnotatedCall() 239 self.assertEqual(mod.forward().numpy(), 240 b"module_with_function_annotated_call/") 241 self.assertEqual(mod.forward_ag().numpy(), 242 b"module_with_function_annotated_call/") 243 244 def testEntersNameScope_concreteFunction(self): 245 mod = ModuleWithFunctionAnnotatedCall() 246 self.assertEqual(mod.forward.get_concrete_function()().numpy(), 247 b"module_with_function_annotated_call/") 248 self.assertEqual(mod.forward_ag.get_concrete_function()().numpy(), 249 b"module_with_function_annotated_call/") 250 251 252class AbcTest(test.TestCase): 253 254 def testAbstract(self): 255 msg = "Can't instantiate .* abstract methods" 256 with self.assertRaisesRegexp(TypeError, msg): 257 AbstractModule() # pylint: disable=abstract-class-instantiated 258 259 def testConcrete(self): 260 mod = ConcreteModule() 261 x, scope_name = mod(2.) 262 self.assertEqual(x, 4.) 263 self.assertEqual(scope_name, "concrete_module/") 264 self.assertEqual(get_name_scope(), "") 265 266 267def get_name_scope(): 268 with ops.name_scope("x") as ns: 269 return ns[:-2] 270 271 272class ErrorModuleError(Exception): 273 pass 274 275 276class ErrorModule(module.Module): 277 278 def __init__(self, call_super, raise_in_constructor=True): 279 if call_super: 280 super(ErrorModule, self).__init__() 281 if raise_in_constructor: 282 raise ErrorModuleError("Deliberate error!") 283 284 def __call__(self): 285 raise ErrorModuleError("Deliberate error!") 286 287 288class RecursiveModule(module.Module): 289 290 def __init__(self, depth, trainable=True): 291 super(RecursiveModule, self).__init__(name="badger") 292 with self.name_scope: 293 self.child = None 294 if depth > 1: 295 self.child = RecursiveModule(depth - 1, trainable=trainable) 296 self.w = variables.Variable(1.0, trainable=trainable, name="mushroom") 297 298 299@six.add_metaclass(abc.ABCMeta) 300class AbstractModule(module.Module): 301 302 @abc.abstractmethod 303 def __call__(self, x): 304 pass 305 306 307class ConcreteModule(AbstractModule): 308 309 @module.Module.with_name_scope 310 def __call__(self, x): 311 return x ** 2, get_name_scope() 312 313 314class TreeModule(module.Module): 315 316 def __init__(self, name=None): 317 super(TreeModule, self).__init__(name=name) 318 self._leaves = [] 319 320 @module.Module.with_name_scope 321 def new_leaf(self, name=None): 322 leaf = TreeModule(name=name) 323 self._leaves.append(leaf) 324 return leaf 325 326 327class ReturnsNameScopeModule(module.Module): 328 329 @module.Module.with_name_scope 330 def alternative_forward(self): 331 return get_name_scope() 332 333 @module.Module.with_name_scope 334 def __call__(self): 335 return get_name_scope() 336 337 338class SubclassedReturnsNameScopeModule(ReturnsNameScopeModule): 339 340 @module.Module.with_name_scope 341 def alternative_alternative_forward(self): 342 return get_name_scope() 343 344 345class PropertyThrowsWhenCalledModule(module.Module): 346 347 @property 348 def raise_assertion_error(self): 349 raise AssertionError 350 351 352class ModuleOverridingNameScope(ReturnsNameScopeModule): 353 354 @property 355 def name_scope(self): 356 return ops.name_scope("yolo/") 357 358 359class ModuleWithFunctionAnnotatedCall(module.Module): 360 361 @def_function.function(autograph=False) 362 @module.Module.with_name_scope 363 def forward(self): 364 return get_name_scope() 365 366 @def_function.function(autograph=True) 367 @module.Module.with_name_scope 368 def forward_ag(self): 369 return get_name_scope() 370 371 372class PropertyModule(module.Module): 373 374 def __init__(self): 375 super(PropertyModule, self).__init__() 376 self._setter_scope_name = None 377 378 @property 379 @module.Module.with_name_scope 380 def some_property(self): 381 getter_scope_name = get_name_scope() 382 return getter_scope_name, self._setter_scope_name 383 384 @some_property.setter 385 @module.Module.with_name_scope 386 def some_property(self, my_property): 387 self._setter_scope_name = get_name_scope() 388 389 @property 390 def no_name_scope_property(self): 391 getter_scope_name = get_name_scope() 392 return getter_scope_name, self._setter_scope_name 393 394 @no_name_scope_property.setter 395 def no_name_scope_property(self, my_property): 396 self._setter_scope_name = get_name_scope() 397 398NamedPair = collections.namedtuple("NamedPair", ("first", "second")) 399mk_index_dict = lambda v: dict(enumerate(v)) 400 401 402class FlattenTest(parameterized.TestCase, test.TestCase): 403 404 @parameterized.parameters(lambda v: NamedPair(*v), list, tuple, mk_index_dict) 405 def test_flatten(self, container_type): 406 parent = SimpleModule(container_type=container_type) 407 child = parent.c 408 409 self.assertEqual( 410 list(parent._flatten(recursive=False, predicate=IS_MEMBER)), 411 [parent.a[0], parent.a[1], parent.z]) 412 413 self.assertEqual( 414 list(parent._flatten(predicate=IS_MEMBER)), 415 [parent.a[0], parent.a[1], parent.z, child.a[0], child.a[1], child.z]) 416 417 def test_attribute_traversal_key(self): 418 mod = LayerModule() 419 self.assertEqual( 420 mod.variables, 421 mod._trainable_variables + mod._non_trainable_variables + [mod._bonus]) 422 423 def test_with_path(self): 424 mod = module.Module() 425 mod.w = variables.Variable(1.) 426 mod.encoder = module.Module() 427 mod.encoder.w = [({"k": mod.w}, {"k": mod.w})] 428 mod.decoder = mod.encoder 429 430 state_dict = dict( 431 mod._flatten(with_path=True, predicate=module._IS_VARIABLE)) 432 433 self.assertEqual(state_dict, 434 {("w",): mod.w, 435 ("encoder", "w", 0, 0, "k"): mod.encoder.w[0][0]["k"], 436 ("encoder", "w", 0, 1, "k"): mod.encoder.w[0][1]["k"], 437 ("decoder", "w", 0, 0, "k"): mod.decoder.w[0][0]["k"], 438 ("decoder", "w", 0, 1, "k"): mod.decoder.w[0][1]["k"]},) 439 440 441class LayerModule(module.Module): 442 443 def __init__(self): 444 super(LayerModule, self).__init__() 445 self._trainable_variables = [ 446 variables.Variable(1., name="a"), 447 variables.Variable(2., name="b"), 448 ] 449 self._non_trainable_variables = [ 450 variables.Variable(3., name="c"), 451 variables.Variable(4., name="d"), 452 ] 453 self._bonus = variables.Variable(5., name="e") 454 455 @property 456 def variables(self): 457 def key_function(name): 458 indexes = {"_trainable_variables": 0, "_non_trainable_variables": 1} 459 return indexes.get(name, 2), name 460 461 return list(self._flatten(predicate=module._IS_VARIABLE, 462 attribute_traversal_key=key_function)) 463 464 465class MemberType(object): 466 """A simple type to search for.""" 467 pass 468 469 470class SimpleModule(module.Module): 471 472 def __init__(self, create_child=True, container_type=list): 473 super(SimpleModule, self).__init__() 474 self.z = MemberType() 475 self.a = container_type([MemberType(), MemberType()]) 476 if create_child: 477 self.c = SimpleModule(create_child=False) 478 479 480IS_MEMBER = lambda v: isinstance(v, MemberType) 481IS_MODULE = lambda v: isinstance(v, module.Module) 482 483if __name__ == "__main__": 484 v2_compat.enable_v2_behavior() 485 test.main() 486