• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.Module`."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22import collections
23import itertools
24
25from absl.testing import parameterized
26import six
27
28from tensorflow.python import tf2
29from tensorflow.python.distribute import ps_values
30from tensorflow.python.distribute import tpu_values
31from tensorflow.python.distribute import values as distributed_values
32from tensorflow.python.eager import context
33from tensorflow.python.eager import def_function
34from tensorflow.python.framework import composite_tensor
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import test_util
37from tensorflow.python.framework import type_spec
38from tensorflow.python.module import module
39from tensorflow.python.ops import variables
40from tensorflow.python.platform import test
41
42
43class TestModuleNaming(test_util.TensorFlowTestCase):
44
45  def test_single_name(self):
46    mod = module.Module(name="simple")
47    self.assertEqual(mod.name, "simple")
48    self.assertEqual(mod.name_scope.name, "simple/")
49
50  def test_construct_in_scope(self):
51    with ops.name_scope("foo", skip_on_eager=False):
52      mod = module.Module(name="bar")
53    self.assertEqual(mod.name, "bar")
54    self.assertEqual(mod.name_scope.name, "foo/bar/")
55
56  def test_enters_name_scope_in_call(self):
57    mod = ReturnsNameScopeModule()
58    for _ in range(3):
59      self.assertEqual(mod(), mod.name_scope.name)
60
61  def test_enters_name_scope_in_other_method(self):
62    mod = ReturnsNameScopeModule()
63    for _ in range(3):
64      self.assertEqual(mod.alternative_forward(), mod.name_scope.name)
65
66  def test_subclassed_module(self):
67    mod = SubclassedReturnsNameScopeModule()
68    for _ in range(3):
69      self.assertEqual(mod.alternative_forward(), mod.name_scope.name)
70      self.assertEqual(mod.alternative_alternative_forward(),
71                       mod.name_scope.name)
72
73  def test_submodule_created_late(self):
74    m = TreeModule()
75    self.assertEqual(m.name, "tree_module")
76    self.assertEqual(m.name_scope.name, "tree_module/")
77    leaf1 = m.new_leaf()
78    self.assertEqual(leaf1.name, "tree_module")
79    self.assertEqual(leaf1.name_scope.name, "tree_module/tree_module/")
80
81  def test_does_not_evaluate_property_methods(self):
82    mod = PropertyThrowsWhenCalledModule()
83    with self.assertRaises(AssertionError):
84      mod.raise_assertion_error  # pylint: disable=pointless-statement
85
86  def test_overridden_name_scope(self):
87    mod = ModuleOverridingNameScope()
88    self.assertEqual(mod(), mod.name_scope.name)
89    self.assertEqual(mod.alternative_forward(), mod.name_scope.name)
90
91  def test_patched_callable(self):
92    with ops.name_scope("foo", skip_on_eager=False):
93      mod = module.Module(name="bar")
94    mod.foo = get_name_scope
95    # `foo` is not a method so we do not re-enter the name scope.
96    self.assertEqual(mod.foo(), "")
97
98  def test_property(self):
99    mod = PropertyModule()
100    mod.some_property = None, None  # None, None for the linter.
101    getter_scope_name, setter_scope_name = mod.some_property
102    self.assertEqual(getter_scope_name, "property_module/")
103    self.assertEqual(setter_scope_name, "property_module/")
104
105  def test_property_no_name_scope(self):
106    mod = PropertyModule()
107    mod.no_name_scope_property = None, None  # None, None for the linter.
108    getter_scope_name, setter_scope_name = mod.no_name_scope_property
109    self.assertEqual(getter_scope_name, "")
110    self.assertEqual(setter_scope_name, "")
111
112  def test_invalid_name(self):
113    msg = ".* is not a valid module name"
114    with self.assertRaisesRegex(ValueError, msg):
115      module.Module(name="$Foo")
116
117  @test_util.run_in_graph_and_eager_modes
118  def test_modules_not_numbered_in_eager(self):
119    if not context.executing_eagerly():
120      self.skipTest("Eager specific")
121
122    mod = RecursiveModule(2)
123    self.assertEqual(mod.name_scope.name, "badger/")
124    self.assertEqual(mod.child.name_scope.name, "badger/badger/")
125
126    mod = RecursiveModule(2)
127    self.assertEqual(mod.name_scope.name, "badger/")
128    self.assertEqual(mod.child.name_scope.name, "badger/badger/")
129
130  @test_util.run_in_graph_and_eager_modes
131  def test_module_numbering_in_graph(self):
132    if context.executing_eagerly():
133      self.skipTest("Graph specific")
134
135    mod = RecursiveModule(2)
136    self.assertEqual(mod.name_scope.name, "badger/")
137    self.assertEqual(mod.child.name_scope.name, "badger/badger/")
138
139    mod = RecursiveModule(2)
140    self.assertEqual(mod.name_scope.name, "badger_1/")
141    self.assertEqual(mod.child.name_scope.name, "badger_1/badger/")
142
143  def test_ctor_error_closes_name_scope(self):
144    with self.assertRaises(ErrorModuleError):
145      # If super constructor is called then a name scope is opened then an error
146      # is thrown. The metaclass should handle this and close the namescope
147      # before re-throwing the exception.
148      ErrorModule(call_super=True)
149
150    self.assertEqual("", get_name_scope())
151
152  def test_ctor_error_handles_ctor_not_opening_name_scope(self):
153    with self.assertRaises(ErrorModuleError):
154      # If super ctor is not called then the name scope isn't opened. We need to
155      # ensure that this doesn't trigger an exception (e.g. the metaclass trying
156      # to __exit__ a non-existent name scope).
157      ErrorModule(call_super=False)
158
159    self.assertEqual("", get_name_scope())
160
161  def test_forward_method_closes_name_scope(self):
162    mod = ErrorModule(call_super=True, raise_in_constructor=False)
163    with self.assertRaises(ErrorModuleError):
164      mod()
165
166    self.assertEqual("", get_name_scope())
167
168  def test_get_attr_doesnt_enter_name_scope(self):
169    scope_names = []
170
171    class GetAttrModule(module.Module):
172
173      def __getattr__(self, name):
174        scope_names.append((name, get_name_scope()))
175        return super(GetAttrModule, self).__getattr__(name)
176
177    mod = GetAttrModule()
178    with self.assertRaises(AttributeError):
179      mod.does_not_exist  # pylint: disable=pointless-statement
180    self.assertIn(("does_not_exist", ""), scope_names)
181
182  def test_get_attribute_doesnt_enter_name_scope(self):
183    scope_names = []
184
185    class GetAttributeModule(module.Module):
186
187      def __getattribute__(self, name):
188        scope_names.append((name, get_name_scope()))
189        return super(GetAttributeModule, self).__getattribute__(name)
190
191    mod = GetAttributeModule()
192    with self.assertRaises(AttributeError):
193      mod.does_not_exist  # pylint: disable=pointless-statement
194    self.assertIn(("does_not_exist", ""), scope_names)
195
196
197class VariableNamingTest(test_util.TensorFlowTestCase):
198
199  def test_variable_names(self):
200    mod = RecursiveModule(3)
201    self.assertEqual(mod.w.name, "badger/mushroom:0")
202    self.assertEqual(mod.child.w.name, "badger/badger/mushroom:0")
203    self.assertEqual(mod.child.child.w.name, "badger/badger/badger/mushroom:0")
204
205
206class NameScopeTest(test_util.TensorFlowTestCase):
207
208  @test_util.run_deprecated_v1
209  def test_not_memoized_in_tf1(self):
210    if tf2.enabled():
211      self.skipTest("Requires TF1")
212
213    mod = module.Module(name="name")
214    name_scope_1 = mod.name_scope
215    name_scope_2 = mod.name_scope
216    self.assertIsNot(name_scope_1, name_scope_2)
217    self.assertEqual(name_scope_1.name, name_scope_2.name)
218
219  def test_memoized_in_tf2(self):
220    if not tf2.enabled():
221      self.skipTest("Requires TF2")
222
223    mod = module.Module(name="name")
224    name_scope_1 = mod.name_scope
225    name_scope_2 = mod.name_scope
226    self.assertIs(name_scope_1, name_scope_2)
227
228
229class VariableTrackingTest(test_util.TensorFlowTestCase):
230
231  def test_variables(self):
232    m = RecursiveModule(3)
233    self.assertEqual(m.variables, (m.w, m.child.w, m.child.child.w))
234    self.assertEqual(m.child.variables, (m.child.w, m.child.child.w))
235    self.assertEqual(m.child.child.variables, (m.child.child.w,))
236
237  def test_trainable_variables(self):
238    m = RecursiveModule(3)
239    self.assertEqual(m.trainable_variables,
240                     (m.w, m.child.w, m.child.child.w))
241    self.assertEqual(m.child.trainable_variables,
242                     (m.child.w, m.child.child.w))
243    self.assertEqual(m.child.child.trainable_variables, (m.child.child.w,))
244
245  def test_trainable_variables_ignores_non_trainable(self):
246    m = RecursiveModule(3, trainable=False)
247    self.assertEqual(len(m.trainable_variables), 0)
248    self.assertEqual(len(m.child.trainable_variables), 0)
249    self.assertEqual(len(m.child.child.trainable_variables), 0)
250
251  def test_supports_distributed_variables(self):
252    mirrored = distributed_values.MirroredVariable(
253        None, [variables.Variable(1.)], variables.VariableAggregation.SUM)
254    tpu = tpu_values.TPUMirroredVariable(
255        strategy=None, values=[variables.Variable(42.)], aggregation=None)
256    aggregating = ps_values.AggregatingVariable(
257        strategy=None, v=variables.Variable(1.), aggregation=None)
258
259    m = module.Module()
260    m.a = mirrored
261    m.b = tpu
262    m.c = aggregating
263    self.assertEqual(m.variables, (mirrored, tpu, aggregating))
264
265  def test_composite_variable(self):
266
267    class Spec(type_spec.TypeSpec):
268
269      value_type = property(lambda self: CompositeVariable)
270
271      def _component_specs(self):
272        pass
273
274      def _serialize(self):
275        pass
276
277      def _to_components(self, value):
278        return value._variables
279
280      def _from_components(self, variable_list):
281        return CompositeVariable(variable_list)
282
283    class CompositeVariable(composite_tensor.CompositeTensor):
284
285      def __init__(self, variable_list):
286        self._variables = variable_list
287
288      @property
289      def _type_spec(self):
290        return Spec()
291
292    m = module.Module()
293    m.a = CompositeVariable([variables.Variable(1.), variables.Variable(2.)])
294    self.assertAllEqual(m.variables, m.a._variables)
295
296
297class ModuleTrackingTest(test_util.TensorFlowTestCase):
298
299  def test_submodules(self):
300    m = RecursiveModule(3)
301    self.assertEqual(list(m.submodules), [m.child, m.child.child])
302    self.assertEqual(list(m.child.submodules), [m.child.child])
303    self.assertEqual(list(m.child.child.submodules), [])
304
305  def test_non_ctor_submodule(self):
306    m = TreeModule()
307    leaf1 = m.new_leaf()
308    self.assertEqual(set(m.submodules), {leaf1})
309    leaf2 = m.new_leaf()
310    self.assertEqual(set(m.submodules), {leaf1, leaf2})
311
312
313class ForwardMethodsTest(test_util.TensorFlowTestCase):
314
315  def testFunctionType(self):
316    mod = ModuleWithFunctionAnnotatedCall()
317    self.assertIsInstance(mod.forward, def_function.Function)
318    self.assertIsInstance(mod.forward_ag, def_function.Function)
319
320  def testEntersNameScope_call(self):
321    mod = ModuleWithFunctionAnnotatedCall()
322    self.assertEqual(self.evaluate(mod.forward()),
323                     b"module_with_function_annotated_call/")
324    self.assertEqual(self.evaluate(mod.forward_ag()),
325                     b"module_with_function_annotated_call/")
326
327  def testEntersNameScope_concreteFunction(self):
328    mod = ModuleWithFunctionAnnotatedCall()
329    self.assertEqual(self.evaluate(mod.forward.get_concrete_function()()),
330                     b"module_with_function_annotated_call/")
331    self.assertEqual(self.evaluate(mod.forward_ag.get_concrete_function()()),
332                     b"module_with_function_annotated_call/")
333
334
335class AbcTest(test_util.TensorFlowTestCase):
336
337  def testAbstract(self):
338    msg = "Can't instantiate .* abstract methods"
339    with self.assertRaisesRegex(TypeError, msg):
340      AbstractModule()  # pylint: disable=abstract-class-instantiated
341
342  def testConcrete(self):
343    mod = ConcreteModule()
344    x, scope_name = mod(2.)
345    self.assertEqual(x, 4.)
346    self.assertEqual(scope_name, "concrete_module/")
347    self.assertEqual(get_name_scope(), "")
348
349
350def get_name_scope():
351  with ops.name_scope("x", skip_on_eager=False) as ns:
352    ns = "/".join(ns.split("/")[:-2])
353    return ns + "/" if ns else ""
354
355
356class ErrorModuleError(Exception):
357  pass
358
359
360class ErrorModule(module.Module):
361
362  def __init__(self, call_super, raise_in_constructor=True):
363    if call_super:
364      super(ErrorModule, self).__init__()
365    if raise_in_constructor:
366      raise ErrorModuleError("Deliberate error!")
367
368  def __call__(self):
369    raise ErrorModuleError("Deliberate error!")
370
371
372class RecursiveModule(module.Module):
373
374  def __init__(self, depth, trainable=True):
375    super(RecursiveModule, self).__init__(name="badger")
376    with self.name_scope:
377      self.child = None
378      if depth > 1:
379        self.child = RecursiveModule(depth - 1, trainable=trainable)
380      self.w = variables.Variable(1.0, trainable=trainable, name="mushroom")
381
382
383@six.add_metaclass(abc.ABCMeta)
384class AbstractModule(module.Module):
385
386  @abc.abstractmethod
387  def __call__(self, x):
388    pass
389
390
391class ConcreteModule(AbstractModule):
392
393  @module.Module.with_name_scope
394  def __call__(self, x):
395    return x ** 2, get_name_scope()
396
397
398class TreeModule(module.Module):
399
400  def __init__(self, name=None):
401    super(TreeModule, self).__init__(name=name)
402    self._leaves = []
403
404  @module.Module.with_name_scope
405  def new_leaf(self, name=None):
406    leaf = TreeModule(name=name)
407    self._leaves.append(leaf)
408    return leaf
409
410
411class ReturnsNameScopeModule(module.Module):
412
413  @module.Module.with_name_scope
414  def alternative_forward(self):
415    return get_name_scope()
416
417  @module.Module.with_name_scope
418  def __call__(self):
419    return get_name_scope()
420
421
422class SubclassedReturnsNameScopeModule(ReturnsNameScopeModule):
423
424  @module.Module.with_name_scope
425  def alternative_alternative_forward(self):
426    return get_name_scope()
427
428
429class PropertyThrowsWhenCalledModule(module.Module):
430
431  @property
432  def raise_assertion_error(self):
433    raise AssertionError
434
435
436class ModuleOverridingNameScope(ReturnsNameScopeModule):
437
438  @property
439  def name_scope(self):
440    return ops.name_scope("yolo/", skip_on_eager=False)
441
442
443class ModuleWithFunctionAnnotatedCall(module.Module):
444
445  @def_function.function(autograph=False)
446  @module.Module.with_name_scope
447  def forward(self):
448    return get_name_scope()
449
450  @def_function.function(autograph=True)
451  @module.Module.with_name_scope
452  def forward_ag(self):
453    return get_name_scope()
454
455
456class PropertyModule(module.Module):
457
458  def __init__(self):
459    super(PropertyModule, self).__init__()
460    self._setter_scope_name = None
461
462  @property
463  @module.Module.with_name_scope
464  def some_property(self):
465    getter_scope_name = get_name_scope()
466    return getter_scope_name, self._setter_scope_name
467
468  @some_property.setter
469  @module.Module.with_name_scope
470  def some_property(self, my_property):
471    self._setter_scope_name = get_name_scope()
472
473  @property
474  def no_name_scope_property(self):
475    getter_scope_name = get_name_scope()
476    return getter_scope_name, self._setter_scope_name
477
478  @no_name_scope_property.setter
479  def no_name_scope_property(self, my_property):
480    self._setter_scope_name = get_name_scope()
481
482NamedPair = collections.namedtuple("NamedPair", ("first", "second"))
483mk_index_dict = lambda v: dict(enumerate(v))
484
485
486class FlattenTest(parameterized.TestCase, test_util.TensorFlowTestCase):
487
488  @parameterized.parameters(lambda v: NamedPair(*v), list, tuple, mk_index_dict)
489  def test_flatten(self, container_type):
490    parent = SimpleModule(container_type=container_type)
491    child = parent.c
492
493    self.assertEqual(
494        list(parent._flatten(recursive=False, predicate=is_member)),
495        [parent.a[0], parent.a[1], parent.z])
496
497    self.assertEqual(
498        list(parent._flatten(predicate=is_member)),
499        [parent.a[0], parent.a[1], parent.z, child.a[0], child.a[1], child.z])
500
501  def test_attribute_traversal_key(self):
502    mod = LayerModule()
503    self.assertEqual(
504        mod.variables,
505        mod._trainable_variables + mod._non_trainable_variables + [mod._bonus])
506
507  def test_attributes_to_ignore(self):
508    class DangerousModule(module.Module):
509      _TF_MODULE_IGNORED_PROPERTIES = frozenset(itertools.chain(
510          ("dangerous_submodule", "dangerous_variable"),
511          module.Module._TF_MODULE_IGNORED_PROPERTIES
512      ))
513
514    mod = DangerousModule()
515    mod.dangerous_submodule = module.Module()
516    mod.dangerous_variable = variables.Variable(1.)
517    mod.normal_variable = variables.Variable(2.)
518
519    self.assertEmpty(mod.submodules)
520    self.assertLen(mod.variables, 1)
521    self.assertEqual(mod.variables[0], mod.normal_variable)
522
523  def test_with_path(self):
524    mod = module.Module()
525    mod.w = variables.Variable(1.)
526    mod.encoder = module.Module()
527    mod.encoder.w = [({"k": mod.w}, {"k": mod.w})]
528    mod.decoder = mod.encoder
529
530    state_dict = dict(
531        mod._flatten(with_path=True, predicate=module._is_variable))
532
533    self.assertEqual(state_dict,
534                     {("w",): mod.w,
535                      ("encoder", "w", 0, 0, "k"): mod.encoder.w[0][0]["k"],
536                      ("encoder", "w", 0, 1, "k"): mod.encoder.w[0][1]["k"],
537                      ("decoder", "w", 0, 0, "k"): mod.decoder.w[0][0]["k"],
538                      ("decoder", "w", 0, 1, "k"): mod.decoder.w[0][1]["k"]},)
539
540  def test_raises_error_with_path(self):
541    if six.PY2:
542      class NonOrderable(object):
543        __lt__ = None
544
545      non_orderable = NonOrderable
546    else:
547      non_orderable = object
548
549    m = module.Module()
550    m.layers = {non_orderable(): None, non_orderable(): None}
551    with self.assertRaisesRegex(ValueError,
552                                "Error processing property 'layers'"):
553      m.variables  # pylint: disable=pointless-statement
554
555
556class LayerModule(module.Module):
557
558  def __init__(self):
559    super(LayerModule, self).__init__()
560    self._trainable_variables = [
561        variables.Variable(1., name="a"),
562        variables.Variable(2., name="b"),
563    ]
564    self._non_trainable_variables = [
565        variables.Variable(3., name="c"),
566        variables.Variable(4., name="d"),
567    ]
568    self._bonus = variables.Variable(5., name="e")
569
570  @property
571  def variables(self):
572    def key_function(name):
573      indexes = {"_trainable_variables": 0, "_non_trainable_variables": 1}
574      return indexes.get(name, 2), name
575
576    return list(
577        self._flatten(
578            predicate=module._is_variable,
579            attribute_traversal_key=key_function))
580
581
582class MemberType(object):
583  """A simple type to search for."""
584  pass
585
586
587class SimpleModule(module.Module):
588
589  def __init__(self, create_child=True, container_type=list):
590    super(SimpleModule, self).__init__()
591    self.z = MemberType()
592    self.a = container_type([MemberType(), MemberType()])
593    if create_child:
594      self.c = SimpleModule(create_child=False)
595
596is_member = lambda v: isinstance(v, MemberType)
597
598if __name__ == "__main__":
599  test.main()
600