• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.Module`."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22import collections
23
24from absl.testing import parameterized
25import six
26
27from tensorflow.python.compat import v2_compat
28from tensorflow.python.eager import def_function
29from tensorflow.python.framework import ops
30from tensorflow.python.module import module
31from tensorflow.python.ops import variables
32from tensorflow.python.platform import test
33
34
35class TestModuleNaming(test.TestCase):
36
37  def test_single_name(self):
38    mod = module.Module(name="simple")
39    self.assertEqual(mod.name, "simple")
40    self.assertEqual(mod.name_scope.name, "simple/")
41
42  def test_construct_in_scope(self):
43    with ops.name_scope("foo"):
44      mod = module.Module(name="bar")
45    self.assertEqual(mod.name, "bar")
46    self.assertEqual(mod.name_scope.name, "foo/bar/")
47
48  def test_enters_name_scope_in_call(self):
49    mod = ReturnsNameScopeModule()
50    for _ in range(3):
51      self.assertEqual(mod(), mod.name_scope.name)
52
53  def test_enters_name_scope_in_other_method(self):
54    mod = ReturnsNameScopeModule()
55    for _ in range(3):
56      self.assertEqual(mod.alternative_forward(), mod.name_scope.name)
57
58  def test_subclassed_module(self):
59    mod = SubclassedReturnsNameScopeModule()
60    for _ in range(3):
61      self.assertEqual(mod.alternative_forward(), mod.name_scope.name)
62      self.assertEqual(mod.alternative_alternative_forward(),
63                       mod.name_scope.name)
64
65  def test_submodule_created_late(self):
66    m = TreeModule()
67    self.assertEqual(m.name, "tree_module")
68    self.assertEqual(m.name_scope.name, "tree_module/")
69    leaf1 = m.new_leaf()
70    self.assertEqual(leaf1.name, "tree_module")
71    self.assertEqual(leaf1.name_scope.name, "tree_module/tree_module/")
72
73  def test_does_not_evaluate_property_methods(self):
74    mod = PropertyThrowsWhenCalledModule()
75    with self.assertRaises(AssertionError):
76      mod.raise_assertion_error  # pylint: disable=pointless-statement
77
78  def test_overridden_name_scope(self):
79    mod = ModuleOverridingNameScope()
80    self.assertEqual(mod(), mod.name_scope.name)
81    self.assertEqual(mod.alternative_forward(), mod.name_scope.name)
82
83  def test_patched_callable(self):
84    with ops.name_scope("foo"):
85      mod = module.Module(name="bar")
86    mod.foo = get_name_scope
87    # `foo` is not a method so we do not re-enter the name scope.
88    self.assertEqual(mod.foo(), "")
89
90  def test_property(self):
91    mod = PropertyModule()
92    mod.some_property = None, None  # None, None for the linter.
93    getter_scope_name, setter_scope_name = mod.some_property
94    self.assertEqual(getter_scope_name, "property_module/")
95    self.assertEqual(setter_scope_name, "property_module/")
96
97  def test_property_no_name_scope(self):
98    mod = PropertyModule()
99    mod.no_name_scope_property = None, None  # None, None for the linter.
100    getter_scope_name, setter_scope_name = mod.no_name_scope_property
101    self.assertEqual(getter_scope_name, "")
102    self.assertEqual(setter_scope_name, "")
103
104  def test_invalid_name(self):
105    msg = ".* is not a valid module name"
106    with self.assertRaisesRegexp(ValueError, msg):
107      module.Module(name="$Foo")
108
109  def test_modules_not_numbered_in_eager(self):
110    mod = RecursiveModule(2)
111    self.assertEqual(mod.name_scope.name, "badger/")
112    self.assertEqual(mod.child.name_scope.name, "badger/badger/")
113
114    mod = RecursiveModule(2)
115    self.assertEqual(mod.name_scope.name, "badger/")
116    self.assertEqual(mod.child.name_scope.name, "badger/badger/")
117
118  def test_module_numbering_in_graph(self):
119    with ops.Graph().as_default():
120      mod = RecursiveModule(2)
121      self.assertEqual(mod.name_scope.name, "badger/")
122      self.assertEqual(mod.child.name_scope.name, "badger/badger/")
123
124      mod = RecursiveModule(2)
125      self.assertEqual(mod.name_scope.name, "badger_1/")
126      self.assertEqual(mod.child.name_scope.name, "badger_1/badger/")
127
128  def test_ctor_error_closes_name_scope(self):
129    with self.assertRaises(ErrorModuleError):
130      # If super constructor is called then a name scope is opened then an error
131      # is thrown. The metaclass should handle this and close the namescope
132      # before re-throwing the exception.
133      ErrorModule(call_super=True)
134
135    self.assertEqual("", get_name_scope())
136
137  def test_ctor_error_handles_ctor_not_opening_name_scope(self):
138    with self.assertRaises(ErrorModuleError):
139      # If super ctor is not called then the name scope isn't opened. We need to
140      # ensure that this doesn't trigger an exception (e.g. the metaclass trying
141      # to __exit__ a non-existant name scope).
142      ErrorModule(call_super=False)
143
144    self.assertEqual("", get_name_scope())
145
146  def test_forward_method_closes_name_scope(self):
147    mod = ErrorModule(call_super=True, raise_in_constructor=False)
148    with self.assertRaises(ErrorModuleError):
149      mod()
150
151    self.assertEqual("", get_name_scope())
152
153  def test_get_attr_doesnt_enter_name_scope(self):
154    scope_names = []
155
156    class GetAttrModule(module.Module):
157
158      def __getattr__(self, name):
159        scope_names.append((name, get_name_scope()))
160        return super(GetAttrModule, self).__getattr__(name)
161
162    mod = GetAttrModule()
163    with self.assertRaises(AttributeError):
164      mod.does_not_exist  # pylint: disable=pointless-statement
165    self.assertIn(("does_not_exist", ""), scope_names)
166
167  def test_get_attribute_doesnt_enter_name_scope(self):
168    scope_names = []
169
170    class GetAttributeModule(module.Module):
171
172      def __getattribute__(self, name):
173        scope_names.append((name, get_name_scope()))
174        return super(GetAttributeModule, self).__getattribute__(name)
175
176    mod = GetAttributeModule()
177    with self.assertRaises(AttributeError):
178      mod.does_not_exist  # pylint: disable=pointless-statement
179    self.assertIn(("does_not_exist", ""), scope_names)
180
181
182class VariableNamingTest(test.TestCase):
183
184  def test_variable_names(self):
185    mod = RecursiveModule(3)
186    self.assertEqual(mod.w.name, "badger/mushroom:0")
187    self.assertEqual(mod.child.w.name, "badger/badger/mushroom:0")
188    self.assertEqual(mod.child.child.w.name, "badger/badger/badger/mushroom:0")
189
190
191class VariableTrackingTest(test.TestCase):
192
193  def test_variables(self):
194    m = RecursiveModule(3)
195    self.assertEqual(m.variables, (m.w, m.child.w, m.child.child.w))
196    self.assertEqual(m.child.variables, (m.child.w, m.child.child.w))
197    self.assertEqual(m.child.child.variables, (m.child.child.w,))
198
199  def test_trainable_variables(self):
200    m = RecursiveModule(3)
201    self.assertEqual(m.trainable_variables,
202                     (m.w, m.child.w, m.child.child.w))
203    self.assertEqual(m.child.trainable_variables,
204                     (m.child.w, m.child.child.w))
205    self.assertEqual(m.child.child.trainable_variables, (m.child.child.w,))
206
207  def test_trainable_variables_ignores_non_trainable(self):
208    m = RecursiveModule(3, trainable=False)
209    self.assertEqual(len(m.trainable_variables), 0)
210    self.assertEqual(len(m.child.trainable_variables), 0)
211    self.assertEqual(len(m.child.child.trainable_variables), 0)
212
213
214class ModuleTrackingTest(test.TestCase):
215
216  def test_submodules(self):
217    m = RecursiveModule(3)
218    self.assertEqual(list(m.submodules), [m.child, m.child.child])
219    self.assertEqual(list(m.child.submodules), [m.child.child])
220    self.assertEqual(list(m.child.child.submodules), [])
221
222  def test_non_ctor_submodule(self):
223    m = TreeModule()
224    leaf1 = m.new_leaf()
225    self.assertEqual(set(m.submodules), {leaf1})
226    leaf2 = m.new_leaf()
227    self.assertEqual(set(m.submodules), {leaf1, leaf2})
228
229
230class ForwardMethodsTest(test.TestCase):
231
232  def testFunctionType(self):
233    mod = ModuleWithFunctionAnnotatedCall()
234    self.assertTrue(isinstance(mod.forward, def_function.Function))
235    self.assertTrue(isinstance(mod.forward_ag, def_function.Function))
236
237  def testEntersNameScope_call(self):
238    mod = ModuleWithFunctionAnnotatedCall()
239    self.assertEqual(mod.forward().numpy(),
240                     b"module_with_function_annotated_call/")
241    self.assertEqual(mod.forward_ag().numpy(),
242                     b"module_with_function_annotated_call/")
243
244  def testEntersNameScope_concreteFunction(self):
245    mod = ModuleWithFunctionAnnotatedCall()
246    self.assertEqual(mod.forward.get_concrete_function()().numpy(),
247                     b"module_with_function_annotated_call/")
248    self.assertEqual(mod.forward_ag.get_concrete_function()().numpy(),
249                     b"module_with_function_annotated_call/")
250
251
252class AbcTest(test.TestCase):
253
254  def testAbstract(self):
255    msg = "Can't instantiate .* abstract methods"
256    with self.assertRaisesRegexp(TypeError, msg):
257      AbstractModule()  # pylint: disable=abstract-class-instantiated
258
259  def testConcrete(self):
260    mod = ConcreteModule()
261    x, scope_name = mod(2.)
262    self.assertEqual(x, 4.)
263    self.assertEqual(scope_name, "concrete_module/")
264    self.assertEqual(get_name_scope(), "")
265
266
267def get_name_scope():
268  with ops.name_scope("x") as ns:
269    return ns[:-2]
270
271
272class ErrorModuleError(Exception):
273  pass
274
275
276class ErrorModule(module.Module):
277
278  def __init__(self, call_super, raise_in_constructor=True):
279    if call_super:
280      super(ErrorModule, self).__init__()
281    if raise_in_constructor:
282      raise ErrorModuleError("Deliberate error!")
283
284  def __call__(self):
285    raise ErrorModuleError("Deliberate error!")
286
287
288class RecursiveModule(module.Module):
289
290  def __init__(self, depth, trainable=True):
291    super(RecursiveModule, self).__init__(name="badger")
292    with self.name_scope:
293      self.child = None
294      if depth > 1:
295        self.child = RecursiveModule(depth - 1, trainable=trainable)
296      self.w = variables.Variable(1.0, trainable=trainable, name="mushroom")
297
298
299@six.add_metaclass(abc.ABCMeta)
300class AbstractModule(module.Module):
301
302  @abc.abstractmethod
303  def __call__(self, x):
304    pass
305
306
307class ConcreteModule(AbstractModule):
308
309  @module.Module.with_name_scope
310  def __call__(self, x):
311    return x ** 2, get_name_scope()
312
313
314class TreeModule(module.Module):
315
316  def __init__(self, name=None):
317    super(TreeModule, self).__init__(name=name)
318    self._leaves = []
319
320  @module.Module.with_name_scope
321  def new_leaf(self, name=None):
322    leaf = TreeModule(name=name)
323    self._leaves.append(leaf)
324    return leaf
325
326
327class ReturnsNameScopeModule(module.Module):
328
329  @module.Module.with_name_scope
330  def alternative_forward(self):
331    return get_name_scope()
332
333  @module.Module.with_name_scope
334  def __call__(self):
335    return get_name_scope()
336
337
338class SubclassedReturnsNameScopeModule(ReturnsNameScopeModule):
339
340  @module.Module.with_name_scope
341  def alternative_alternative_forward(self):
342    return get_name_scope()
343
344
345class PropertyThrowsWhenCalledModule(module.Module):
346
347  @property
348  def raise_assertion_error(self):
349    raise AssertionError
350
351
352class ModuleOverridingNameScope(ReturnsNameScopeModule):
353
354  @property
355  def name_scope(self):
356    return ops.name_scope("yolo/")
357
358
359class ModuleWithFunctionAnnotatedCall(module.Module):
360
361  @def_function.function(autograph=False)
362  @module.Module.with_name_scope
363  def forward(self):
364    return get_name_scope()
365
366  @def_function.function(autograph=True)
367  @module.Module.with_name_scope
368  def forward_ag(self):
369    return get_name_scope()
370
371
372class PropertyModule(module.Module):
373
374  def __init__(self):
375    super(PropertyModule, self).__init__()
376    self._setter_scope_name = None
377
378  @property
379  @module.Module.with_name_scope
380  def some_property(self):
381    getter_scope_name = get_name_scope()
382    return getter_scope_name, self._setter_scope_name
383
384  @some_property.setter
385  @module.Module.with_name_scope
386  def some_property(self, my_property):
387    self._setter_scope_name = get_name_scope()
388
389  @property
390  def no_name_scope_property(self):
391    getter_scope_name = get_name_scope()
392    return getter_scope_name, self._setter_scope_name
393
394  @no_name_scope_property.setter
395  def no_name_scope_property(self, my_property):
396    self._setter_scope_name = get_name_scope()
397
398NamedPair = collections.namedtuple("NamedPair", ("first", "second"))
399mk_index_dict = lambda v: dict(enumerate(v))
400
401
402class FlattenTest(parameterized.TestCase, test.TestCase):
403
404  @parameterized.parameters(lambda v: NamedPair(*v), list, tuple, mk_index_dict)
405  def test_flatten(self, container_type):
406    parent = SimpleModule(container_type=container_type)
407    child = parent.c
408
409    self.assertEqual(
410        list(parent._flatten(recursive=False, predicate=IS_MEMBER)),
411        [parent.a[0], parent.a[1], parent.z])
412
413    self.assertEqual(
414        list(parent._flatten(predicate=IS_MEMBER)),
415        [parent.a[0], parent.a[1], parent.z, child.a[0], child.a[1], child.z])
416
417  def test_attribute_traversal_key(self):
418    mod = LayerModule()
419    self.assertEqual(
420        mod.variables,
421        mod._trainable_variables + mod._non_trainable_variables + [mod._bonus])
422
423  def test_with_path(self):
424    mod = module.Module()
425    mod.w = variables.Variable(1.)
426    mod.encoder = module.Module()
427    mod.encoder.w = [({"k": mod.w}, {"k": mod.w})]
428    mod.decoder = mod.encoder
429
430    state_dict = dict(
431        mod._flatten(with_path=True, predicate=module._IS_VARIABLE))
432
433    self.assertEqual(state_dict,
434                     {("w",): mod.w,
435                      ("encoder", "w", 0, 0, "k"): mod.encoder.w[0][0]["k"],
436                      ("encoder", "w", 0, 1, "k"): mod.encoder.w[0][1]["k"],
437                      ("decoder", "w", 0, 0, "k"): mod.decoder.w[0][0]["k"],
438                      ("decoder", "w", 0, 1, "k"): mod.decoder.w[0][1]["k"]},)
439
440
441class LayerModule(module.Module):
442
443  def __init__(self):
444    super(LayerModule, self).__init__()
445    self._trainable_variables = [
446        variables.Variable(1., name="a"),
447        variables.Variable(2., name="b"),
448    ]
449    self._non_trainable_variables = [
450        variables.Variable(3., name="c"),
451        variables.Variable(4., name="d"),
452    ]
453    self._bonus = variables.Variable(5., name="e")
454
455  @property
456  def variables(self):
457    def key_function(name):
458      indexes = {"_trainable_variables": 0, "_non_trainable_variables": 1}
459      return indexes.get(name, 2), name
460
461    return list(self._flatten(predicate=module._IS_VARIABLE,
462                              attribute_traversal_key=key_function))
463
464
465class MemberType(object):
466  """A simple type to search for."""
467  pass
468
469
470class SimpleModule(module.Module):
471
472  def __init__(self, create_child=True, container_type=list):
473    super(SimpleModule, self).__init__()
474    self.z = MemberType()
475    self.a = container_type([MemberType(), MemberType()])
476    if create_child:
477      self.c = SimpleModule(create_child=False)
478
479
480IS_MEMBER = lambda v: isinstance(v, MemberType)
481IS_MODULE = lambda v: isinstance(v, module.Module)
482
483if __name__ == "__main__":
484  v2_compat.enable_v2_behavior()
485  test.main()
486