• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for exporting TensorFlow symbols to the API.
16
17Exporting a function or a class:
18
19To export a function or a class use tf_export decorator. For e.g.:
20```python
21@tf_export('foo', 'bar.foo')
22def foo(...):
23  ...
24```
25
26If a function is assigned to a variable, you can export it by calling
27tf_export explicitly. For e.g.:
28```python
29foo = get_foo(...)
30tf_export('foo', 'bar.foo')(foo)
31```
32
33
34Exporting a constant
35```python
36foo = 1
37tf_export("consts.foo").export_constant(__name__, 'foo')
38```
39"""
40from __future__ import absolute_import
41from __future__ import division
42from __future__ import print_function
43
44import sys
45
46from tensorflow.python.util import tf_decorator
47
48
49class SymbolAlreadyExposedError(Exception):
50  """Raised when adding API names to symbol that already has API names."""
51  pass
52
53
54class tf_export(object):  # pylint: disable=invalid-name
55  """Provides ways to export symbols to the TensorFlow API."""
56
57  def __init__(self, *args, **kwargs):
58    """Export under the names *args (first one is considered canonical).
59
60    Args:
61      *args: API names in dot delimited format.
62      **kwargs: Optional keyed arguments. Currently only supports 'overrides'
63        argument. overrides: List of symbols that this is overriding
64        (those overrided api exports will be removed). Note: passing overrides
65        has no effect on exporting a constant.
66    """
67    self._names = args
68    self._overrides = kwargs.get('overrides', [])
69
70  def __call__(self, func):
71    """Calls this decorator.
72
73    Args:
74      func: decorated symbol (function or class).
75
76    Returns:
77      The input function with _tf_api_names attribute set.
78
79    Raises:
80      SymbolAlreadyExposedError: Raised when a symbol already has API names.
81    """
82    # Undecorate overridden names
83    for f in self._overrides:
84      _, undecorated_f = tf_decorator.unwrap(f)
85      del undecorated_f._tf_api_names  # pylint: disable=protected-access
86
87    _, undecorated_func = tf_decorator.unwrap(func)
88
89    # Check for an existing api. We check if attribute name is in
90    # __dict__ instead of using hasattr to verify that subclasses have
91    # their own _tf_api_names as opposed to just inheriting it.
92    if '_tf_api_names' in undecorated_func.__dict__:
93      # pylint: disable=protected-access
94      raise SymbolAlreadyExposedError(
95          'Symbol %s is already exposed as %s.' %
96          (undecorated_func.__name__, undecorated_func._tf_api_names))
97      # pylint: enable=protected-access
98
99    # Complete the export by creating/overriding attribute
100    # pylint: disable=protected-access
101    undecorated_func._tf_api_names = self._names
102    # pylint: enable=protected-access
103    return func
104
105  def export_constant(self, module_name, name):
106    """Store export information for constants/string literals.
107
108    Export information is stored in the module where constants/string literals
109    are defined.
110
111    e.g.
112    ```python
113    foo = 1
114    bar = 2
115    tf_export("consts.foo").export_constant(__name__, 'foo')
116    tf_export("consts.bar").export_constant(__name__, 'bar')
117    ```
118
119    Args:
120      module_name: (string) Name of the module to store constant at.
121      name: (string) Current constant name.
122    """
123    module = sys.modules[module_name]
124    if not hasattr(module, '_tf_api_constants'):
125      module._tf_api_constants = []  # pylint: disable=protected-access
126    # pylint: disable=protected-access
127    module._tf_api_constants.append((self._names, name))
128
129