• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for exporting TensorFlow symbols to the API.
16
17Exporting a function or a class:
18
19To export a function or a class use tf_export decorator. For e.g.:
20```python
21@tf_export('foo', 'bar.foo')
22def foo(...):
23  ...
24```
25
26If a function is assigned to a variable, you can export it by calling
27tf_export explicitly. For e.g.:
28```python
29foo = get_foo(...)
30tf_export('foo', 'bar.foo')(foo)
31```
32
33
34Exporting a constant
35```python
36foo = 1
37tf_export('consts.foo').export_constant(__name__, 'foo')
38```
39"""
40import collections
41import functools
42import sys
43
44from tensorflow.python.util import tf_decorator
45from tensorflow.python.util import tf_inspect
46
47ESTIMATOR_API_NAME = 'estimator'
48KERAS_API_NAME = 'keras'
49TENSORFLOW_API_NAME = 'tensorflow'
50
51# List of subpackage names used by TensorFlow components. Have to check that
52# TensorFlow core repo does not export any symbols under these names.
53SUBPACKAGE_NAMESPACES = [ESTIMATOR_API_NAME]
54
55_Attributes = collections.namedtuple(
56    'ExportedApiAttributes', ['names', 'constants'])
57
58# Attribute values must be unique to each API.
59API_ATTRS = {
60    TENSORFLOW_API_NAME: _Attributes(
61        '_tf_api_names',
62        '_tf_api_constants'),
63    ESTIMATOR_API_NAME: _Attributes(
64        '_estimator_api_names',
65        '_estimator_api_constants'),
66    KERAS_API_NAME: _Attributes(
67        '_keras_api_names',
68        '_keras_api_constants')
69}
70
71API_ATTRS_V1 = {
72    TENSORFLOW_API_NAME: _Attributes(
73        '_tf_api_names_v1',
74        '_tf_api_constants_v1'),
75    ESTIMATOR_API_NAME: _Attributes(
76        '_estimator_api_names_v1',
77        '_estimator_api_constants_v1'),
78    KERAS_API_NAME: _Attributes(
79        '_keras_api_names_v1',
80        '_keras_api_constants_v1')
81}
82
83
84class SymbolAlreadyExposedError(Exception):
85  """Raised when adding API names to symbol that already has API names."""
86  pass
87
88
89class InvalidSymbolNameError(Exception):
90  """Raised when trying to export symbol as an invalid or unallowed name."""
91  pass
92
93_NAME_TO_SYMBOL_MAPPING = dict()
94
95
96def get_symbol_from_name(name):
97  return _NAME_TO_SYMBOL_MAPPING.get(name)
98
99
100def get_canonical_name_for_symbol(
101    symbol, api_name=TENSORFLOW_API_NAME,
102    add_prefix_to_v1_names=False):
103  """Get canonical name for the API symbol.
104
105  Example:
106  ```python
107  from tensorflow.python.util import tf_export
108  cls = tf_export.get_symbol_from_name('keras.optimizers.Adam')
109
110  # Gives `<class 'keras.optimizer_v2.adam.Adam'>`
111  print(cls)
112
113  # Gives `keras.optimizers.Adam`
114  print(tf_export.get_canonical_name_for_symbol(cls, api_name='keras'))
115  ```
116
117  Args:
118    symbol: API function or class.
119    api_name: API name (tensorflow or estimator).
120    add_prefix_to_v1_names: Specifies whether a name available only in V1
121      should be prefixed with compat.v1.
122
123  Returns:
124    Canonical name for the API symbol (for e.g. initializers.zeros) if
125    canonical name could be determined. Otherwise, returns None.
126  """
127  if not hasattr(symbol, '__dict__'):
128    return None
129  api_names_attr = API_ATTRS[api_name].names
130  _, undecorated_symbol = tf_decorator.unwrap(symbol)
131  if api_names_attr not in undecorated_symbol.__dict__:
132    return None
133  api_names = getattr(undecorated_symbol, api_names_attr)
134  deprecated_api_names = undecorated_symbol.__dict__.get(
135      '_tf_deprecated_api_names', [])
136
137  canonical_name = get_canonical_name(api_names, deprecated_api_names)
138  if canonical_name:
139    return canonical_name
140
141  # If there is no V2 canonical name, get V1 canonical name.
142  api_names_attr = API_ATTRS_V1[api_name].names
143  api_names = getattr(undecorated_symbol, api_names_attr)
144  v1_canonical_name = get_canonical_name(api_names, deprecated_api_names)
145  if add_prefix_to_v1_names:
146    return 'compat.v1.%s' % v1_canonical_name
147  return v1_canonical_name
148
149
150def get_canonical_name(api_names, deprecated_api_names):
151  """Get preferred endpoint name.
152
153  Args:
154    api_names: API names iterable.
155    deprecated_api_names: Deprecated API names iterable.
156  Returns:
157    Returns one of the following in decreasing preference:
158    - first non-deprecated endpoint
159    - first endpoint
160    - None
161  """
162  non_deprecated_name = next(
163      (name for name in api_names if name not in deprecated_api_names),
164      None)
165  if non_deprecated_name:
166    return non_deprecated_name
167  if api_names:
168    return api_names[0]
169  return None
170
171
172def get_v1_names(symbol):
173  """Get a list of TF 1.* names for this symbol.
174
175  Args:
176    symbol: symbol to get API names for.
177
178  Returns:
179    List of all API names for this symbol including TensorFlow and
180    Estimator names.
181  """
182  names_v1 = []
183  tensorflow_api_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].names
184  estimator_api_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].names
185  keras_api_attr_v1 = API_ATTRS_V1[KERAS_API_NAME].names
186
187  if not hasattr(symbol, '__dict__'):
188    return names_v1
189  if tensorflow_api_attr_v1 in symbol.__dict__:
190    names_v1.extend(getattr(symbol, tensorflow_api_attr_v1))
191  if estimator_api_attr_v1 in symbol.__dict__:
192    names_v1.extend(getattr(symbol, estimator_api_attr_v1))
193  if keras_api_attr_v1 in symbol.__dict__:
194    names_v1.extend(getattr(symbol, keras_api_attr_v1))
195  return names_v1
196
197
198def get_v2_names(symbol):
199  """Get a list of TF 2.0 names for this symbol.
200
201  Args:
202    symbol: symbol to get API names for.
203
204  Returns:
205    List of all API names for this symbol including TensorFlow and
206    Estimator names.
207  """
208  names_v2 = []
209  tensorflow_api_attr = API_ATTRS[TENSORFLOW_API_NAME].names
210  estimator_api_attr = API_ATTRS[ESTIMATOR_API_NAME].names
211  keras_api_attr = API_ATTRS[KERAS_API_NAME].names
212
213  if not hasattr(symbol, '__dict__'):
214    return names_v2
215  if tensorflow_api_attr in symbol.__dict__:
216    names_v2.extend(getattr(symbol, tensorflow_api_attr))
217  if estimator_api_attr in symbol.__dict__:
218    names_v2.extend(getattr(symbol, estimator_api_attr))
219  if keras_api_attr in symbol.__dict__:
220    names_v2.extend(getattr(symbol, keras_api_attr))
221  return names_v2
222
223
224def get_v1_constants(module):
225  """Get a list of TF 1.* constants in this module.
226
227  Args:
228    module: TensorFlow module.
229
230  Returns:
231    List of all API constants under the given module including TensorFlow and
232    Estimator constants.
233  """
234  constants_v1 = []
235  tensorflow_constants_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].constants
236  estimator_constants_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].constants
237
238  if hasattr(module, tensorflow_constants_attr_v1):
239    constants_v1.extend(getattr(module, tensorflow_constants_attr_v1))
240  if hasattr(module, estimator_constants_attr_v1):
241    constants_v1.extend(getattr(module, estimator_constants_attr_v1))
242  return constants_v1
243
244
245def get_v2_constants(module):
246  """Get a list of TF 2.0 constants in this module.
247
248  Args:
249    module: TensorFlow module.
250
251  Returns:
252    List of all API constants under the given module including TensorFlow and
253    Estimator constants.
254  """
255  constants_v2 = []
256  tensorflow_constants_attr = API_ATTRS[TENSORFLOW_API_NAME].constants
257  estimator_constants_attr = API_ATTRS[ESTIMATOR_API_NAME].constants
258
259  if hasattr(module, tensorflow_constants_attr):
260    constants_v2.extend(getattr(module, tensorflow_constants_attr))
261  if hasattr(module, estimator_constants_attr):
262    constants_v2.extend(getattr(module, estimator_constants_attr))
263  return constants_v2
264
265
266class api_export(object):  # pylint: disable=invalid-name
267  """Provides ways to export symbols to the TensorFlow API."""
268
269  def __init__(self, *args, **kwargs):  # pylint: disable=g-doc-args
270    """Export under the names *args (first one is considered canonical).
271
272    Args:
273      *args: API names in dot delimited format.
274      **kwargs: Optional keyed arguments.
275        v1: Names for the TensorFlow V1 API. If not set, we will use V2 API
276          names both for TensorFlow V1 and V2 APIs.
277        overrides: List of symbols that this is overriding
278          (those overrided api exports will be removed). Note: passing overrides
279          has no effect on exporting a constant.
280        api_name: Name of the API you want to generate (e.g. `tensorflow` or
281          `estimator`). Default is `tensorflow`.
282        allow_multiple_exports: Allow symbol to be exported multiple time under
283          different names.
284    """
285    self._names = args
286    self._names_v1 = kwargs.get('v1', args)
287    if 'v2' in kwargs:
288      raise ValueError('You passed a "v2" argument to tf_export. This is not '
289                       'what you want. Pass v2 names directly as positional '
290                       'arguments instead.')
291    self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME)
292    self._overrides = kwargs.get('overrides', [])
293    self._allow_multiple_exports = kwargs.get('allow_multiple_exports', False)
294
295    self._validate_symbol_names()
296
297  def _validate_symbol_names(self):
298    """Validate you are exporting symbols under an allowed package.
299
300    We need to ensure things exported by tf_export, estimator_export, etc.
301    export symbols under disjoint top-level package names.
302
303    For TensorFlow, we check that it does not export anything under subpackage
304    names used by components (estimator, keras, etc.).
305
306    For each component, we check that it exports everything under its own
307    subpackage.
308
309    Raises:
310      InvalidSymbolNameError: If you try to export symbol under disallowed name.
311    """
312    all_symbol_names = set(self._names) | set(self._names_v1)
313    if self._api_name == TENSORFLOW_API_NAME:
314      for subpackage in SUBPACKAGE_NAMESPACES:
315        if any(n.startswith(subpackage) for n in all_symbol_names):
316          raise InvalidSymbolNameError(
317              '@tf_export is not allowed to export symbols under %s.*' % (
318                  subpackage))
319    else:
320      if not all(n.startswith(self._api_name) for n in all_symbol_names):
321        raise InvalidSymbolNameError(
322            'Can only export symbols under package name of component. '
323            'e.g. tensorflow_estimator must export all symbols under '
324            'tf.estimator')
325
326  def __call__(self, func):
327    """Calls this decorator.
328
329    Args:
330      func: decorated symbol (function or class).
331
332    Returns:
333      The input function with _tf_api_names attribute set.
334
335    Raises:
336      SymbolAlreadyExposedError: Raised when a symbol already has API names
337        and kwarg `allow_multiple_exports` not set.
338    """
339    api_names_attr = API_ATTRS[self._api_name].names
340    api_names_attr_v1 = API_ATTRS_V1[self._api_name].names
341    # Undecorate overridden names
342    for f in self._overrides:
343      _, undecorated_f = tf_decorator.unwrap(f)
344      delattr(undecorated_f, api_names_attr)
345      delattr(undecorated_f, api_names_attr_v1)
346
347    _, undecorated_func = tf_decorator.unwrap(func)
348    self.set_attr(undecorated_func, api_names_attr, self._names)
349    self.set_attr(undecorated_func, api_names_attr_v1, self._names_v1)
350
351    for name in self._names:
352      _NAME_TO_SYMBOL_MAPPING[name] = func
353    for name_v1 in self._names_v1:
354      _NAME_TO_SYMBOL_MAPPING['compat.v1.%s' % name_v1] = func
355    return func
356
357  def set_attr(self, func, api_names_attr, names):
358    # Check for an existing api. We check if attribute name is in
359    # __dict__ instead of using hasattr to verify that subclasses have
360    # their own _tf_api_names as opposed to just inheriting it.
361    if api_names_attr in func.__dict__:
362      if not self._allow_multiple_exports:
363        raise SymbolAlreadyExposedError(
364            'Symbol %s is already exposed as %s.' %
365            (func.__name__, getattr(func, api_names_attr)))  # pylint: disable=protected-access
366    setattr(func, api_names_attr, names)
367
368  def export_constant(self, module_name, name):
369    """Store export information for constants/string literals.
370
371    Export information is stored in the module where constants/string literals
372    are defined.
373
374    e.g.
375    ```python
376    foo = 1
377    bar = 2
378    tf_export("consts.foo").export_constant(__name__, 'foo')
379    tf_export("consts.bar").export_constant(__name__, 'bar')
380    ```
381
382    Args:
383      module_name: (string) Name of the module to store constant at.
384      name: (string) Current constant name.
385    """
386    module = sys.modules[module_name]
387    api_constants_attr = API_ATTRS[self._api_name].constants
388    api_constants_attr_v1 = API_ATTRS_V1[self._api_name].constants
389
390    if not hasattr(module, api_constants_attr):
391      setattr(module, api_constants_attr, [])
392    # pylint: disable=protected-access
393    getattr(module, api_constants_attr).append(
394        (self._names, name))
395
396    if not hasattr(module, api_constants_attr_v1):
397      setattr(module, api_constants_attr_v1, [])
398    getattr(module, api_constants_attr_v1).append(
399        (self._names_v1, name))
400
401
402def kwarg_only(f):
403  """A wrapper that throws away all non-kwarg arguments."""
404  f_argspec = tf_inspect.getargspec(f)
405
406  def wrapper(*args, **kwargs):
407    if args:
408      raise TypeError(
409          '{f} only takes keyword args (possible keys: {kwargs}). '
410          'Please pass these args as kwargs instead.'
411          .format(f=f.__name__, kwargs=f_argspec.args))
412    return f(**kwargs)
413
414  return tf_decorator.make_decorator(f, wrapper, decorator_argspec=f_argspec)
415
416
417tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME)
418estimator_export = functools.partial(api_export, api_name=ESTIMATOR_API_NAME)
419keras_export = functools.partial(api_export, api_name=KERAS_API_NAME)
420