1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Modules encapsulate building stateful components.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import re 22 23import six 24 25from tensorflow.python import tf2 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import variables 28from tensorflow.python.training.tracking import tracking 29from tensorflow.python.util import nest 30from tensorflow.python.util import tf_decorator 31from tensorflow.python.util.tf_export import tf_export 32 33 34@tf_export("Module") 35class Module(tracking.AutoTrackable): 36 """Base neural network module class. 37 38 A module is a named container for `tf.Variable`s, other `tf.Module`s and 39 functions which apply to user input. For example a dense layer in a neural 40 network might be implemented as a `tf.Module`: 41 42 >>> class Dense(tf.Module): 43 ... def __init__(self, input_dim, output_size, name=None): 44 ... super(Dense, self).__init__(name=name) 45 ... self.w = tf.Variable( 46 ... tf.random.normal([input_dim, output_size]), name='w') 47 ... self.b = tf.Variable(tf.zeros([output_size]), name='b') 48 ... def __call__(self, x): 49 ... y = tf.matmul(x, self.w) + self.b 50 ... return tf.nn.relu(y) 51 52 You can use the Dense layer as you would expect: 53 54 >>> d = Dense(input_dim=3, output_size=2) 55 >>> d(tf.ones([1, 3])) 56 <tf.Tensor: shape=(1, 2), dtype=float32, numpy=..., dtype=float32)> 57 58 59 By subclassing `tf.Module` instead of `object` any `tf.Variable` or 60 `tf.Module` instances assigned to object properties can be collected using 61 the `variables`, `trainable_variables` or `submodules` property: 62 63 >>> d.variables 64 (<tf.Variable 'b:0' shape=(2,) dtype=float32, numpy=..., 65 dtype=float32)>, 66 <tf.Variable 'w:0' shape=(3, 2) dtype=float32, numpy=..., dtype=float32)>) 67 68 69 Subclasses of `tf.Module` can also take advantage of the `_flatten` method 70 which can be used to implement tracking of any other types. 71 72 All `tf.Module` classes have an associated `tf.name_scope` which can be used 73 to group operations in TensorBoard and create hierarchies for variable names 74 which can help with debugging. We suggest using the name scope when creating 75 nested submodules/parameters or for forward methods whose graph you might want 76 to inspect in TensorBoard. You can enter the name scope explicitly using 77 `with self.name_scope:` or you can annotate methods (apart from `__init__`) 78 with `@tf.Module.with_name_scope`. 79 80 >>> class MLP(tf.Module): 81 ... def __init__(self, input_size, sizes, name=None): 82 ... super(MLP, self).__init__(name=name) 83 ... self.layers = [] 84 ... with self.name_scope: 85 ... for size in sizes: 86 ... self.layers.append(Dense(input_dim=input_size, output_size=size)) 87 ... input_size = size 88 ... @tf.Module.with_name_scope 89 ... def __call__(self, x): 90 ... for layer in self.layers: 91 ... x = layer(x) 92 ... return x 93 94 >>> module = MLP(input_size=5, sizes=[5, 5]) 95 >>> module.variables 96 (<tf.Variable 'mlp/b:0' shape=(5,) dtype=float32, numpy=..., dtype=float32)>, 97 <tf.Variable 'mlp/w:0' shape=(5, 5) dtype=float32, numpy=..., 98 dtype=float32)>, 99 <tf.Variable 'mlp/b:0' shape=(5,) dtype=float32, numpy=..., dtype=float32)>, 100 <tf.Variable 'mlp/w:0' shape=(5, 5) dtype=float32, numpy=..., 101 dtype=float32)>) 102 """ 103 104 # AutoTrackable adds object attributes that users will not expect us to 105 # include when flattening (these reference dependencies reachable via other 106 # object attributes). 107 _TF_MODULE_IGNORED_PROPERTIES = frozenset(( 108 "_self_unconditional_checkpoint_dependencies", 109 "_self_unconditional_dependency_names" 110 )) 111 112 def __init__(self, name=None): 113 if name is None: 114 name = camel_to_snake(type(self).__name__) 115 else: 116 if not valid_identifier(name): 117 raise ValueError( 118 "%r is not a valid module name. Module names must be valid Python " 119 "identifiers (e.g. a valid class name)." % name) 120 121 self._name = name 122 if tf2.enabled(): 123 with ops.name_scope_v2(name) as scope_name: 124 self._name_scope = ops.name_scope_v2(scope_name) 125 else: 126 with ops.name_scope(name, skip_on_eager=False) as scope_name: 127 self._scope_name = scope_name 128 129 @property 130 def name(self): 131 """Returns the name of this module as passed or determined in the ctor. 132 133 NOTE: This is not the same as the `self.name_scope.name` which includes 134 parent module names. 135 """ 136 return self._name 137 138 @property 139 def name_scope(self): 140 """Returns a `tf.name_scope` instance for this class.""" 141 if tf2.enabled(): 142 return self._name_scope 143 else: 144 # In TF1 name_scope is not re-entrant in eager so we cannot memoize it. 145 return ops.name_scope(self._scope_name, skip_on_eager=False) 146 147 @property 148 def variables(self): 149 """Sequence of variables owned by this module and its submodules. 150 151 Note: this method uses reflection to find variables on the current instance 152 and submodules. For performance reasons you may wish to cache the result 153 of calling this method if you don't expect the return value to change. 154 155 Returns: 156 A sequence of variables for the current module (sorted by attribute 157 name) followed by variables from all submodules recursively (breadth 158 first). 159 """ 160 return tuple(self._flatten(predicate=_is_variable, expand_composites=True)) 161 162 @property 163 def trainable_variables(self): 164 """Sequence of trainable variables owned by this module and its submodules. 165 166 Note: this method uses reflection to find variables on the current instance 167 and submodules. For performance reasons you may wish to cache the result 168 of calling this method if you don't expect the return value to change. 169 170 Returns: 171 A sequence of variables for the current module (sorted by attribute 172 name) followed by variables from all submodules recursively (breadth 173 first). 174 """ 175 return tuple( 176 self._flatten(predicate=_is_trainable_variable, expand_composites=True)) 177 178 @property 179 def non_trainable_variables(self): 180 """Sequence of non-trainable variables owned by this module and its submodules. 181 182 Note: this method uses reflection to find variables on the current instance 183 and submodules. For performance reasons you may wish to cache the result 184 of calling this method if you don't expect the return value to change. 185 186 Returns: 187 A sequence of variables for the current module (sorted by attribute 188 name) followed by variables from all submodules recursively (breadth 189 first). 190 """ 191 return tuple(self._flatten(predicate=_is_non_trainable_variable)) 192 193 @property 194 def submodules(self): 195 """Sequence of all sub-modules. 196 197 Submodules are modules which are properties of this module, or found as 198 properties of modules which are properties of this module (and so on). 199 200 >>> a = tf.Module() 201 >>> b = tf.Module() 202 >>> c = tf.Module() 203 >>> a.b = b 204 >>> b.c = c 205 >>> list(a.submodules) == [b, c] 206 True 207 >>> list(b.submodules) == [c] 208 True 209 >>> list(c.submodules) == [] 210 True 211 212 Returns: 213 A sequence of all submodules. 214 """ 215 return tuple(self._flatten(predicate=_is_module)) 216 217 def _flatten(self, 218 recursive=True, 219 predicate=None, 220 attribute_traversal_key=None, 221 with_path=False, 222 expand_composites=False): 223 """Flattened attribute values in sorted order by attribute name. 224 225 Modules are flattened by first walking their attributes in name order. 226 Each attribute value is then flattened to find leaf values. If flatten is 227 applied `recursive`ly and if the leaf is a `Module` it will also be 228 flattened to find leaves. Finally every leaf value is optionally tested 229 against the given `predicate` and finally yielded. 230 231 ``` 232 class Foo(tf.Module): 233 def __init__(self): 234 super(Foo, self).__init__() 235 self.x = [tf.constant('a'), tf.constant('b')] 236 self.y = {'i': tf.constant('c'), 'j': tf.constant('d')} 237 self.z = tf.constant('e') 238 239 @property 240 def tensors(self): 241 return tuple(self._flatten(predicate=is_tensor, with_path=True)) 242 243 foo = Foo() 244 foo.tensors 245 # ==> ((('x', 0), <tf.Tensor: ...'a'>), 246 # (('x', 1), <tf.Tensor: ...'b'>), 247 # (('y', 'i'), <tf.Tensor: ...'c'>), 248 # (('y', 'j'), <tf.Tensor: ...'d'>), 249 # (('z',), <tf.Tensor: ...'e'>)) 250 ``` 251 252 `attribute_traversal_key` controls the order object properties are visited. 253 If not set objects are visited in ascending order by name. 254 255 Args: 256 recursive: Whether to recurse into child modules or not. 257 predicate: (Optional) If set then only values matching predicate are 258 yielded. A value of `None` (the default) means no items will be 259 filtered. 260 attribute_traversal_key: (Optional) Method to rekey object attributes 261 before they are sorted. Contract is the same as `key` argument to 262 builtin `sorted` and only applies to object properties. 263 with_path: (Optional) Whether to include the path to the object as well 264 as the object itself. If `with_path` is `True` then leaves will not be 265 de-duplicated (e.g. if the same leaf instance is reachable via multiple 266 modules then it will be yielded multiple times with different paths). 267 expand_composites: If true, then composite tensors are expanded into their 268 component tensors. 269 270 Returns: 271 Flat generator for leaves of the current module and optionally all 272 submodules. 273 """ 274 if predicate is None: 275 predicate = lambda _: True 276 277 return _flatten_module( 278 self, 279 recursive=recursive, 280 predicate=predicate, 281 attributes_to_ignore=self._TF_MODULE_IGNORED_PROPERTIES, 282 attribute_traversal_key=attribute_traversal_key, 283 with_path=with_path, 284 expand_composites=expand_composites) 285 286 @classmethod 287 def with_name_scope(cls, method): 288 """Decorator to automatically enter the module name scope. 289 290 >>> class MyModule(tf.Module): 291 ... @tf.Module.with_name_scope 292 ... def __call__(self, x): 293 ... if not hasattr(self, 'w'): 294 ... self.w = tf.Variable(tf.random.normal([x.shape[1], 3])) 295 ... return tf.matmul(x, self.w) 296 297 Using the above module would produce `tf.Variable`s and `tf.Tensor`s whose 298 names included the module name: 299 300 >>> mod = MyModule() 301 >>> mod(tf.ones([1, 2])) 302 <tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)> 303 >>> mod.w 304 <tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32, 305 numpy=..., dtype=float32)> 306 307 Args: 308 method: The method to wrap. 309 310 Returns: 311 The original method wrapped such that it enters the module's name scope. 312 """ 313 def method_with_name_scope(self, *args, **kwargs): 314 with self.name_scope: 315 return method(self, *args, **kwargs) 316 317 return tf_decorator.make_decorator(method, method_with_name_scope) 318 319 320def _is_variable(obj): 321 return isinstance(obj, variables.Variable) 322 323 324def _is_trainable_variable(obj): 325 return _is_variable(obj) and getattr(obj, "trainable", False) 326 327 328def _is_non_trainable_variable(obj): 329 return _is_variable(obj) and not getattr(obj, "trainable", False) 330 331 332def _is_module(obj): 333 return isinstance(obj, Module) 334 335_CAMEL_TO_SNAKE_R = re.compile(r"((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))") 336_VALID_IDENTIFIER = re.compile(r"^[a-zA-Z_]([a-zA-Z0-9_])*$") 337 338 339def valid_identifier(name): 340 return bool(_VALID_IDENTIFIER.match(name)) 341 342 343def camel_to_snake(value): 344 return _CAMEL_TO_SNAKE_R.sub(r"_\1", value).lower() 345 346 347def _flatten_module(module, 348 recursive, 349 predicate, 350 attribute_traversal_key, 351 attributes_to_ignore, 352 with_path, 353 expand_composites, 354 module_path=(), 355 seen=None): 356 """Implementation of `flatten`.""" 357 if seen is None: 358 seen = set([id(module)]) 359 360 module_dict = vars(module) 361 submodules = [] 362 363 for key in sorted(module_dict, key=attribute_traversal_key): 364 if key in attributes_to_ignore: 365 continue 366 367 prop = module_dict[key] 368 try: 369 leaves = nest.flatten_with_tuple_paths( 370 prop, expand_composites=expand_composites) 371 except Exception as cause: # pylint: disable=broad-except 372 six.raise_from( 373 ValueError( 374 "Error processing property {!r} of {!r}".format(key, prop)), 375 cause) 376 377 for leaf_path, leaf in leaves: 378 leaf_path = (key,) + leaf_path 379 380 # TODO(tomhennigan) Handle cycles for `with_path=True` (e.g. `a.a = a`). 381 if not with_path: 382 leaf_id = id(leaf) 383 if leaf_id in seen: 384 continue 385 seen.add(leaf_id) 386 387 if predicate(leaf): 388 if with_path: 389 yield module_path + leaf_path, leaf 390 else: 391 yield leaf 392 393 if recursive and _is_module(leaf): 394 # Walk direct properties first then recurse. 395 submodules.append((module_path + leaf_path, leaf)) 396 397 for submodule_path, submodule in submodules: 398 subvalues = _flatten_module( 399 submodule, 400 recursive=recursive, 401 predicate=predicate, 402 attribute_traversal_key=attribute_traversal_key, 403 attributes_to_ignore=submodule._TF_MODULE_IGNORED_PROPERTIES, # pylint: disable=protected-access 404 with_path=with_path, 405 expand_composites=expand_composites, 406 module_path=submodule_path, 407 seen=seen) 408 409 for subvalue in subvalues: 410 # Predicate is already tested for these values. 411 yield subvalue 412