1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Registry for tensor conversion functions.""" 16# pylint: disable=g-bad-name 17import collections 18import threading 19 20import numpy as np 21 22from tensorflow.python.util import lazy_loader 23from tensorflow.python.util.tf_export import tf_export 24 25# Loaded lazily due to a circular dependency 26# ops->tensor_conversion_registry->constant_op->ops. 27constant_op = lazy_loader.LazyLoader( 28 "constant_op", globals(), 29 "tensorflow.python.framework.constant_op") 30 31 32_tensor_conversion_func_registry = collections.defaultdict(list) 33_tensor_conversion_func_cache = {} 34_tensor_conversion_func_lock = threading.Lock() 35 36# Instances of these types are always converted using 37# `_default_conversion_function`. 38_UNCONVERTIBLE_TYPES = ( 39 int, 40 float, 41 np.generic, 42 np.ndarray, 43) 44 45 46def _default_conversion_function(value, dtype, name, as_ref): 47 del as_ref # Unused. 48 return constant_op.constant(value, dtype, name=name) 49 50 51# TODO(josh11b): Add ctx argument to conversion_func() signature. 52@tf_export("register_tensor_conversion_function") 53def register_tensor_conversion_function(base_type, 54 conversion_func, 55 priority=100): 56 """Registers a function for converting objects of `base_type` to `Tensor`. 57 58 The conversion function must have the following signature: 59 60 ```python 61 def conversion_func(value, dtype=None, name=None, as_ref=False): 62 # ... 63 ``` 64 65 It must return a `Tensor` with the given `dtype` if specified. If the 66 conversion function creates a new `Tensor`, it should use the given 67 `name` if specified. All exceptions will be propagated to the caller. 68 69 The conversion function may return `NotImplemented` for some 70 inputs. In this case, the conversion process will continue to try 71 subsequent conversion functions. 72 73 If `as_ref` is true, the function must return a `Tensor` reference, 74 such as a `Variable`. 75 76 NOTE: The conversion functions will execute in order of priority, 77 followed by order of registration. To ensure that a conversion function 78 `F` runs before another conversion function `G`, ensure that `F` is 79 registered with a smaller priority than `G`. 80 81 Args: 82 base_type: The base type or tuple of base types for all objects that 83 `conversion_func` accepts. 84 conversion_func: A function that converts instances of `base_type` to 85 `Tensor`. 86 priority: Optional integer that indicates the priority for applying this 87 conversion function. Conversion functions with smaller priority values run 88 earlier than conversion functions with larger priority values. Defaults to 89 100. 90 91 Raises: 92 TypeError: If the arguments do not have the appropriate type. 93 """ 94 base_types = base_type if isinstance(base_type, tuple) else (base_type,) 95 if any(not isinstance(x, type) for x in base_types): 96 raise TypeError("Argument `base_type` must be a type or a tuple of types. " 97 f"Obtained: {base_type}") 98 if any(issubclass(x, _UNCONVERTIBLE_TYPES) for x in base_types): 99 raise TypeError("Cannot register conversions for Python numeric types and " 100 "NumPy scalars and arrays.") 101 del base_types # Only needed for validation. 102 if not callable(conversion_func): 103 raise TypeError("Argument `conversion_func` must be callable. Received " 104 f"{conversion_func}.") 105 106 with _tensor_conversion_func_lock: 107 _tensor_conversion_func_registry[priority].append( 108 (base_type, conversion_func)) 109 _tensor_conversion_func_cache.clear() 110 111 112def get(query): 113 """Get conversion function for objects of `cls`. 114 115 Args: 116 query: The type to query for. 117 118 Returns: 119 A list of conversion functions in increasing order of priority. 120 """ 121 if issubclass(query, _UNCONVERTIBLE_TYPES): 122 return [(query, _default_conversion_function)] 123 124 conversion_funcs = _tensor_conversion_func_cache.get(query) 125 if conversion_funcs is None: 126 with _tensor_conversion_func_lock: 127 # Has another thread populated the cache in the meantime? 128 conversion_funcs = _tensor_conversion_func_cache.get(query) 129 if conversion_funcs is None: 130 conversion_funcs = [] 131 for _, funcs_at_priority in sorted( 132 _tensor_conversion_func_registry.items()): 133 conversion_funcs.extend( 134 (base_type, conversion_func) 135 for base_type, conversion_func in funcs_at_priority 136 if issubclass(query, base_type)) 137 _tensor_conversion_func_cache[query] = conversion_funcs 138 return conversion_funcs 139