1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Registry for tensor conversion functions.""" 16# pylint: disable=g-bad-name 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import threading 23 24import numpy as np 25import six 26 27from tensorflow.python.util import lazy_loader 28from tensorflow.python.util.tf_export import tf_export 29 30# Loaded lazily due to a circular dependency 31# ops->tensor_conversion_registry->constant_op->ops. 32constant_op = lazy_loader.LazyLoader( 33 "constant_op", globals(), 34 "tensorflow.python.framework.constant_op") 35 36 37_tensor_conversion_func_registry = collections.defaultdict(list) 38_tensor_conversion_func_cache = {} 39_tensor_conversion_func_lock = threading.Lock() 40 41# Instances of these types are always converted using 42# `_default_conversion_function`. 43_UNCONVERTIBLE_TYPES = six.integer_types + ( 44 float, 45 np.generic, 46 np.ndarray, 47) 48 49 50def _default_conversion_function(value, dtype, name, as_ref): 51 del as_ref # Unused. 52 return constant_op.constant(value, dtype, name=name) 53 54 55# TODO(josh11b): Add ctx argument to conversion_func() signature. 56@tf_export("register_tensor_conversion_function") 57def register_tensor_conversion_function(base_type, 58 conversion_func, 59 priority=100): 60 """Registers a function for converting objects of `base_type` to `Tensor`. 61 62 The conversion function must have the following signature: 63 64 ```python 65 def conversion_func(value, dtype=None, name=None, as_ref=False): 66 # ... 67 ``` 68 69 It must return a `Tensor` with the given `dtype` if specified. If the 70 conversion function creates a new `Tensor`, it should use the given 71 `name` if specified. All exceptions will be propagated to the caller. 72 73 The conversion function may return `NotImplemented` for some 74 inputs. In this case, the conversion process will continue to try 75 subsequent conversion functions. 76 77 If `as_ref` is true, the function must return a `Tensor` reference, 78 such as a `Variable`. 79 80 NOTE: The conversion functions will execute in order of priority, 81 followed by order of registration. To ensure that a conversion function 82 `F` runs before another conversion function `G`, ensure that `F` is 83 registered with a smaller priority than `G`. 84 85 Args: 86 base_type: The base type or tuple of base types for all objects that 87 `conversion_func` accepts. 88 conversion_func: A function that converts instances of `base_type` to 89 `Tensor`. 90 priority: Optional integer that indicates the priority for applying this 91 conversion function. Conversion functions with smaller priority values run 92 earlier than conversion functions with larger priority values. Defaults to 93 100. 94 95 Raises: 96 TypeError: If the arguments do not have the appropriate type. 97 """ 98 base_types = base_type if isinstance(base_type, tuple) else (base_type,) 99 if any(not isinstance(x, type) for x in base_types): 100 raise TypeError("base_type must be a type or a tuple of types.") 101 if any(issubclass(x, _UNCONVERTIBLE_TYPES) for x in base_types): 102 raise TypeError("Cannot register conversions for Python numeric types and " 103 "NumPy scalars and arrays.") 104 del base_types # Only needed for validation. 105 if not callable(conversion_func): 106 raise TypeError("conversion_func must be callable.") 107 108 with _tensor_conversion_func_lock: 109 _tensor_conversion_func_registry[priority].append( 110 (base_type, conversion_func)) 111 _tensor_conversion_func_cache.clear() 112 113 114def get(query): 115 """Get conversion function for objects of `cls`. 116 117 Args: 118 query: The type to query for. 119 120 Returns: 121 A list of conversion functions in increasing order of priority. 122 """ 123 if issubclass(query, _UNCONVERTIBLE_TYPES): 124 return [(query, _default_conversion_function)] 125 126 conversion_funcs = _tensor_conversion_func_cache.get(query) 127 if conversion_funcs is None: 128 with _tensor_conversion_func_lock: 129 # Has another thread populated the cache in the meantime? 130 conversion_funcs = _tensor_conversion_func_cache.get(query) 131 if conversion_funcs is None: 132 conversion_funcs = [] 133 for _, funcs_at_priority in sorted( 134 _tensor_conversion_func_registry.items()): 135 conversion_funcs.extend( 136 (base_type, conversion_func) 137 for base_type, conversion_func in funcs_at_priority 138 if issubclass(query, base_type)) 139 _tensor_conversion_func_cache[query] = conversion_funcs 140 return conversion_funcs 141