1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for exporting TensorFlow symbols to the API. 16 17Exporting a function or a class: 18 19To export a function or a class use tf_export decorator. For e.g.: 20```python 21@tf_export('foo', 'bar.foo') 22def foo(...): 23 ... 24``` 25 26If a function is assigned to a variable, you can export it by calling 27tf_export explicitly. For e.g.: 28```python 29foo = get_foo(...) 30tf_export('foo', 'bar.foo')(foo) 31``` 32 33 34Exporting a constant 35```python 36foo = 1 37tf_export('consts.foo').export_constant(__name__, 'foo') 38``` 39""" 40from __future__ import absolute_import 41from __future__ import division 42from __future__ import print_function 43 44import collections 45import functools 46import sys 47 48from tensorflow.python.util import tf_decorator 49from tensorflow.python.util import tf_inspect 50 51ESTIMATOR_API_NAME = 'estimator' 52KERAS_API_NAME = 'keras' 53TENSORFLOW_API_NAME = 'tensorflow' 54 55# List of subpackage names used by TensorFlow components. Have to check that 56# TensorFlow core repo does not export any symbols under these names. 57SUBPACKAGE_NAMESPACES = [ESTIMATOR_API_NAME] 58 59_Attributes = collections.namedtuple( 60 'ExportedApiAttributes', ['names', 'constants']) 61 62# Attribute values must be unique to each API. 63API_ATTRS = { 64 TENSORFLOW_API_NAME: _Attributes( 65 '_tf_api_names', 66 '_tf_api_constants'), 67 ESTIMATOR_API_NAME: _Attributes( 68 '_estimator_api_names', 69 '_estimator_api_constants'), 70 KERAS_API_NAME: _Attributes( 71 '_keras_api_names', 72 '_keras_api_constants') 73} 74 75API_ATTRS_V1 = { 76 TENSORFLOW_API_NAME: _Attributes( 77 '_tf_api_names_v1', 78 '_tf_api_constants_v1'), 79 ESTIMATOR_API_NAME: _Attributes( 80 '_estimator_api_names_v1', 81 '_estimator_api_constants_v1'), 82 KERAS_API_NAME: _Attributes( 83 '_keras_api_names_v1', 84 '_keras_api_constants_v1') 85} 86 87 88class SymbolAlreadyExposedError(Exception): 89 """Raised when adding API names to symbol that already has API names.""" 90 pass 91 92 93class InvalidSymbolNameError(Exception): 94 """Raised when trying to export symbol as an invalid or unallowed name.""" 95 pass 96 97 98def get_canonical_name_for_symbol( 99 symbol, api_name=TENSORFLOW_API_NAME, 100 add_prefix_to_v1_names=False): 101 """Get canonical name for the API symbol. 102 103 Args: 104 symbol: API function or class. 105 api_name: API name (tensorflow or estimator). 106 add_prefix_to_v1_names: Specifies whether a name available only in V1 107 should be prefixed with compat.v1. 108 109 Returns: 110 Canonical name for the API symbol (for e.g. initializers.zeros) if 111 canonical name could be determined. Otherwise, returns None. 112 """ 113 if not hasattr(symbol, '__dict__'): 114 return None 115 api_names_attr = API_ATTRS[api_name].names 116 _, undecorated_symbol = tf_decorator.unwrap(symbol) 117 if api_names_attr not in undecorated_symbol.__dict__: 118 return None 119 api_names = getattr(undecorated_symbol, api_names_attr) 120 deprecated_api_names = undecorated_symbol.__dict__.get( 121 '_tf_deprecated_api_names', []) 122 123 canonical_name = get_canonical_name(api_names, deprecated_api_names) 124 if canonical_name: 125 return canonical_name 126 127 # If there is no V2 canonical name, get V1 canonical name. 128 api_names_attr = API_ATTRS_V1[api_name].names 129 api_names = getattr(undecorated_symbol, api_names_attr) 130 v1_canonical_name = get_canonical_name(api_names, deprecated_api_names) 131 if add_prefix_to_v1_names: 132 return 'compat.v1.%s' % v1_canonical_name 133 return v1_canonical_name 134 135 136def get_canonical_name(api_names, deprecated_api_names): 137 """Get preferred endpoint name. 138 139 Args: 140 api_names: API names iterable. 141 deprecated_api_names: Deprecated API names iterable. 142 Returns: 143 Returns one of the following in decreasing preference: 144 - first non-deprecated endpoint 145 - first endpoint 146 - None 147 """ 148 non_deprecated_name = next( 149 (name for name in api_names if name not in deprecated_api_names), 150 None) 151 if non_deprecated_name: 152 return non_deprecated_name 153 if api_names: 154 return api_names[0] 155 return None 156 157 158def get_v1_names(symbol): 159 """Get a list of TF 1.* names for this symbol. 160 161 Args: 162 symbol: symbol to get API names for. 163 164 Returns: 165 List of all API names for this symbol including TensorFlow and 166 Estimator names. 167 """ 168 names_v1 = [] 169 tensorflow_api_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].names 170 estimator_api_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].names 171 keras_api_attr_v1 = API_ATTRS_V1[KERAS_API_NAME].names 172 173 if not hasattr(symbol, '__dict__'): 174 return names_v1 175 if tensorflow_api_attr_v1 in symbol.__dict__: 176 names_v1.extend(getattr(symbol, tensorflow_api_attr_v1)) 177 if estimator_api_attr_v1 in symbol.__dict__: 178 names_v1.extend(getattr(symbol, estimator_api_attr_v1)) 179 if keras_api_attr_v1 in symbol.__dict__: 180 names_v1.extend(getattr(symbol, keras_api_attr_v1)) 181 return names_v1 182 183 184def get_v2_names(symbol): 185 """Get a list of TF 2.0 names for this symbol. 186 187 Args: 188 symbol: symbol to get API names for. 189 190 Returns: 191 List of all API names for this symbol including TensorFlow and 192 Estimator names. 193 """ 194 names_v2 = [] 195 tensorflow_api_attr = API_ATTRS[TENSORFLOW_API_NAME].names 196 estimator_api_attr = API_ATTRS[ESTIMATOR_API_NAME].names 197 keras_api_attr = API_ATTRS[KERAS_API_NAME].names 198 199 if not hasattr(symbol, '__dict__'): 200 return names_v2 201 if tensorflow_api_attr in symbol.__dict__: 202 names_v2.extend(getattr(symbol, tensorflow_api_attr)) 203 if estimator_api_attr in symbol.__dict__: 204 names_v2.extend(getattr(symbol, estimator_api_attr)) 205 if keras_api_attr in symbol.__dict__: 206 names_v2.extend(getattr(symbol, keras_api_attr)) 207 return names_v2 208 209 210def get_v1_constants(module): 211 """Get a list of TF 1.* constants in this module. 212 213 Args: 214 module: TensorFlow module. 215 216 Returns: 217 List of all API constants under the given module including TensorFlow and 218 Estimator constants. 219 """ 220 constants_v1 = [] 221 tensorflow_constants_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].constants 222 estimator_constants_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].constants 223 224 if hasattr(module, tensorflow_constants_attr_v1): 225 constants_v1.extend(getattr(module, tensorflow_constants_attr_v1)) 226 if hasattr(module, estimator_constants_attr_v1): 227 constants_v1.extend(getattr(module, estimator_constants_attr_v1)) 228 return constants_v1 229 230 231def get_v2_constants(module): 232 """Get a list of TF 2.0 constants in this module. 233 234 Args: 235 module: TensorFlow module. 236 237 Returns: 238 List of all API constants under the given module including TensorFlow and 239 Estimator constants. 240 """ 241 constants_v2 = [] 242 tensorflow_constants_attr = API_ATTRS[TENSORFLOW_API_NAME].constants 243 estimator_constants_attr = API_ATTRS[ESTIMATOR_API_NAME].constants 244 245 if hasattr(module, tensorflow_constants_attr): 246 constants_v2.extend(getattr(module, tensorflow_constants_attr)) 247 if hasattr(module, estimator_constants_attr): 248 constants_v2.extend(getattr(module, estimator_constants_attr)) 249 return constants_v2 250 251 252class api_export(object): # pylint: disable=invalid-name 253 """Provides ways to export symbols to the TensorFlow API.""" 254 255 def __init__(self, *args, **kwargs): # pylint: disable=g-doc-args 256 """Export under the names *args (first one is considered canonical). 257 258 Args: 259 *args: API names in dot delimited format. 260 **kwargs: Optional keyed arguments. 261 v1: Names for the TensorFlow V1 API. If not set, we will use V2 API 262 names both for TensorFlow V1 and V2 APIs. 263 overrides: List of symbols that this is overriding 264 (those overrided api exports will be removed). Note: passing overrides 265 has no effect on exporting a constant. 266 api_name: Name of the API you want to generate (e.g. `tensorflow` or 267 `estimator`). Default is `tensorflow`. 268 allow_multiple_exports: Allow symbol to be exported multiple time under 269 different names. 270 """ 271 self._names = args 272 self._names_v1 = kwargs.get('v1', args) 273 if 'v2' in kwargs: 274 raise ValueError('You passed a "v2" argument to tf_export. This is not ' 275 'what you want. Pass v2 names directly as positional ' 276 'arguments instead.') 277 self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME) 278 self._overrides = kwargs.get('overrides', []) 279 self._allow_multiple_exports = kwargs.get('allow_multiple_exports', False) 280 281 self._validate_symbol_names() 282 283 def _validate_symbol_names(self): 284 """Validate you are exporting symbols under an allowed package. 285 286 We need to ensure things exported by tf_export, estimator_export, etc. 287 export symbols under disjoint top-level package names. 288 289 For TensorFlow, we check that it does not export anything under subpackage 290 names used by components (estimator, keras, etc.). 291 292 For each component, we check that it exports everything under its own 293 subpackage. 294 295 Raises: 296 InvalidSymbolNameError: If you try to export symbol under disallowed name. 297 """ 298 all_symbol_names = set(self._names) | set(self._names_v1) 299 if self._api_name == TENSORFLOW_API_NAME: 300 for subpackage in SUBPACKAGE_NAMESPACES: 301 if any(n.startswith(subpackage) for n in all_symbol_names): 302 raise InvalidSymbolNameError( 303 '@tf_export is not allowed to export symbols under %s.*' % ( 304 subpackage)) 305 else: 306 if not all(n.startswith(self._api_name) for n in all_symbol_names): 307 raise InvalidSymbolNameError( 308 'Can only export symbols under package name of component. ' 309 'e.g. tensorflow_estimator must export all symbols under ' 310 'tf.estimator') 311 312 def __call__(self, func): 313 """Calls this decorator. 314 315 Args: 316 func: decorated symbol (function or class). 317 318 Returns: 319 The input function with _tf_api_names attribute set. 320 321 Raises: 322 SymbolAlreadyExposedError: Raised when a symbol already has API names 323 and kwarg `allow_multiple_exports` not set. 324 """ 325 api_names_attr = API_ATTRS[self._api_name].names 326 api_names_attr_v1 = API_ATTRS_V1[self._api_name].names 327 # Undecorate overridden names 328 for f in self._overrides: 329 _, undecorated_f = tf_decorator.unwrap(f) 330 delattr(undecorated_f, api_names_attr) 331 delattr(undecorated_f, api_names_attr_v1) 332 333 _, undecorated_func = tf_decorator.unwrap(func) 334 self.set_attr(undecorated_func, api_names_attr, self._names) 335 self.set_attr(undecorated_func, api_names_attr_v1, self._names_v1) 336 return func 337 338 def set_attr(self, func, api_names_attr, names): 339 # Check for an existing api. We check if attribute name is in 340 # __dict__ instead of using hasattr to verify that subclasses have 341 # their own _tf_api_names as opposed to just inheriting it. 342 if api_names_attr in func.__dict__: 343 if not self._allow_multiple_exports: 344 raise SymbolAlreadyExposedError( 345 'Symbol %s is already exposed as %s.' % 346 (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access 347 setattr(func, api_names_attr, names) 348 349 def export_constant(self, module_name, name): 350 """Store export information for constants/string literals. 351 352 Export information is stored in the module where constants/string literals 353 are defined. 354 355 e.g. 356 ```python 357 foo = 1 358 bar = 2 359 tf_export("consts.foo").export_constant(__name__, 'foo') 360 tf_export("consts.bar").export_constant(__name__, 'bar') 361 ``` 362 363 Args: 364 module_name: (string) Name of the module to store constant at. 365 name: (string) Current constant name. 366 """ 367 module = sys.modules[module_name] 368 api_constants_attr = API_ATTRS[self._api_name].constants 369 api_constants_attr_v1 = API_ATTRS_V1[self._api_name].constants 370 371 if not hasattr(module, api_constants_attr): 372 setattr(module, api_constants_attr, []) 373 # pylint: disable=protected-access 374 getattr(module, api_constants_attr).append( 375 (self._names, name)) 376 377 if not hasattr(module, api_constants_attr_v1): 378 setattr(module, api_constants_attr_v1, []) 379 getattr(module, api_constants_attr_v1).append( 380 (self._names_v1, name)) 381 382 383def kwarg_only(f): 384 """A wrapper that throws away all non-kwarg arguments.""" 385 f_argspec = tf_inspect.getargspec(f) 386 387 def wrapper(*args, **kwargs): 388 if args: 389 raise TypeError( 390 '{f} only takes keyword args (possible keys: {kwargs}). ' 391 'Please pass these args as kwargs instead.' 392 .format(f=f.__name__, kwargs=f_argspec.args)) 393 return f(**kwargs) 394 395 return tf_decorator.make_decorator(f, wrapper, decorator_argspec=f_argspec) 396 397 398tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME) 399estimator_export = functools.partial(api_export, api_name=ESTIMATOR_API_NAME) 400keras_export = functools.partial(api_export, api_name=KERAS_API_NAME) 401