1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for exporting TensorFlow symbols to the API. 16 17Exporting a function or a class: 18 19To export a function or a class use tf_export decorator. For e.g.: 20```python 21@tf_export('foo', 'bar.foo') 22def foo(...): 23 ... 24``` 25 26If a function is assigned to a variable, you can export it by calling 27tf_export explicitly. For e.g.: 28```python 29foo = get_foo(...) 30tf_export('foo', 'bar.foo')(foo) 31``` 32 33 34Exporting a constant 35```python 36foo = 1 37tf_export('consts.foo').export_constant(__name__, 'foo') 38``` 39""" 40import collections 41import functools 42import sys 43 44from tensorflow.python.util import tf_decorator 45from tensorflow.python.util import tf_inspect 46 47ESTIMATOR_API_NAME = 'estimator' 48KERAS_API_NAME = 'keras' 49TENSORFLOW_API_NAME = 'tensorflow' 50 51# List of subpackage names used by TensorFlow components. Have to check that 52# TensorFlow core repo does not export any symbols under these names. 53SUBPACKAGE_NAMESPACES = [ESTIMATOR_API_NAME] 54 55_Attributes = collections.namedtuple( 56 'ExportedApiAttributes', ['names', 'constants']) 57 58# Attribute values must be unique to each API. 59API_ATTRS = { 60 TENSORFLOW_API_NAME: _Attributes( 61 '_tf_api_names', 62 '_tf_api_constants'), 63 ESTIMATOR_API_NAME: _Attributes( 64 '_estimator_api_names', 65 '_estimator_api_constants'), 66 KERAS_API_NAME: _Attributes( 67 '_keras_api_names', 68 '_keras_api_constants') 69} 70 71API_ATTRS_V1 = { 72 TENSORFLOW_API_NAME: _Attributes( 73 '_tf_api_names_v1', 74 '_tf_api_constants_v1'), 75 ESTIMATOR_API_NAME: _Attributes( 76 '_estimator_api_names_v1', 77 '_estimator_api_constants_v1'), 78 KERAS_API_NAME: _Attributes( 79 '_keras_api_names_v1', 80 '_keras_api_constants_v1') 81} 82 83 84class SymbolAlreadyExposedError(Exception): 85 """Raised when adding API names to symbol that already has API names.""" 86 pass 87 88 89class InvalidSymbolNameError(Exception): 90 """Raised when trying to export symbol as an invalid or unallowed name.""" 91 pass 92 93_NAME_TO_SYMBOL_MAPPING = dict() 94 95 96def get_symbol_from_name(name): 97 return _NAME_TO_SYMBOL_MAPPING.get(name) 98 99 100def get_canonical_name_for_symbol( 101 symbol, api_name=TENSORFLOW_API_NAME, 102 add_prefix_to_v1_names=False): 103 """Get canonical name for the API symbol. 104 105 Example: 106 ```python 107 from tensorflow.python.util import tf_export 108 cls = tf_export.get_symbol_from_name('keras.optimizers.Adam') 109 110 # Gives `<class 'keras.optimizer_v2.adam.Adam'>` 111 print(cls) 112 113 # Gives `keras.optimizers.Adam` 114 print(tf_export.get_canonical_name_for_symbol(cls, api_name='keras')) 115 ``` 116 117 Args: 118 symbol: API function or class. 119 api_name: API name (tensorflow or estimator). 120 add_prefix_to_v1_names: Specifies whether a name available only in V1 121 should be prefixed with compat.v1. 122 123 Returns: 124 Canonical name for the API symbol (for e.g. initializers.zeros) if 125 canonical name could be determined. Otherwise, returns None. 126 """ 127 if not hasattr(symbol, '__dict__'): 128 return None 129 api_names_attr = API_ATTRS[api_name].names 130 _, undecorated_symbol = tf_decorator.unwrap(symbol) 131 if api_names_attr not in undecorated_symbol.__dict__: 132 return None 133 api_names = getattr(undecorated_symbol, api_names_attr) 134 deprecated_api_names = undecorated_symbol.__dict__.get( 135 '_tf_deprecated_api_names', []) 136 137 canonical_name = get_canonical_name(api_names, deprecated_api_names) 138 if canonical_name: 139 return canonical_name 140 141 # If there is no V2 canonical name, get V1 canonical name. 142 api_names_attr = API_ATTRS_V1[api_name].names 143 api_names = getattr(undecorated_symbol, api_names_attr) 144 v1_canonical_name = get_canonical_name(api_names, deprecated_api_names) 145 if add_prefix_to_v1_names: 146 return 'compat.v1.%s' % v1_canonical_name 147 return v1_canonical_name 148 149 150def get_canonical_name(api_names, deprecated_api_names): 151 """Get preferred endpoint name. 152 153 Args: 154 api_names: API names iterable. 155 deprecated_api_names: Deprecated API names iterable. 156 Returns: 157 Returns one of the following in decreasing preference: 158 - first non-deprecated endpoint 159 - first endpoint 160 - None 161 """ 162 non_deprecated_name = next( 163 (name for name in api_names if name not in deprecated_api_names), 164 None) 165 if non_deprecated_name: 166 return non_deprecated_name 167 if api_names: 168 return api_names[0] 169 return None 170 171 172def get_v1_names(symbol): 173 """Get a list of TF 1.* names for this symbol. 174 175 Args: 176 symbol: symbol to get API names for. 177 178 Returns: 179 List of all API names for this symbol including TensorFlow and 180 Estimator names. 181 """ 182 names_v1 = [] 183 tensorflow_api_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].names 184 estimator_api_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].names 185 keras_api_attr_v1 = API_ATTRS_V1[KERAS_API_NAME].names 186 187 if not hasattr(symbol, '__dict__'): 188 return names_v1 189 if tensorflow_api_attr_v1 in symbol.__dict__: 190 names_v1.extend(getattr(symbol, tensorflow_api_attr_v1)) 191 if estimator_api_attr_v1 in symbol.__dict__: 192 names_v1.extend(getattr(symbol, estimator_api_attr_v1)) 193 if keras_api_attr_v1 in symbol.__dict__: 194 names_v1.extend(getattr(symbol, keras_api_attr_v1)) 195 return names_v1 196 197 198def get_v2_names(symbol): 199 """Get a list of TF 2.0 names for this symbol. 200 201 Args: 202 symbol: symbol to get API names for. 203 204 Returns: 205 List of all API names for this symbol including TensorFlow and 206 Estimator names. 207 """ 208 names_v2 = [] 209 tensorflow_api_attr = API_ATTRS[TENSORFLOW_API_NAME].names 210 estimator_api_attr = API_ATTRS[ESTIMATOR_API_NAME].names 211 keras_api_attr = API_ATTRS[KERAS_API_NAME].names 212 213 if not hasattr(symbol, '__dict__'): 214 return names_v2 215 if tensorflow_api_attr in symbol.__dict__: 216 names_v2.extend(getattr(symbol, tensorflow_api_attr)) 217 if estimator_api_attr in symbol.__dict__: 218 names_v2.extend(getattr(symbol, estimator_api_attr)) 219 if keras_api_attr in symbol.__dict__: 220 names_v2.extend(getattr(symbol, keras_api_attr)) 221 return names_v2 222 223 224def get_v1_constants(module): 225 """Get a list of TF 1.* constants in this module. 226 227 Args: 228 module: TensorFlow module. 229 230 Returns: 231 List of all API constants under the given module including TensorFlow and 232 Estimator constants. 233 """ 234 constants_v1 = [] 235 tensorflow_constants_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].constants 236 estimator_constants_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].constants 237 238 if hasattr(module, tensorflow_constants_attr_v1): 239 constants_v1.extend(getattr(module, tensorflow_constants_attr_v1)) 240 if hasattr(module, estimator_constants_attr_v1): 241 constants_v1.extend(getattr(module, estimator_constants_attr_v1)) 242 return constants_v1 243 244 245def get_v2_constants(module): 246 """Get a list of TF 2.0 constants in this module. 247 248 Args: 249 module: TensorFlow module. 250 251 Returns: 252 List of all API constants under the given module including TensorFlow and 253 Estimator constants. 254 """ 255 constants_v2 = [] 256 tensorflow_constants_attr = API_ATTRS[TENSORFLOW_API_NAME].constants 257 estimator_constants_attr = API_ATTRS[ESTIMATOR_API_NAME].constants 258 259 if hasattr(module, tensorflow_constants_attr): 260 constants_v2.extend(getattr(module, tensorflow_constants_attr)) 261 if hasattr(module, estimator_constants_attr): 262 constants_v2.extend(getattr(module, estimator_constants_attr)) 263 return constants_v2 264 265 266class api_export(object): # pylint: disable=invalid-name 267 """Provides ways to export symbols to the TensorFlow API.""" 268 269 def __init__(self, *args, **kwargs): # pylint: disable=g-doc-args 270 """Export under the names *args (first one is considered canonical). 271 272 Args: 273 *args: API names in dot delimited format. 274 **kwargs: Optional keyed arguments. 275 v1: Names for the TensorFlow V1 API. If not set, we will use V2 API 276 names both for TensorFlow V1 and V2 APIs. 277 overrides: List of symbols that this is overriding 278 (those overrided api exports will be removed). Note: passing overrides 279 has no effect on exporting a constant. 280 api_name: Name of the API you want to generate (e.g. `tensorflow` or 281 `estimator`). Default is `tensorflow`. 282 allow_multiple_exports: Allow symbol to be exported multiple time under 283 different names. 284 """ 285 self._names = args 286 self._names_v1 = kwargs.get('v1', args) 287 if 'v2' in kwargs: 288 raise ValueError('You passed a "v2" argument to tf_export. This is not ' 289 'what you want. Pass v2 names directly as positional ' 290 'arguments instead.') 291 self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME) 292 self._overrides = kwargs.get('overrides', []) 293 self._allow_multiple_exports = kwargs.get('allow_multiple_exports', False) 294 295 self._validate_symbol_names() 296 297 def _validate_symbol_names(self): 298 """Validate you are exporting symbols under an allowed package. 299 300 We need to ensure things exported by tf_export, estimator_export, etc. 301 export symbols under disjoint top-level package names. 302 303 For TensorFlow, we check that it does not export anything under subpackage 304 names used by components (estimator, keras, etc.). 305 306 For each component, we check that it exports everything under its own 307 subpackage. 308 309 Raises: 310 InvalidSymbolNameError: If you try to export symbol under disallowed name. 311 """ 312 all_symbol_names = set(self._names) | set(self._names_v1) 313 if self._api_name == TENSORFLOW_API_NAME: 314 for subpackage in SUBPACKAGE_NAMESPACES: 315 if any(n.startswith(subpackage) for n in all_symbol_names): 316 raise InvalidSymbolNameError( 317 '@tf_export is not allowed to export symbols under %s.*' % ( 318 subpackage)) 319 else: 320 if not all(n.startswith(self._api_name) for n in all_symbol_names): 321 raise InvalidSymbolNameError( 322 'Can only export symbols under package name of component. ' 323 'e.g. tensorflow_estimator must export all symbols under ' 324 'tf.estimator') 325 326 def __call__(self, func): 327 """Calls this decorator. 328 329 Args: 330 func: decorated symbol (function or class). 331 332 Returns: 333 The input function with _tf_api_names attribute set. 334 335 Raises: 336 SymbolAlreadyExposedError: Raised when a symbol already has API names 337 and kwarg `allow_multiple_exports` not set. 338 """ 339 api_names_attr = API_ATTRS[self._api_name].names 340 api_names_attr_v1 = API_ATTRS_V1[self._api_name].names 341 # Undecorate overridden names 342 for f in self._overrides: 343 _, undecorated_f = tf_decorator.unwrap(f) 344 delattr(undecorated_f, api_names_attr) 345 delattr(undecorated_f, api_names_attr_v1) 346 347 _, undecorated_func = tf_decorator.unwrap(func) 348 self.set_attr(undecorated_func, api_names_attr, self._names) 349 self.set_attr(undecorated_func, api_names_attr_v1, self._names_v1) 350 351 for name in self._names: 352 _NAME_TO_SYMBOL_MAPPING[name] = func 353 for name_v1 in self._names_v1: 354 _NAME_TO_SYMBOL_MAPPING['compat.v1.%s' % name_v1] = func 355 return func 356 357 def set_attr(self, func, api_names_attr, names): 358 # Check for an existing api. We check if attribute name is in 359 # __dict__ instead of using hasattr to verify that subclasses have 360 # their own _tf_api_names as opposed to just inheriting it. 361 if api_names_attr in func.__dict__: 362 if not self._allow_multiple_exports: 363 raise SymbolAlreadyExposedError( 364 'Symbol %s is already exposed as %s.' % 365 (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access 366 setattr(func, api_names_attr, names) 367 368 def export_constant(self, module_name, name): 369 """Store export information for constants/string literals. 370 371 Export information is stored in the module where constants/string literals 372 are defined. 373 374 e.g. 375 ```python 376 foo = 1 377 bar = 2 378 tf_export("consts.foo").export_constant(__name__, 'foo') 379 tf_export("consts.bar").export_constant(__name__, 'bar') 380 ``` 381 382 Args: 383 module_name: (string) Name of the module to store constant at. 384 name: (string) Current constant name. 385 """ 386 module = sys.modules[module_name] 387 api_constants_attr = API_ATTRS[self._api_name].constants 388 api_constants_attr_v1 = API_ATTRS_V1[self._api_name].constants 389 390 if not hasattr(module, api_constants_attr): 391 setattr(module, api_constants_attr, []) 392 # pylint: disable=protected-access 393 getattr(module, api_constants_attr).append( 394 (self._names, name)) 395 396 if not hasattr(module, api_constants_attr_v1): 397 setattr(module, api_constants_attr_v1, []) 398 getattr(module, api_constants_attr_v1).append( 399 (self._names_v1, name)) 400 401 402def kwarg_only(f): 403 """A wrapper that throws away all non-kwarg arguments.""" 404 f_argspec = tf_inspect.getargspec(f) 405 406 def wrapper(*args, **kwargs): 407 if args: 408 raise TypeError( 409 '{f} only takes keyword args (possible keys: {kwargs}). ' 410 'Please pass these args as kwargs instead.' 411 .format(f=f.__name__, kwargs=f_argspec.args)) 412 return f(**kwargs) 413 414 return tf_decorator.make_decorator(f, wrapper, decorator_argspec=f_argspec) 415 416 417tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME) 418estimator_export = functools.partial(api_export, api_name=ESTIMATOR_API_NAME) 419keras_export = functools.partial(api_export, api_name=KERAS_API_NAME) 420