• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Modules encapsulate building stateful components."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import re
22
23import six
24
25from tensorflow.python import tf2
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import variables
28from tensorflow.python.training.tracking import tracking
29from tensorflow.python.util import nest
30from tensorflow.python.util import tf_decorator
31from tensorflow.python.util.tf_export import tf_export
32
33
34@tf_export("Module")
35class Module(tracking.AutoTrackable):
36  """Base neural network module class.
37
38  A module is a named container for `tf.Variable`s, other `tf.Module`s and
39  functions which apply to user input. For example a dense layer in a neural
40  network might be implemented as a `tf.Module`:
41
42  >>> class Dense(tf.Module):
43  ...   def __init__(self, input_dim, output_size, name=None):
44  ...     super(Dense, self).__init__(name=name)
45  ...     self.w = tf.Variable(
46  ...       tf.random.normal([input_dim, output_size]), name='w')
47  ...     self.b = tf.Variable(tf.zeros([output_size]), name='b')
48  ...   def __call__(self, x):
49  ...     y = tf.matmul(x, self.w) + self.b
50  ...     return tf.nn.relu(y)
51
52  You can use the Dense layer as you would expect:
53
54  >>> d = Dense(input_dim=3, output_size=2)
55  >>> d(tf.ones([1, 3]))
56  <tf.Tensor: shape=(1, 2), dtype=float32, numpy=..., dtype=float32)>
57
58
59  By subclassing `tf.Module` instead of `object` any `tf.Variable` or
60  `tf.Module` instances assigned to object properties can be collected using
61  the `variables`, `trainable_variables` or `submodules` property:
62
63  >>> d.variables
64      (<tf.Variable 'b:0' shape=(2,) dtype=float32, numpy=...,
65      dtype=float32)>,
66      <tf.Variable 'w:0' shape=(3, 2) dtype=float32, numpy=..., dtype=float32)>)
67
68
69  Subclasses of `tf.Module` can also take advantage of the `_flatten` method
70  which can be used to implement tracking of any other types.
71
72  All `tf.Module` classes have an associated `tf.name_scope` which can be used
73  to group operations in TensorBoard and create hierarchies for variable names
74  which can help with debugging. We suggest using the name scope when creating
75  nested submodules/parameters or for forward methods whose graph you might want
76  to inspect in TensorBoard. You can enter the name scope explicitly using
77  `with self.name_scope:` or you can annotate methods (apart from `__init__`)
78  with `@tf.Module.with_name_scope`.
79
80  >>> class MLP(tf.Module):
81  ...   def __init__(self, input_size, sizes, name=None):
82  ...     super(MLP, self).__init__(name=name)
83  ...     self.layers = []
84  ...     with self.name_scope:
85  ...       for size in sizes:
86  ...         self.layers.append(Dense(input_dim=input_size, output_size=size))
87  ...         input_size = size
88  ...   @tf.Module.with_name_scope
89  ...   def __call__(self, x):
90  ...     for layer in self.layers:
91  ...       x = layer(x)
92  ...     return x
93
94  >>> module = MLP(input_size=5, sizes=[5, 5])
95  >>> module.variables
96  (<tf.Variable 'mlp/b:0' shape=(5,) dtype=float32, numpy=..., dtype=float32)>,
97  <tf.Variable 'mlp/w:0' shape=(5, 5) dtype=float32, numpy=...,
98     dtype=float32)>,
99  <tf.Variable 'mlp/b:0' shape=(5,) dtype=float32, numpy=..., dtype=float32)>,
100  <tf.Variable 'mlp/w:0' shape=(5, 5) dtype=float32, numpy=...,
101     dtype=float32)>)
102  """
103
104  # AutoTrackable adds object attributes that users will not expect us to
105  # include when flattening (these reference dependencies reachable via other
106  # object attributes).
107  _TF_MODULE_IGNORED_PROPERTIES = frozenset((
108      "_self_unconditional_checkpoint_dependencies",
109      "_self_unconditional_dependency_names"
110  ))
111
112  def __init__(self, name=None):
113    if name is None:
114      name = camel_to_snake(type(self).__name__)
115    else:
116      if not valid_identifier(name):
117        raise ValueError(
118            "%r is not a valid module name. Module names must be valid Python "
119            "identifiers (e.g. a valid class name)." % name)
120
121    self._name = name
122    if tf2.enabled():
123      with ops.name_scope_v2(name) as scope_name:
124        self._name_scope = ops.name_scope_v2(scope_name)
125    else:
126      with ops.name_scope(name, skip_on_eager=False) as scope_name:
127        self._scope_name = scope_name
128
129  @property
130  def name(self):
131    """Returns the name of this module as passed or determined in the ctor.
132
133    NOTE: This is not the same as the `self.name_scope.name` which includes
134    parent module names.
135    """
136    return self._name
137
138  @property
139  def name_scope(self):
140    """Returns a `tf.name_scope` instance for this class."""
141    if tf2.enabled():
142      return self._name_scope
143    else:
144      # In TF1 name_scope is not re-entrant in eager so we cannot memoize it.
145      return ops.name_scope(self._scope_name, skip_on_eager=False)
146
147  @property
148  def variables(self):
149    """Sequence of variables owned by this module and its submodules.
150
151    Note: this method uses reflection to find variables on the current instance
152    and submodules. For performance reasons you may wish to cache the result
153    of calling this method if you don't expect the return value to change.
154
155    Returns:
156      A sequence of variables for the current module (sorted by attribute
157      name) followed by variables from all submodules recursively (breadth
158      first).
159    """
160    return tuple(self._flatten(predicate=_is_variable, expand_composites=True))
161
162  @property
163  def trainable_variables(self):
164    """Sequence of trainable variables owned by this module and its submodules.
165
166    Note: this method uses reflection to find variables on the current instance
167    and submodules. For performance reasons you may wish to cache the result
168    of calling this method if you don't expect the return value to change.
169
170    Returns:
171      A sequence of variables for the current module (sorted by attribute
172      name) followed by variables from all submodules recursively (breadth
173      first).
174    """
175    return tuple(
176        self._flatten(predicate=_is_trainable_variable, expand_composites=True))
177
178  @property
179  def non_trainable_variables(self):
180    """Sequence of non-trainable variables owned by this module and its submodules.
181
182    Note: this method uses reflection to find variables on the current instance
183    and submodules. For performance reasons you may wish to cache the result
184    of calling this method if you don't expect the return value to change.
185
186    Returns:
187      A sequence of variables for the current module (sorted by attribute
188      name) followed by variables from all submodules recursively (breadth
189      first).
190    """
191    return tuple(self._flatten(predicate=_is_non_trainable_variable))
192
193  @property
194  def submodules(self):
195    """Sequence of all sub-modules.
196
197    Submodules are modules which are properties of this module, or found as
198    properties of modules which are properties of this module (and so on).
199
200    >>> a = tf.Module()
201    >>> b = tf.Module()
202    >>> c = tf.Module()
203    >>> a.b = b
204    >>> b.c = c
205    >>> list(a.submodules) == [b, c]
206    True
207    >>> list(b.submodules) == [c]
208    True
209    >>> list(c.submodules) == []
210    True
211
212    Returns:
213      A sequence of all submodules.
214    """
215    return tuple(self._flatten(predicate=_is_module))
216
217  def _flatten(self,
218               recursive=True,
219               predicate=None,
220               attribute_traversal_key=None,
221               with_path=False,
222               expand_composites=False):
223    """Flattened attribute values in sorted order by attribute name.
224
225    Modules are flattened by first walking their attributes in name order.
226    Each attribute value is then flattened to find leaf values. If flatten is
227    applied `recursive`ly and if the leaf is a `Module` it will also be
228    flattened to find leaves. Finally every leaf value is optionally tested
229    against the given `predicate` and finally yielded.
230
231    ```
232    class Foo(tf.Module):
233      def __init__(self):
234        super(Foo, self).__init__()
235        self.x = [tf.constant('a'), tf.constant('b')]
236        self.y = {'i': tf.constant('c'), 'j': tf.constant('d')}
237        self.z = tf.constant('e')
238
239      @property
240      def tensors(self):
241        return tuple(self._flatten(predicate=is_tensor, with_path=True))
242
243    foo = Foo()
244    foo.tensors
245    # ==> ((('x', 0),   <tf.Tensor: ...'a'>),
246    #     (('x', 1),   <tf.Tensor: ...'b'>),
247    #     (('y', 'i'), <tf.Tensor: ...'c'>),
248    #     (('y', 'j'), <tf.Tensor: ...'d'>),
249    #     (('z',),     <tf.Tensor: ...'e'>))
250    ```
251
252    `attribute_traversal_key` controls the order object properties are visited.
253    If not set objects are visited in ascending order by name.
254
255    Args:
256      recursive: Whether to recurse into child modules or not.
257      predicate: (Optional) If set then only values matching predicate are
258        yielded. A value of `None` (the default) means no items will be
259        filtered.
260      attribute_traversal_key: (Optional) Method to rekey object attributes
261        before they are sorted. Contract is the same as `key` argument to
262        builtin `sorted` and only applies to object properties.
263      with_path: (Optional) Whether to include the path to the object as well
264        as the object itself. If `with_path` is `True` then leaves will not be
265        de-duplicated (e.g. if the same leaf instance is reachable via multiple
266        modules then it will be yielded multiple times with different paths).
267      expand_composites: If true, then composite tensors are expanded into their
268        component tensors.
269
270    Returns:
271      Flat generator for leaves of the current module and optionally all
272      submodules.
273    """
274    if predicate is None:
275      predicate = lambda _: True
276
277    return _flatten_module(
278        self,
279        recursive=recursive,
280        predicate=predicate,
281        attributes_to_ignore=self._TF_MODULE_IGNORED_PROPERTIES,
282        attribute_traversal_key=attribute_traversal_key,
283        with_path=with_path,
284        expand_composites=expand_composites)
285
286  @classmethod
287  def with_name_scope(cls, method):
288    """Decorator to automatically enter the module name scope.
289
290    >>> class MyModule(tf.Module):
291    ...   @tf.Module.with_name_scope
292    ...   def __call__(self, x):
293    ...     if not hasattr(self, 'w'):
294    ...       self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
295    ...     return tf.matmul(x, self.w)
296
297    Using the above module would produce `tf.Variable`s and `tf.Tensor`s whose
298    names included the module name:
299
300    >>> mod = MyModule()
301    >>> mod(tf.ones([1, 2]))
302    <tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
303    >>> mod.w
304    <tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
305    numpy=..., dtype=float32)>
306
307    Args:
308      method: The method to wrap.
309
310    Returns:
311      The original method wrapped such that it enters the module's name scope.
312    """
313    def method_with_name_scope(self, *args, **kwargs):
314      with self.name_scope:
315        return method(self, *args, **kwargs)
316
317    return tf_decorator.make_decorator(method, method_with_name_scope)
318
319
320def _is_variable(obj):
321  return isinstance(obj, variables.Variable)
322
323
324def _is_trainable_variable(obj):
325  return _is_variable(obj) and getattr(obj, "trainable", False)
326
327
328def _is_non_trainable_variable(obj):
329  return _is_variable(obj) and not getattr(obj, "trainable", False)
330
331
332def _is_module(obj):
333  return isinstance(obj, Module)
334
335_CAMEL_TO_SNAKE_R = re.compile(r"((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))")
336_VALID_IDENTIFIER = re.compile(r"^[a-zA-Z_]([a-zA-Z0-9_])*$")
337
338
339def valid_identifier(name):
340  return bool(_VALID_IDENTIFIER.match(name))
341
342
343def camel_to_snake(value):
344  return _CAMEL_TO_SNAKE_R.sub(r"_\1", value).lower()
345
346
347def _flatten_module(module,
348                    recursive,
349                    predicate,
350                    attribute_traversal_key,
351                    attributes_to_ignore,
352                    with_path,
353                    expand_composites,
354                    module_path=(),
355                    seen=None):
356  """Implementation of `flatten`."""
357  if seen is None:
358    seen = set([id(module)])
359
360  module_dict = vars(module)
361  submodules = []
362
363  for key in sorted(module_dict, key=attribute_traversal_key):
364    if key in attributes_to_ignore:
365      continue
366
367    prop = module_dict[key]
368    try:
369      leaves = nest.flatten_with_tuple_paths(
370          prop, expand_composites=expand_composites)
371    except Exception as cause:  # pylint: disable=broad-except
372      six.raise_from(
373          ValueError(
374              "Error processing property {!r} of {!r}".format(key, prop)),
375          cause)
376
377    for leaf_path, leaf in leaves:
378      leaf_path = (key,) + leaf_path
379
380      # TODO(tomhennigan) Handle cycles for `with_path=True` (e.g. `a.a = a`).
381      if not with_path:
382        leaf_id = id(leaf)
383        if leaf_id in seen:
384          continue
385        seen.add(leaf_id)
386
387      if predicate(leaf):
388        if with_path:
389          yield module_path + leaf_path, leaf
390        else:
391          yield leaf
392
393      if recursive and _is_module(leaf):
394        # Walk direct properties first then recurse.
395        submodules.append((module_path + leaf_path, leaf))
396
397  for submodule_path, submodule in submodules:
398    subvalues = _flatten_module(
399        submodule,
400        recursive=recursive,
401        predicate=predicate,
402        attribute_traversal_key=attribute_traversal_key,
403        attributes_to_ignore=submodule._TF_MODULE_IGNORED_PROPERTIES,  # pylint: disable=protected-access
404        with_path=with_path,
405        expand_composites=expand_composites,
406        module_path=submodule_path,
407        seen=seen)
408
409    for subvalue in subvalues:
410      # Predicate is already tested for these values.
411      yield subvalue
412