1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for exporting TensorFlow symbols to the API. 16 17Exporting a function or a class: 18 19To export a function or a class use tf_export decorator. For e.g.: 20```python 21@tf_export('foo', 'bar.foo') 22def foo(...): 23 ... 24``` 25 26If a function is assigned to a variable, you can export it by calling 27tf_export explicitly. For e.g.: 28```python 29foo = get_foo(...) 30tf_export('foo', 'bar.foo')(foo) 31``` 32 33 34Exporting a constant 35```python 36foo = 1 37tf_export('consts.foo').export_constant(__name__, 'foo') 38``` 39""" 40from __future__ import absolute_import 41from __future__ import division 42from __future__ import print_function 43 44import collections 45import functools 46import sys 47 48from tensorflow.python.util import tf_decorator 49from tensorflow.python.util import tf_inspect 50 51ESTIMATOR_API_NAME = 'estimator' 52KERAS_API_NAME = 'keras' 53TENSORFLOW_API_NAME = 'tensorflow' 54 55# List of subpackage names used by TensorFlow components. Have to check that 56# TensorFlow core repo does not export any symbols under these names. 57SUBPACKAGE_NAMESPACES = [ESTIMATOR_API_NAME] 58 59_Attributes = collections.namedtuple( 60 'ExportedApiAttributes', ['names', 'constants']) 61 62# Attribute values must be unique to each API. 63API_ATTRS = { 64 TENSORFLOW_API_NAME: _Attributes( 65 '_tf_api_names', 66 '_tf_api_constants'), 67 ESTIMATOR_API_NAME: _Attributes( 68 '_estimator_api_names', 69 '_estimator_api_constants'), 70 KERAS_API_NAME: _Attributes( 71 '_keras_api_names', 72 '_keras_api_constants') 73} 74 75API_ATTRS_V1 = { 76 TENSORFLOW_API_NAME: _Attributes( 77 '_tf_api_names_v1', 78 '_tf_api_constants_v1'), 79 ESTIMATOR_API_NAME: _Attributes( 80 '_estimator_api_names_v1', 81 '_estimator_api_constants_v1'), 82 KERAS_API_NAME: _Attributes( 83 '_keras_api_names_v1', 84 '_keras_api_constants_v1') 85} 86 87 88class SymbolAlreadyExposedError(Exception): 89 """Raised when adding API names to symbol that already has API names.""" 90 pass 91 92 93class InvalidSymbolNameError(Exception): 94 """Raised when trying to export symbol as an invalid or unallowed name.""" 95 pass 96 97_NAME_TO_SYMBOL_MAPPING = dict() 98 99 100def get_symbol_from_name(name): 101 return _NAME_TO_SYMBOL_MAPPING.get(name) 102 103 104def get_canonical_name_for_symbol( 105 symbol, api_name=TENSORFLOW_API_NAME, 106 add_prefix_to_v1_names=False): 107 """Get canonical name for the API symbol. 108 109 Args: 110 symbol: API function or class. 111 api_name: API name (tensorflow or estimator). 112 add_prefix_to_v1_names: Specifies whether a name available only in V1 113 should be prefixed with compat.v1. 114 115 Returns: 116 Canonical name for the API symbol (for e.g. initializers.zeros) if 117 canonical name could be determined. Otherwise, returns None. 118 """ 119 if not hasattr(symbol, '__dict__'): 120 return None 121 api_names_attr = API_ATTRS[api_name].names 122 _, undecorated_symbol = tf_decorator.unwrap(symbol) 123 if api_names_attr not in undecorated_symbol.__dict__: 124 return None 125 api_names = getattr(undecorated_symbol, api_names_attr) 126 deprecated_api_names = undecorated_symbol.__dict__.get( 127 '_tf_deprecated_api_names', []) 128 129 canonical_name = get_canonical_name(api_names, deprecated_api_names) 130 if canonical_name: 131 return canonical_name 132 133 # If there is no V2 canonical name, get V1 canonical name. 134 api_names_attr = API_ATTRS_V1[api_name].names 135 api_names = getattr(undecorated_symbol, api_names_attr) 136 v1_canonical_name = get_canonical_name(api_names, deprecated_api_names) 137 if add_prefix_to_v1_names: 138 return 'compat.v1.%s' % v1_canonical_name 139 return v1_canonical_name 140 141 142def get_canonical_name(api_names, deprecated_api_names): 143 """Get preferred endpoint name. 144 145 Args: 146 api_names: API names iterable. 147 deprecated_api_names: Deprecated API names iterable. 148 Returns: 149 Returns one of the following in decreasing preference: 150 - first non-deprecated endpoint 151 - first endpoint 152 - None 153 """ 154 non_deprecated_name = next( 155 (name for name in api_names if name not in deprecated_api_names), 156 None) 157 if non_deprecated_name: 158 return non_deprecated_name 159 if api_names: 160 return api_names[0] 161 return None 162 163 164def get_v1_names(symbol): 165 """Get a list of TF 1.* names for this symbol. 166 167 Args: 168 symbol: symbol to get API names for. 169 170 Returns: 171 List of all API names for this symbol including TensorFlow and 172 Estimator names. 173 """ 174 names_v1 = [] 175 tensorflow_api_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].names 176 estimator_api_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].names 177 keras_api_attr_v1 = API_ATTRS_V1[KERAS_API_NAME].names 178 179 if not hasattr(symbol, '__dict__'): 180 return names_v1 181 if tensorflow_api_attr_v1 in symbol.__dict__: 182 names_v1.extend(getattr(symbol, tensorflow_api_attr_v1)) 183 if estimator_api_attr_v1 in symbol.__dict__: 184 names_v1.extend(getattr(symbol, estimator_api_attr_v1)) 185 if keras_api_attr_v1 in symbol.__dict__: 186 names_v1.extend(getattr(symbol, keras_api_attr_v1)) 187 return names_v1 188 189 190def get_v2_names(symbol): 191 """Get a list of TF 2.0 names for this symbol. 192 193 Args: 194 symbol: symbol to get API names for. 195 196 Returns: 197 List of all API names for this symbol including TensorFlow and 198 Estimator names. 199 """ 200 names_v2 = [] 201 tensorflow_api_attr = API_ATTRS[TENSORFLOW_API_NAME].names 202 estimator_api_attr = API_ATTRS[ESTIMATOR_API_NAME].names 203 keras_api_attr = API_ATTRS[KERAS_API_NAME].names 204 205 if not hasattr(symbol, '__dict__'): 206 return names_v2 207 if tensorflow_api_attr in symbol.__dict__: 208 names_v2.extend(getattr(symbol, tensorflow_api_attr)) 209 if estimator_api_attr in symbol.__dict__: 210 names_v2.extend(getattr(symbol, estimator_api_attr)) 211 if keras_api_attr in symbol.__dict__: 212 names_v2.extend(getattr(symbol, keras_api_attr)) 213 return names_v2 214 215 216def get_v1_constants(module): 217 """Get a list of TF 1.* constants in this module. 218 219 Args: 220 module: TensorFlow module. 221 222 Returns: 223 List of all API constants under the given module including TensorFlow and 224 Estimator constants. 225 """ 226 constants_v1 = [] 227 tensorflow_constants_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].constants 228 estimator_constants_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].constants 229 230 if hasattr(module, tensorflow_constants_attr_v1): 231 constants_v1.extend(getattr(module, tensorflow_constants_attr_v1)) 232 if hasattr(module, estimator_constants_attr_v1): 233 constants_v1.extend(getattr(module, estimator_constants_attr_v1)) 234 return constants_v1 235 236 237def get_v2_constants(module): 238 """Get a list of TF 2.0 constants in this module. 239 240 Args: 241 module: TensorFlow module. 242 243 Returns: 244 List of all API constants under the given module including TensorFlow and 245 Estimator constants. 246 """ 247 constants_v2 = [] 248 tensorflow_constants_attr = API_ATTRS[TENSORFLOW_API_NAME].constants 249 estimator_constants_attr = API_ATTRS[ESTIMATOR_API_NAME].constants 250 251 if hasattr(module, tensorflow_constants_attr): 252 constants_v2.extend(getattr(module, tensorflow_constants_attr)) 253 if hasattr(module, estimator_constants_attr): 254 constants_v2.extend(getattr(module, estimator_constants_attr)) 255 return constants_v2 256 257 258class api_export(object): # pylint: disable=invalid-name 259 """Provides ways to export symbols to the TensorFlow API.""" 260 261 def __init__(self, *args, **kwargs): # pylint: disable=g-doc-args 262 """Export under the names *args (first one is considered canonical). 263 264 Args: 265 *args: API names in dot delimited format. 266 **kwargs: Optional keyed arguments. 267 v1: Names for the TensorFlow V1 API. If not set, we will use V2 API 268 names both for TensorFlow V1 and V2 APIs. 269 overrides: List of symbols that this is overriding 270 (those overrided api exports will be removed). Note: passing overrides 271 has no effect on exporting a constant. 272 api_name: Name of the API you want to generate (e.g. `tensorflow` or 273 `estimator`). Default is `tensorflow`. 274 allow_multiple_exports: Allow symbol to be exported multiple time under 275 different names. 276 """ 277 self._names = args 278 self._names_v1 = kwargs.get('v1', args) 279 if 'v2' in kwargs: 280 raise ValueError('You passed a "v2" argument to tf_export. This is not ' 281 'what you want. Pass v2 names directly as positional ' 282 'arguments instead.') 283 self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME) 284 self._overrides = kwargs.get('overrides', []) 285 self._allow_multiple_exports = kwargs.get('allow_multiple_exports', False) 286 287 self._validate_symbol_names() 288 289 def _validate_symbol_names(self): 290 """Validate you are exporting symbols under an allowed package. 291 292 We need to ensure things exported by tf_export, estimator_export, etc. 293 export symbols under disjoint top-level package names. 294 295 For TensorFlow, we check that it does not export anything under subpackage 296 names used by components (estimator, keras, etc.). 297 298 For each component, we check that it exports everything under its own 299 subpackage. 300 301 Raises: 302 InvalidSymbolNameError: If you try to export symbol under disallowed name. 303 """ 304 all_symbol_names = set(self._names) | set(self._names_v1) 305 if self._api_name == TENSORFLOW_API_NAME: 306 for subpackage in SUBPACKAGE_NAMESPACES: 307 if any(n.startswith(subpackage) for n in all_symbol_names): 308 raise InvalidSymbolNameError( 309 '@tf_export is not allowed to export symbols under %s.*' % ( 310 subpackage)) 311 else: 312 if not all(n.startswith(self._api_name) for n in all_symbol_names): 313 raise InvalidSymbolNameError( 314 'Can only export symbols under package name of component. ' 315 'e.g. tensorflow_estimator must export all symbols under ' 316 'tf.estimator') 317 318 def __call__(self, func): 319 """Calls this decorator. 320 321 Args: 322 func: decorated symbol (function or class). 323 324 Returns: 325 The input function with _tf_api_names attribute set. 326 327 Raises: 328 SymbolAlreadyExposedError: Raised when a symbol already has API names 329 and kwarg `allow_multiple_exports` not set. 330 """ 331 api_names_attr = API_ATTRS[self._api_name].names 332 api_names_attr_v1 = API_ATTRS_V1[self._api_name].names 333 # Undecorate overridden names 334 for f in self._overrides: 335 _, undecorated_f = tf_decorator.unwrap(f) 336 delattr(undecorated_f, api_names_attr) 337 delattr(undecorated_f, api_names_attr_v1) 338 339 _, undecorated_func = tf_decorator.unwrap(func) 340 self.set_attr(undecorated_func, api_names_attr, self._names) 341 self.set_attr(undecorated_func, api_names_attr_v1, self._names_v1) 342 343 for name in self._names: 344 _NAME_TO_SYMBOL_MAPPING[name] = func 345 for name_v1 in self._names_v1: 346 _NAME_TO_SYMBOL_MAPPING['compat.v1.%s' % name_v1] = func 347 return func 348 349 def set_attr(self, func, api_names_attr, names): 350 # Check for an existing api. We check if attribute name is in 351 # __dict__ instead of using hasattr to verify that subclasses have 352 # their own _tf_api_names as opposed to just inheriting it. 353 if api_names_attr in func.__dict__: 354 if not self._allow_multiple_exports: 355 raise SymbolAlreadyExposedError( 356 'Symbol %s is already exposed as %s.' % 357 (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access 358 setattr(func, api_names_attr, names) 359 360 def export_constant(self, module_name, name): 361 """Store export information for constants/string literals. 362 363 Export information is stored in the module where constants/string literals 364 are defined. 365 366 e.g. 367 ```python 368 foo = 1 369 bar = 2 370 tf_export("consts.foo").export_constant(__name__, 'foo') 371 tf_export("consts.bar").export_constant(__name__, 'bar') 372 ``` 373 374 Args: 375 module_name: (string) Name of the module to store constant at. 376 name: (string) Current constant name. 377 """ 378 module = sys.modules[module_name] 379 api_constants_attr = API_ATTRS[self._api_name].constants 380 api_constants_attr_v1 = API_ATTRS_V1[self._api_name].constants 381 382 if not hasattr(module, api_constants_attr): 383 setattr(module, api_constants_attr, []) 384 # pylint: disable=protected-access 385 getattr(module, api_constants_attr).append( 386 (self._names, name)) 387 388 if not hasattr(module, api_constants_attr_v1): 389 setattr(module, api_constants_attr_v1, []) 390 getattr(module, api_constants_attr_v1).append( 391 (self._names_v1, name)) 392 393 394def kwarg_only(f): 395 """A wrapper that throws away all non-kwarg arguments.""" 396 f_argspec = tf_inspect.getargspec(f) 397 398 def wrapper(*args, **kwargs): 399 if args: 400 raise TypeError( 401 '{f} only takes keyword args (possible keys: {kwargs}). ' 402 'Please pass these args as kwargs instead.' 403 .format(f=f.__name__, kwargs=f_argspec.args)) 404 return f(**kwargs) 405 406 return tf_decorator.make_decorator(f, wrapper, decorator_argspec=f_argspec) 407 408 409tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME) 410estimator_export = functools.partial(api_export, api_name=ESTIMATOR_API_NAME) 411keras_export = functools.partial(api_export, api_name=KERAS_API_NAME) 412