• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Modules encapsulate building stateful components."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import re
22
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import variables
25from tensorflow.python.training.tracking import tracking
26from tensorflow.python.util import nest
27from tensorflow.python.util import tf_decorator
28from tensorflow.python.util.tf_export import tf_export
29
30
31@tf_export("Module")
32class Module(tracking.AutoTrackable):
33  """Base neural network module class.
34
35  A module is a named container for `tf.Variable`s, other `tf.Module`s and
36  functions which apply to user input. For example a dense layer in a neural
37  network might be implemented as a `tf.Module`:
38
39  >>> class Dense(tf.Module):
40  ...   def __init__(self, in_features, output_features, name=None):
41  ...     super(Dense, self).__init__(name=name)
42  ...     self.w = tf.Variable(
43  ...         tf.random_normal([input_features, output_features]), name='w')
44  ...     self.b = tf.Variable(tf.zeros([output_features]), name='b')
45  ...
46  ...   def __call__(self, x):
47  ...     y = tf.matmul(x, self.w) + self.b
48  ...     return tf.nn.relu(y)
49
50  You can use the Dense layer as you would expect:
51
52  >>> d = Dense(input_features=64, output_features=10)
53  >>> d(tf.ones([100, 64]))
54  <tf.Tensor: ...>
55
56  By subclassing `tf.Module` instead of `object` any `tf.Variable` or
57  `tf.Module` instances assigned to object properties can be collected using
58  the `variables`, `trainable_variables` or `submodules` property:
59
60  >>> d.variables
61  (<tf.Variable 'b:0' ...>, <tf.Variable 'w:0' ...>)
62
63  Subclasses of `tf.Module` can also take advantage of the `_flatten` method
64  which can be used to implement tracking of any other types.
65
66  All `tf.Module` classes have an associated `tf.name_scope` which can be used
67  to group operations in TensorBoard and create hierarchies for variable names
68  which can help with debugging. We suggest using the name scope when creating
69  nested submodules/parameters or for forward methods whose graph you might want
70  to inspect in TensorBoard. You can enter the name scope explicitly using
71  `with self.name_scope:` or you can annotate methods (apart from `__init__`)
72  with `@tf.Module.with_name_scope`.
73
74  >>> class MLP(tf.Module):
75  ...   def __init__(self, input_size, sizes, name=None):
76  ...     super(MLP, self).__init__(name=name)
77  ...     self.layers = []
78  ...     with self.name_scope:
79  ...       for size in sizes:
80  ...         self.layers.append(Dense(input_size=input_size, output_size=size))
81  ...         input_size = size
82  ...
83  ...   @tf.Module.with_name_scope
84  ...   def __call__(self, x):
85  ...     for layer in self.layers:
86  ...       x = layer(x)
87  ...     return x
88  """
89
90  def __init__(self, name=None):
91    if name is None:
92      name = camel_to_snake(type(self).__name__)
93    else:
94      if not valid_identifier(name):
95        raise ValueError(
96            "%r is not a valid module name. Module names must be valid Python "
97            "identifiers (e.g. a valid class name)." % name)
98
99    self._name = name
100    with ops.name_scope(name) as scope_name:
101      self._scope_name = scope_name
102
103  @property
104  def name(self):
105    """Returns the name of this module as passed or determined in the ctor.
106
107    NOTE: This is not the same as the `self.name_scope.name` which includes
108    parent module names.
109    """
110    return self._name
111
112  @property
113  def name_scope(self):
114    """Returns a `tf.name_scope` instance for this class."""
115    # TODO(tomhennigan) Memoize once name scopes are re-entrant.
116    return ops.name_scope(self._scope_name)
117
118  @property
119  def variables(self):
120    """Sequence of variables owned by this module and it's submodules.
121
122    Note: this method uses reflection to find variables on the current instance
123    and submodules. For performance reasons you may wish to cache the result
124    of calling this method if you don't expect the return value to change.
125
126    Returns:
127      A sequence of variables for the current module (sorted by attribute
128      name) followed by variables from all submodules recursively (breadth
129      first).
130    """
131    return tuple(self._flatten(predicate=_IS_VARIABLE))
132
133  @property
134  def trainable_variables(self):
135    """Sequence of variables owned by this module and it's submodules.
136
137    Note: this method uses reflection to find variables on the current instance
138    and submodules. For performance reasons you may wish to cache the result
139    of calling this method if you don't expect the return value to change.
140
141    Returns:
142      A sequence of variables for the current module (sorted by attribute
143      name) followed by variables from all submodules recursively (breadth
144      first).
145    """
146    return tuple(self._flatten(predicate=_IS_TRAINABLE_VARIABLE))
147
148  @property
149  def submodules(self):
150    """Sequence of all sub-modules.
151
152    Submodules are modules which are properties of this module, or found as
153    properties of modules which are properties of this module (and so on).
154
155    >>> a = tf.Module()
156    >>> b = tf.Module()
157    >>> c = tf.Module()
158    >>> a.b = b
159    >>> b.c = c
160    >>> assert list(a.submodules) == [b, c]
161    >>> assert list(b.submodules) == [c]
162    >>> assert list(c.submodules) == []
163
164    Returns:
165      A sequence of all submodules.
166    """
167    return tuple(self._flatten(predicate=_IS_MODULE))
168
169  def _flatten(self,
170               recursive=True,
171               predicate=None,
172               attribute_traversal_key=None,
173               with_path=False):
174    """Flattened attribute values in sorted order by attribute name.
175
176    Modules are flattened by first walking their attributes in name order.
177    Each attribute value is then flattened to find leaf values. If flatten is
178    to be applied `recursive`ly then if the leaf is a `Module` it will also be
179    flattened to find leaves. Finally every leaf value is optionally tested
180    against the given `predicate` and finally yielded.
181
182    >>> class Foo(tf.Module):
183    ...   def __init__(self):
184    ...     super(Foo, self).__init__()
185    ...     self.x = [tf.constant('a'), tf.constant('b')]
186    ...     self.y = {'i': tf.constant('c'), 'j': tf.constant('d')}
187    ...     self.z = tf.constant('e')
188    ...
189    ...   @property
190    ...   def tensors(self):
191    ...     return tuple(self._flatten(predicate=is_tensor, with_path=True))
192
193    >>> foo = Foo()
194    >>> foo.tensors
195    ((('x', 0),   <tf.Tensor: ...'a'>),
196     (('x', 1),   <tf.Tensor: ...'b'>),
197     (('y', 'i'), <tf.Tensor: ...'c'>),
198     (('y', 'j'), <tf.Tensor: ...'d'>),
199     (('z',),     <tf.Tensor: ...'e'>))
200
201    `attribute_traversal_key` controls the order object properties are visited.
202    If not set objects are visited in ascending order by name.
203
204    Args:
205      recursive: Whether to recurse into child modules or not.
206      predicate: (Optional) If set then only values matching predicate are
207        yielded. A value of `None` (the default) means no items will be
208        filtered.
209      attribute_traversal_key: (Optional) Method to rekey object attributes
210        before they are sorted. Contract is the same as `key` argument to
211        builtin `sorted` and only applies to object properties.
212      with_path: (Optional) Whether to include the path to the object as well
213        as the object itself. If `with_path` is `True` then leaves will not be
214        de-duplicated (e.g. if the same leaf instance is reachable via multiple
215        modules then it will be yielded multiple times with different paths).
216
217    Returns:
218      Flat generator for leaves of the current module and optionally all
219      submodules.
220    """
221    if predicate is None:
222      predicate = lambda _: True
223
224    return _flatten_module(
225        self,
226        recursive=recursive,
227        predicate=predicate,
228        attribute_traversal_key=attribute_traversal_key,
229        with_path=with_path)
230
231  @classmethod
232  def with_name_scope(cls, method):
233    """Decorator to automatically enter the module name scope.
234
235    >>> class MyModule(tf.Module):
236    ...   @tf.Module.with_name_scope
237    ...   def __call__(self, x):
238    ...     if not hasattr(self, 'w'):
239    ...       self.w = tf.Variable(tf.random.normal([x.shape[1], 64]))
240    ...     return tf.matmul(x, self.w)
241
242    Using the above module would produce `tf.Variable`s and `tf.Tensor`s whose
243    names included the module name:
244
245    >>> mod = MyModule()
246    >>> mod(tf.ones([8, 32]))
247    <tf.Tensor: ...>
248    >>> mod.w
249    <tf.Variable ...'my_module/w:0'>
250
251    Args:
252      method: The method to wrap.
253
254    Returns:
255      The original method wrapped such that it enters the module's name scope.
256    """
257    def method_with_name_scope(self, *args, **kwargs):
258      with self.name_scope:
259        return method(self, *args, **kwargs)
260
261    return tf_decorator.make_decorator(method, method_with_name_scope)
262
263
264_IS_VARIABLE = lambda o: isinstance(o, variables.Variable)
265_IS_TRAINABLE_VARIABLE = lambda o: (_IS_VARIABLE(o) and o.trainable)
266_IS_MODULE = lambda o: isinstance(o, Module)
267_CAMEL_TO_SNAKE_R = re.compile(r"((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))")
268_VALID_IDENTIFIER = re.compile(r"^[a-zA-Z_]([a-zA-Z0-9_])*$")
269
270
271def valid_identifier(name):
272  return bool(_VALID_IDENTIFIER.match(name))
273
274
275def camel_to_snake(value):
276  return _CAMEL_TO_SNAKE_R.sub(r"_\1", value).lower()
277
278
279# AutoTrackable adds object attributes that users will not expect us to
280# include when flattening (these reference dependencies reachable via other
281# object attributes).
282AUTO_CHECKPOINTABLE_ATTRS = ("_unconditional_checkpoint_dependencies",
283                             "_unconditional_dependency_names")
284
285
286def _flatten_module(module,
287                    recursive,
288                    predicate,
289                    attribute_traversal_key,
290                    with_path,
291                    module_path=(),
292                    seen=None):
293  """Implementation of `flatten`."""
294  if seen is None:
295    seen = set([id(module)])
296
297  module_dict = vars(module)
298  submodules = []
299
300  for key in sorted(module_dict, key=attribute_traversal_key):
301    if key in AUTO_CHECKPOINTABLE_ATTRS:
302      continue
303
304    for leaf_path, leaf in nest.flatten_with_tuple_paths(module_dict[key]):
305      leaf_path = (key,) + leaf_path
306
307      # TODO(tomhennigan) Handle cycles for `with_path=True` (e.g. `a.a = a`).
308      if not with_path:
309        leaf_id = id(leaf)
310        if leaf_id in seen:
311          continue
312        seen.add(leaf_id)
313
314      if predicate(leaf):
315        if with_path:
316          yield module_path + leaf_path, leaf
317        else:
318          yield leaf
319
320      if recursive and isinstance(leaf, Module):
321        # Walk direct properties first then recurse.
322        submodules.append((module_path + leaf_path, leaf))
323
324  for submodule_path, submodule in submodules:
325    subvalues = _flatten_module(
326        submodule,
327        recursive=recursive,
328        predicate=predicate,
329        attribute_traversal_key=attribute_traversal_key,
330        with_path=with_path,
331        module_path=submodule_path,
332        seen=seen)
333
334    for subvalue in subvalues:
335      # Predicate is already tested for these values.
336      yield subvalue
337