1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Modules encapsulate building stateful components.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import re 22 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import variables 25from tensorflow.python.training.tracking import tracking 26from tensorflow.python.util import nest 27from tensorflow.python.util import tf_decorator 28from tensorflow.python.util.tf_export import tf_export 29 30 31@tf_export("Module") 32class Module(tracking.AutoTrackable): 33 """Base neural network module class. 34 35 A module is a named container for `tf.Variable`s, other `tf.Module`s and 36 functions which apply to user input. For example a dense layer in a neural 37 network might be implemented as a `tf.Module`: 38 39 >>> class Dense(tf.Module): 40 ... def __init__(self, in_features, output_features, name=None): 41 ... super(Dense, self).__init__(name=name) 42 ... self.w = tf.Variable( 43 ... tf.random_normal([input_features, output_features]), name='w') 44 ... self.b = tf.Variable(tf.zeros([output_features]), name='b') 45 ... 46 ... def __call__(self, x): 47 ... y = tf.matmul(x, self.w) + self.b 48 ... return tf.nn.relu(y) 49 50 You can use the Dense layer as you would expect: 51 52 >>> d = Dense(input_features=64, output_features=10) 53 >>> d(tf.ones([100, 64])) 54 <tf.Tensor: ...> 55 56 By subclassing `tf.Module` instead of `object` any `tf.Variable` or 57 `tf.Module` instances assigned to object properties can be collected using 58 the `variables`, `trainable_variables` or `submodules` property: 59 60 >>> d.variables 61 (<tf.Variable 'b:0' ...>, <tf.Variable 'w:0' ...>) 62 63 Subclasses of `tf.Module` can also take advantage of the `_flatten` method 64 which can be used to implement tracking of any other types. 65 66 All `tf.Module` classes have an associated `tf.name_scope` which can be used 67 to group operations in TensorBoard and create hierarchies for variable names 68 which can help with debugging. We suggest using the name scope when creating 69 nested submodules/parameters or for forward methods whose graph you might want 70 to inspect in TensorBoard. You can enter the name scope explicitly using 71 `with self.name_scope:` or you can annotate methods (apart from `__init__`) 72 with `@tf.Module.with_name_scope`. 73 74 >>> class MLP(tf.Module): 75 ... def __init__(self, input_size, sizes, name=None): 76 ... super(MLP, self).__init__(name=name) 77 ... self.layers = [] 78 ... with self.name_scope: 79 ... for size in sizes: 80 ... self.layers.append(Dense(input_size=input_size, output_size=size)) 81 ... input_size = size 82 ... 83 ... @tf.Module.with_name_scope 84 ... def __call__(self, x): 85 ... for layer in self.layers: 86 ... x = layer(x) 87 ... return x 88 """ 89 90 def __init__(self, name=None): 91 if name is None: 92 name = camel_to_snake(type(self).__name__) 93 else: 94 if not valid_identifier(name): 95 raise ValueError( 96 "%r is not a valid module name. Module names must be valid Python " 97 "identifiers (e.g. a valid class name)." % name) 98 99 self._name = name 100 with ops.name_scope(name) as scope_name: 101 self._scope_name = scope_name 102 103 @property 104 def name(self): 105 """Returns the name of this module as passed or determined in the ctor. 106 107 NOTE: This is not the same as the `self.name_scope.name` which includes 108 parent module names. 109 """ 110 return self._name 111 112 @property 113 def name_scope(self): 114 """Returns a `tf.name_scope` instance for this class.""" 115 # TODO(tomhennigan) Memoize once name scopes are re-entrant. 116 return ops.name_scope(self._scope_name) 117 118 @property 119 def variables(self): 120 """Sequence of variables owned by this module and it's submodules. 121 122 Note: this method uses reflection to find variables on the current instance 123 and submodules. For performance reasons you may wish to cache the result 124 of calling this method if you don't expect the return value to change. 125 126 Returns: 127 A sequence of variables for the current module (sorted by attribute 128 name) followed by variables from all submodules recursively (breadth 129 first). 130 """ 131 return tuple(self._flatten(predicate=_IS_VARIABLE)) 132 133 @property 134 def trainable_variables(self): 135 """Sequence of variables owned by this module and it's submodules. 136 137 Note: this method uses reflection to find variables on the current instance 138 and submodules. For performance reasons you may wish to cache the result 139 of calling this method if you don't expect the return value to change. 140 141 Returns: 142 A sequence of variables for the current module (sorted by attribute 143 name) followed by variables from all submodules recursively (breadth 144 first). 145 """ 146 return tuple(self._flatten(predicate=_IS_TRAINABLE_VARIABLE)) 147 148 @property 149 def submodules(self): 150 """Sequence of all sub-modules. 151 152 Submodules are modules which are properties of this module, or found as 153 properties of modules which are properties of this module (and so on). 154 155 >>> a = tf.Module() 156 >>> b = tf.Module() 157 >>> c = tf.Module() 158 >>> a.b = b 159 >>> b.c = c 160 >>> assert list(a.submodules) == [b, c] 161 >>> assert list(b.submodules) == [c] 162 >>> assert list(c.submodules) == [] 163 164 Returns: 165 A sequence of all submodules. 166 """ 167 return tuple(self._flatten(predicate=_IS_MODULE)) 168 169 def _flatten(self, 170 recursive=True, 171 predicate=None, 172 attribute_traversal_key=None, 173 with_path=False): 174 """Flattened attribute values in sorted order by attribute name. 175 176 Modules are flattened by first walking their attributes in name order. 177 Each attribute value is then flattened to find leaf values. If flatten is 178 to be applied `recursive`ly then if the leaf is a `Module` it will also be 179 flattened to find leaves. Finally every leaf value is optionally tested 180 against the given `predicate` and finally yielded. 181 182 >>> class Foo(tf.Module): 183 ... def __init__(self): 184 ... super(Foo, self).__init__() 185 ... self.x = [tf.constant('a'), tf.constant('b')] 186 ... self.y = {'i': tf.constant('c'), 'j': tf.constant('d')} 187 ... self.z = tf.constant('e') 188 ... 189 ... @property 190 ... def tensors(self): 191 ... return tuple(self._flatten(predicate=is_tensor, with_path=True)) 192 193 >>> foo = Foo() 194 >>> foo.tensors 195 ((('x', 0), <tf.Tensor: ...'a'>), 196 (('x', 1), <tf.Tensor: ...'b'>), 197 (('y', 'i'), <tf.Tensor: ...'c'>), 198 (('y', 'j'), <tf.Tensor: ...'d'>), 199 (('z',), <tf.Tensor: ...'e'>)) 200 201 `attribute_traversal_key` controls the order object properties are visited. 202 If not set objects are visited in ascending order by name. 203 204 Args: 205 recursive: Whether to recurse into child modules or not. 206 predicate: (Optional) If set then only values matching predicate are 207 yielded. A value of `None` (the default) means no items will be 208 filtered. 209 attribute_traversal_key: (Optional) Method to rekey object attributes 210 before they are sorted. Contract is the same as `key` argument to 211 builtin `sorted` and only applies to object properties. 212 with_path: (Optional) Whether to include the path to the object as well 213 as the object itself. If `with_path` is `True` then leaves will not be 214 de-duplicated (e.g. if the same leaf instance is reachable via multiple 215 modules then it will be yielded multiple times with different paths). 216 217 Returns: 218 Flat generator for leaves of the current module and optionally all 219 submodules. 220 """ 221 if predicate is None: 222 predicate = lambda _: True 223 224 return _flatten_module( 225 self, 226 recursive=recursive, 227 predicate=predicate, 228 attribute_traversal_key=attribute_traversal_key, 229 with_path=with_path) 230 231 @classmethod 232 def with_name_scope(cls, method): 233 """Decorator to automatically enter the module name scope. 234 235 >>> class MyModule(tf.Module): 236 ... @tf.Module.with_name_scope 237 ... def __call__(self, x): 238 ... if not hasattr(self, 'w'): 239 ... self.w = tf.Variable(tf.random.normal([x.shape[1], 64])) 240 ... return tf.matmul(x, self.w) 241 242 Using the above module would produce `tf.Variable`s and `tf.Tensor`s whose 243 names included the module name: 244 245 >>> mod = MyModule() 246 >>> mod(tf.ones([8, 32])) 247 <tf.Tensor: ...> 248 >>> mod.w 249 <tf.Variable ...'my_module/w:0'> 250 251 Args: 252 method: The method to wrap. 253 254 Returns: 255 The original method wrapped such that it enters the module's name scope. 256 """ 257 def method_with_name_scope(self, *args, **kwargs): 258 with self.name_scope: 259 return method(self, *args, **kwargs) 260 261 return tf_decorator.make_decorator(method, method_with_name_scope) 262 263 264_IS_VARIABLE = lambda o: isinstance(o, variables.Variable) 265_IS_TRAINABLE_VARIABLE = lambda o: (_IS_VARIABLE(o) and o.trainable) 266_IS_MODULE = lambda o: isinstance(o, Module) 267_CAMEL_TO_SNAKE_R = re.compile(r"((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))") 268_VALID_IDENTIFIER = re.compile(r"^[a-zA-Z_]([a-zA-Z0-9_])*$") 269 270 271def valid_identifier(name): 272 return bool(_VALID_IDENTIFIER.match(name)) 273 274 275def camel_to_snake(value): 276 return _CAMEL_TO_SNAKE_R.sub(r"_\1", value).lower() 277 278 279# AutoTrackable adds object attributes that users will not expect us to 280# include when flattening (these reference dependencies reachable via other 281# object attributes). 282AUTO_CHECKPOINTABLE_ATTRS = ("_unconditional_checkpoint_dependencies", 283 "_unconditional_dependency_names") 284 285 286def _flatten_module(module, 287 recursive, 288 predicate, 289 attribute_traversal_key, 290 with_path, 291 module_path=(), 292 seen=None): 293 """Implementation of `flatten`.""" 294 if seen is None: 295 seen = set([id(module)]) 296 297 module_dict = vars(module) 298 submodules = [] 299 300 for key in sorted(module_dict, key=attribute_traversal_key): 301 if key in AUTO_CHECKPOINTABLE_ATTRS: 302 continue 303 304 for leaf_path, leaf in nest.flatten_with_tuple_paths(module_dict[key]): 305 leaf_path = (key,) + leaf_path 306 307 # TODO(tomhennigan) Handle cycles for `with_path=True` (e.g. `a.a = a`). 308 if not with_path: 309 leaf_id = id(leaf) 310 if leaf_id in seen: 311 continue 312 seen.add(leaf_id) 313 314 if predicate(leaf): 315 if with_path: 316 yield module_path + leaf_path, leaf 317 else: 318 yield leaf 319 320 if recursive and isinstance(leaf, Module): 321 # Walk direct properties first then recurse. 322 submodules.append((module_path + leaf_path, leaf)) 323 324 for submodule_path, submodule in submodules: 325 subvalues = _flatten_module( 326 submodule, 327 recursive=recursive, 328 predicate=predicate, 329 attribute_traversal_key=attribute_traversal_key, 330 with_path=with_path, 331 module_path=submodule_path, 332 seen=seen) 333 334 for subvalue in subvalues: 335 # Predicate is already tested for these values. 336 yield subvalue 337