• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Registry for tensor conversion functions."""
16# pylint: disable=g-bad-name
17import collections
18import threading
19
20import numpy as np
21
22from tensorflow.python.util import lazy_loader
23from tensorflow.python.util.tf_export import tf_export
24
25# Loaded lazily due to a circular dependency
26# ops->tensor_conversion_registry->constant_op->ops.
27constant_op = lazy_loader.LazyLoader(
28    "constant_op", globals(),
29    "tensorflow.python.framework.constant_op")
30
31
32_tensor_conversion_func_registry = collections.defaultdict(list)
33_tensor_conversion_func_cache = {}
34_tensor_conversion_func_lock = threading.Lock()
35
36# Instances of these types are always converted using
37# `_default_conversion_function`.
38_UNCONVERTIBLE_TYPES = (
39    int,
40    float,
41    np.generic,
42    np.ndarray,
43)
44
45
46def _default_conversion_function(value, dtype, name, as_ref):
47  del as_ref  # Unused.
48  return constant_op.constant(value, dtype, name=name)
49
50
51# TODO(josh11b): Add ctx argument to conversion_func() signature.
52@tf_export("register_tensor_conversion_function")
53def register_tensor_conversion_function(base_type,
54                                        conversion_func,
55                                        priority=100):
56  """Registers a function for converting objects of `base_type` to `Tensor`.
57
58  The conversion function must have the following signature:
59
60  ```python
61      def conversion_func(value, dtype=None, name=None, as_ref=False):
62        # ...
63  ```
64
65  It must return a `Tensor` with the given `dtype` if specified. If the
66  conversion function creates a new `Tensor`, it should use the given
67  `name` if specified. All exceptions will be propagated to the caller.
68
69  The conversion function may return `NotImplemented` for some
70  inputs. In this case, the conversion process will continue to try
71  subsequent conversion functions.
72
73  If `as_ref` is true, the function must return a `Tensor` reference,
74  such as a `Variable`.
75
76  NOTE: The conversion functions will execute in order of priority,
77  followed by order of registration. To ensure that a conversion function
78  `F` runs before another conversion function `G`, ensure that `F` is
79  registered with a smaller priority than `G`.
80
81  Args:
82    base_type: The base type or tuple of base types for all objects that
83      `conversion_func` accepts.
84    conversion_func: A function that converts instances of `base_type` to
85      `Tensor`.
86    priority: Optional integer that indicates the priority for applying this
87      conversion function. Conversion functions with smaller priority values run
88      earlier than conversion functions with larger priority values. Defaults to
89      100.
90
91  Raises:
92    TypeError: If the arguments do not have the appropriate type.
93  """
94  base_types = base_type if isinstance(base_type, tuple) else (base_type,)
95  if any(not isinstance(x, type) for x in base_types):
96    raise TypeError("Argument `base_type` must be a type or a tuple of types. "
97                    f"Obtained: {base_type}")
98  if any(issubclass(x, _UNCONVERTIBLE_TYPES) for x in base_types):
99    raise TypeError("Cannot register conversions for Python numeric types and "
100                    "NumPy scalars and arrays.")
101  del base_types  # Only needed for validation.
102  if not callable(conversion_func):
103    raise TypeError("Argument `conversion_func` must be callable. Received "
104                    f"{conversion_func}.")
105
106  with _tensor_conversion_func_lock:
107    _tensor_conversion_func_registry[priority].append(
108        (base_type, conversion_func))
109    _tensor_conversion_func_cache.clear()
110
111
112def get(query):
113  """Get conversion function for objects of `cls`.
114
115  Args:
116    query: The type to query for.
117
118  Returns:
119    A list of conversion functions in increasing order of priority.
120  """
121  if issubclass(query, _UNCONVERTIBLE_TYPES):
122    return [(query, _default_conversion_function)]
123
124  conversion_funcs = _tensor_conversion_func_cache.get(query)
125  if conversion_funcs is None:
126    with _tensor_conversion_func_lock:
127      # Has another thread populated the cache in the meantime?
128      conversion_funcs = _tensor_conversion_func_cache.get(query)
129      if conversion_funcs is None:
130        conversion_funcs = []
131        for _, funcs_at_priority in sorted(
132            _tensor_conversion_func_registry.items()):
133          conversion_funcs.extend(
134              (base_type, conversion_func)
135              for base_type, conversion_func in funcs_at_priority
136              if issubclass(query, base_type))
137        _tensor_conversion_func_cache[query] = conversion_funcs
138  return conversion_funcs
139