1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for exporting TensorFlow symbols to the API. 16 17Exporting a function or a class: 18 19To export a function or a class use tf_export decorator. For e.g.: 20```python 21@tf_export('foo', 'bar.foo') 22def foo(...): 23 ... 24``` 25 26If a function is assigned to a variable, you can export it by calling 27tf_export explicitly. For e.g.: 28```python 29foo = get_foo(...) 30tf_export('foo', 'bar.foo')(foo) 31``` 32 33 34Exporting a constant 35```python 36foo = 1 37tf_export("consts.foo").export_constant(__name__, 'foo') 38``` 39""" 40from __future__ import absolute_import 41from __future__ import division 42from __future__ import print_function 43 44import sys 45 46from tensorflow.python.util import tf_decorator 47 48 49class SymbolAlreadyExposedError(Exception): 50 """Raised when adding API names to symbol that already has API names.""" 51 pass 52 53 54class tf_export(object): # pylint: disable=invalid-name 55 """Provides ways to export symbols to the TensorFlow API.""" 56 57 def __init__(self, *args, **kwargs): 58 """Export under the names *args (first one is considered canonical). 59 60 Args: 61 *args: API names in dot delimited format. 62 **kwargs: Optional keyed arguments. Currently only supports 'overrides' 63 argument. overrides: List of symbols that this is overriding 64 (those overrided api exports will be removed). Note: passing overrides 65 has no effect on exporting a constant. 66 """ 67 self._names = args 68 self._overrides = kwargs.get('overrides', []) 69 70 def __call__(self, func): 71 """Calls this decorator. 72 73 Args: 74 func: decorated symbol (function or class). 75 76 Returns: 77 The input function with _tf_api_names attribute set. 78 79 Raises: 80 SymbolAlreadyExposedError: Raised when a symbol already has API names. 81 """ 82 # Undecorate overridden names 83 for f in self._overrides: 84 _, undecorated_f = tf_decorator.unwrap(f) 85 del undecorated_f._tf_api_names # pylint: disable=protected-access 86 87 _, undecorated_func = tf_decorator.unwrap(func) 88 89 # Check for an existing api. We check if attribute name is in 90 # __dict__ instead of using hasattr to verify that subclasses have 91 # their own _tf_api_names as opposed to just inheriting it. 92 if '_tf_api_names' in undecorated_func.__dict__: 93 # pylint: disable=protected-access 94 raise SymbolAlreadyExposedError( 95 'Symbol %s is already exposed as %s.' % 96 (undecorated_func.__name__, undecorated_func._tf_api_names)) 97 # pylint: enable=protected-access 98 99 # Complete the export by creating/overriding attribute 100 # pylint: disable=protected-access 101 undecorated_func._tf_api_names = self._names 102 # pylint: enable=protected-access 103 return func 104 105 def export_constant(self, module_name, name): 106 """Store export information for constants/string literals. 107 108 Export information is stored in the module where constants/string literals 109 are defined. 110 111 e.g. 112 ```python 113 foo = 1 114 bar = 2 115 tf_export("consts.foo").export_constant(__name__, 'foo') 116 tf_export("consts.bar").export_constant(__name__, 'bar') 117 ``` 118 119 Args: 120 module_name: (string) Name of the module to store constant at. 121 name: (string) Current constant name. 122 """ 123 module = sys.modules[module_name] 124 if not hasattr(module, '_tf_api_constants'): 125 module._tf_api_constants = [] # pylint: disable=protected-access 126 # pylint: disable=protected-access 127 module._tf_api_constants.append((self._names, name)) 128 129