1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Type specifications for TensorFlow APIs.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22 23import numpy as np 24import six 25 26from tensorflow.python import _pywrap_utils 27from tensorflow.python.framework import composite_tensor 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.platform import tf_logging as logging 31from tensorflow.python.util import compat 32from tensorflow.python.util import nest 33from tensorflow.python.util import tf_decorator 34from tensorflow.python.util.lazy_loader import LazyLoader 35from tensorflow.python.util.tf_export import tf_export 36 37# Use LazyLoader to avoid circular dependencies. 38tensor_spec = LazyLoader( 39 "tensor_spec", globals(), 40 "tensorflow.python.framework.tensor_spec") 41ops = LazyLoader( 42 "ops", globals(), 43 "tensorflow.python.framework.ops") 44 45 46@tf_export("TypeSpec", v1=["TypeSpec", "data.experimental.Structure"]) 47@six.add_metaclass(abc.ABCMeta) 48class TypeSpec(object): 49 """Specifies a TensorFlow value type. 50 51 A `tf.TypeSpec` provides metadata describing an object accepted or returned 52 by TensorFlow APIs. Concrete subclasses, such as `tf.TensorSpec` and 53 `tf.RaggedTensorSpec`, are used to describe different value types. 54 55 For example, `tf.function`'s `input_signature` argument accepts a list 56 (or nested structure) of `TypeSpec`s. 57 58 Creating new subclasses of TypeSpec (outside of TensorFlow core) is not 59 currently supported. In particular, we may make breaking changes to the 60 private methods and properties defined by this base class. 61 """ 62 # === Subclassing === 63 # 64 # Each `TypeSpec` subclass must define: 65 # 66 # * A "component encoding" for values. 67 # * A "serialization" for types. 68 # 69 # The component encoding for a value is a nested structure of `tf.Tensor` 70 # or `CompositeTensor` that can be used by the `TypeSpec` to reconstruct 71 # the value. Each individual `TypeSpec` must use the same nested structure 72 # for all values -- this structure is defined by the `component_specs` 73 # attribute. Decomposing values into components, and reconstructing them 74 # from those components, should be inexpensive. In particular, it should 75 # *not* require any TensorFlow ops. 76 # 77 # The serialization for a `TypeSpec` is a nested tuple of values that can 78 # be used to reconstruct the `TypeSpec`. See the documentation for 79 # `_serialize()` for more information. 80 81 __slots__ = [] 82 83 @abc.abstractproperty 84 def value_type(self): 85 """The Python type for values that are compatible with this TypeSpec.""" 86 raise NotImplementedError("%s.value_type" % type(self).__name__) 87 88 def is_compatible_with(self, spec_or_value): 89 """Returns true if `spec_or_value` is compatible with this TypeSpec.""" 90 # === Subclassing === 91 # If not overridden by subclasses, the default behavior is to convert 92 # `spec_or_value` to a `TypeSpec` (if it isn't already); and then to 93 # consider two `TypeSpec`s compatible if they have the same type, and 94 # the values returned by `_serialize` are compatible (where 95 # `tf.TensorShape`, `tf.TensorSpec`, and `tf.DType` are checked for 96 # compatibility using their `is_compatible_with` method; and all other 97 # types are considered compatible if they are equal). 98 if not isinstance(spec_or_value, TypeSpec): 99 spec_or_value = type_spec_from_value(spec_or_value) 100 if type(self) is not type(spec_or_value): 101 return False 102 return self.__is_compatible(self._serialize(), 103 spec_or_value._serialize()) # pylint: disable=protected-access 104 105 def most_specific_compatible_type(self, other): 106 """Returns the most specific TypeSpec compatible with `self` and `other`. 107 108 Args: 109 other: A `TypeSpec`. 110 111 Raises: 112 ValueError: If there is no TypeSpec that is compatible with both `self` 113 and `other`. 114 """ 115 # === Subclassing === 116 # If not overridden by a subclass, the default behavior is to raise a 117 # `ValueError` if `self` and `other` have different types, or if their type 118 # serializations differ by anything other than `TensorShape`s. Otherwise, 119 # the two type serializations are combined (using 120 # `most_specific_compatible_shape` to combine `TensorShape`s), and the 121 # result is used to construct and return a new `TypeSpec`. 122 if type(self) is not type(other): 123 raise ValueError("No TypeSpec is compatible with both %s and %s" % 124 (self, other)) 125 merged = self.__most_specific_compatible_type_serialization( 126 self._serialize(), other._serialize()) # pylint: disable=protected-access 127 return self._deserialize(merged) 128 129 # === Component encoding for values === 130 131 @abc.abstractmethod 132 def _to_components(self, value): 133 """Encodes `value` as a nested structure of `Tensor` or `CompositeTensor`. 134 135 Args: 136 value: A value compatible with this `TypeSpec`. (Caller is responsible 137 for ensuring compatibility.) 138 139 Returns: 140 A nested structure of `tf.Tensor` or `tf.CompositeTensor` compatible with 141 `self._component_specs`, which can be used to reconstruct `value`. 142 """ 143 # === Subclassing === 144 # This method must be inexpensive (do not call TF ops). 145 raise NotImplementedError("%s._to_components()" % type(self).__name__) 146 147 @abc.abstractmethod 148 def _from_components(self, components): 149 """Reconstructs a value from a nested structure of Tensor/CompositeTensor. 150 151 Args: 152 components: A nested structure of `tf.Tensor` or `tf.CompositeTensor`, 153 compatible with `self._component_specs`. (Caller is repsonsible for 154 ensuring compatibility.) 155 156 Returns: 157 A value that is compatible with this `TypeSpec`. 158 """ 159 # === Subclassing === 160 # This method must be inexpensive (do not call TF ops). 161 raise NotImplementedError("%s._from_components()" % type(self).__name__) 162 163 @abc.abstractproperty 164 def _component_specs(self): 165 """A nested structure of TypeSpecs for this type's components. 166 167 Returns: 168 A nested structure describing the component encodings that are returned 169 by this TypeSpec's `_to_components` method. In particular, for a 170 TypeSpec `spec` and a compatible value `value`: 171 172 ``` 173 nest.map_structure(lambda t, c: assert t.is_compatible_with(c), 174 spec._component_specs, spec._to_components(value)) 175 ``` 176 """ 177 raise NotImplementedError("%s._component_specs()" % type(self).__name__) 178 179 # === Tensor list encoding for values === 180 181 def _to_tensor_list(self, value): 182 """Encodes `value` as a flat list of `tf.Tensor`. 183 184 By default, this just flattens `self._to_components(value)` using 185 `nest.flatten`. However, subclasses may override this to return a 186 different tensor encoding for values. In particular, some subclasses 187 of `BatchableTypeSpec` override this method to return a "boxed" encoding 188 for values, which then can be batched or unbatched. See 189 `BatchableTypeSpec` for more details. 190 191 Args: 192 value: A value with compatible this `TypeSpec`. (Caller is responsible 193 for ensuring compatibility.) 194 195 Returns: 196 A list of `tf.Tensor`, compatible with `self._flat_tensor_specs`, which 197 can be used to reconstruct `value`. 198 """ 199 return nest.flatten(self._to_components(value), expand_composites=True) 200 201 def _from_tensor_list(self, tensor_list): 202 """Reconstructs a value from a flat list of `tf.Tensor`. 203 204 Args: 205 tensor_list: A flat list of `tf.Tensor`, compatible with 206 `self._flat_tensor_specs`. 207 208 Returns: 209 A value that is compatible with this `TypeSpec`. 210 211 Raises: 212 ValueError: If `tensor_list` is not compatible with 213 `self._flat_tensor_specs`. 214 """ 215 self.__check_tensor_list(tensor_list) 216 return self._from_compatible_tensor_list(tensor_list) 217 218 def _from_compatible_tensor_list(self, tensor_list): 219 """Reconstructs a value from a compatible flat list of `tf.Tensor`. 220 221 Args: 222 tensor_list: A flat list of `tf.Tensor`, compatible with 223 `self._flat_tensor_specs`. (Caller is responsible for ensuring 224 compatibility.) 225 226 Returns: 227 A value that is compatible with this `TypeSpec`. 228 """ 229 return self._from_components(nest.pack_sequence_as( 230 self._component_specs, tensor_list, expand_composites=True)) 231 232 @property 233 def _flat_tensor_specs(self): 234 """A list of TensorSpecs compatible with self._to_tensor_list(v).""" 235 return nest.flatten(self._component_specs, expand_composites=True) 236 237 # === Serialization for types === 238 239 @abc.abstractmethod 240 def _serialize(self): 241 """Returns a nested tuple containing the state of this TypeSpec. 242 243 The serialization may contain the following value types: boolean, 244 integer, string, float, None, `TensorSpec`, `tf.TensorShape`, `tf.DType`, 245 `np.ndarray`, `TypeSpec`, and nested tuples, namedtuples, dicts, and 246 OrderedDicts of any of the above. 247 248 This method is used to provide default definitions for: equality 249 testing (__eq__, __ne__), hashing (__hash__), pickling (__reduce__), 250 string representation (__repr__), `self.is_compatible_with()`, 251 `self.most_specific_compatible_type()`, and protobuf serialization 252 (e.g. TensorInfo and StructuredValue). 253 """ 254 raise NotImplementedError("%s._serialize()" % type(self).__name__) 255 256 @classmethod 257 def _deserialize(cls, serialization): 258 """Reconstructs a TypeSpec from a value returned by `serialize`.""" 259 return cls(*serialization) 260 261 # === Operators === 262 263 def __eq__(self, other): 264 # pylint: disable=protected-access 265 return (type(other) is type(self) and 266 self.__get_cmp_key() == other.__get_cmp_key()) 267 268 def __ne__(self, other): 269 return not self == other 270 271 def __hash__(self): 272 return hash(self.__get_cmp_key()) 273 274 def __reduce__(self): 275 return type(self), self._serialize() 276 277 def __repr__(self): 278 return "%s%r" % (type(self).__name__, self._serialize()) 279 280 # === Legacy Output === 281 # TODO(b/133606651) Document and/or deprecate the legacy_output methods. 282 # (These are used by tf.data.) 283 284 def _to_legacy_output_types(self): 285 raise NotImplementedError("%s._to_legacy_output_types()" % 286 type(self).__name__) 287 288 def _to_legacy_output_shapes(self): 289 raise NotImplementedError("%s._to_legacy_output_shapes()" % 290 type(self).__name__) 291 292 def _to_legacy_output_classes(self): 293 return self.value_type 294 295 # === Private Helper Methods === 296 297 def __check_tensor_list(self, tensor_list): 298 expected = self._flat_tensor_specs 299 specs = [type_spec_from_value(t) for t in tensor_list] 300 if len(specs) != len(expected): 301 raise ValueError("Incompatible input: wrong number of tensors") 302 for i, (s1, s2) in enumerate(zip(specs, expected)): 303 if not s1.is_compatible_with(s2): 304 raise ValueError("Incompatible input: tensor %d (%s) is incompatible " 305 "with %s" % (i, tensor_list[i], s2)) 306 307 def __get_cmp_key(self): 308 """Returns a hashable eq-comparable key for `self`.""" 309 # TODO(b/133606651): Decide whether to cache this value. 310 return (type(self), self.__make_cmp_key(self._serialize())) 311 312 def __make_cmp_key(self, value): 313 """Converts `value` to a hashable key.""" 314 if isinstance(value, (int, float, bool, dtypes.DType, TypeSpec)): 315 return value 316 if isinstance(value, compat.bytes_or_text_types): 317 return value 318 if value is None: 319 return value 320 if isinstance(value, dict): 321 return tuple([ 322 tuple([self.__make_cmp_key(key), 323 self.__make_cmp_key(value[key])]) 324 for key in sorted(value.keys()) 325 ]) 326 if isinstance(value, tuple): 327 return tuple([self.__make_cmp_key(v) for v in value]) 328 if isinstance(value, list): 329 return (list, tuple([self.__make_cmp_key(v) for v in value])) 330 if isinstance(value, tensor_shape.TensorShape): 331 if value.ndims is None: 332 # Note: we include a type object in the tuple, to ensure we can't get 333 # false-positive matches (since users can't include type objects). 334 return (tensor_shape.TensorShape, None) 335 return (tensor_shape.TensorShape, tuple(value.as_list())) 336 if isinstance(value, np.ndarray): 337 return (np.ndarray, value.shape, 338 TypeSpec.__nested_list_to_tuple(value.tolist())) 339 raise ValueError("Unsupported value type %s returned by " 340 "%s._serialize" % 341 (type(value).__name__, type(self).__name__)) 342 343 @staticmethod 344 def __nested_list_to_tuple(value): 345 """Converts a nested list to a corresponding nested tuple.""" 346 if isinstance(value, list): 347 return tuple(TypeSpec.__nested_list_to_tuple(v) for v in value) 348 return value 349 350 @staticmethod 351 def __is_compatible(a, b): 352 """Returns true if the given type serializations compatible.""" 353 if type(a) is not type(b): 354 return False 355 if isinstance(a, (list, tuple)): 356 return (len(a) == len(b) and 357 all(TypeSpec.__is_compatible(x, y) for (x, y) in zip(a, b))) 358 if isinstance(a, dict): 359 return (len(a) == len(b) and sorted(a.keys()) == sorted(b.keys()) and all( 360 TypeSpec.__is_compatible(a[k], b[k]) for k in a.keys())) 361 if isinstance(a, (TypeSpec, tensor_shape.TensorShape, dtypes.DType)): 362 return a.is_compatible_with(b) 363 return a == b 364 365 @staticmethod 366 def __most_specific_compatible_type_serialization(a, b): 367 """Helper for most_specific_compatible_type. 368 369 Combines two type serializations as follows: 370 371 * If they are both tuples of the same length, then recursively combine 372 the respective tuple elements. 373 * If they are both dicts with the same keys, then recursively combine 374 the respective dict elements. 375 * If they are both TypeSpecs, then combine using 376 TypeSpec.most_specific_comptible_type. 377 * If they are both TensorShapes, then combine using 378 TensorShape.most_specific_compatible_shape. 379 * If they are both TensorSpecs with the same dtype, then combine using 380 TensorShape.most_specific_compatible_shape to combine shapes. 381 * If they are equal, then return a. 382 * If none of the above, then raise a ValueError. 383 384 Args: 385 a: A serialized TypeSpec or nested component from a serialized TypeSpec. 386 b: A serialized TypeSpec or nested component from a serialized TypeSpec. 387 388 Returns: 389 A value with the same type and structure as `a` and `b`. 390 391 Raises: 392 ValueError: If `a` and `b` are incompatible. 393 """ 394 if type(a) is not type(b): 395 raise ValueError("Types are not compatible: %r vs %r" % (a, b)) 396 if isinstance(a, (list, tuple)): 397 if len(a) != len(b): 398 raise ValueError("Types are not compatible: %r vs %r" % (a, b)) 399 return tuple(TypeSpec.__most_specific_compatible_type_serialization(x, y) 400 for (x, y) in zip(a, b)) 401 if isinstance(a, dict): 402 a_keys, b_keys = sorted(a.keys()), sorted(b.keys()) 403 if len(a) != len(b) or a_keys != b_keys: 404 raise ValueError("Types are not compatible: %r vs %r" % (a, b)) 405 return { 406 k: TypeSpec.__most_specific_compatible_type_serialization(a[k], b[k]) 407 for k in a_keys 408 } 409 if isinstance(a, tensor_shape.TensorShape): 410 return a.most_specific_compatible_shape(b) 411 if isinstance(a, list): 412 raise AssertionError("_serialize() should not return list values.") 413 if isinstance(a, TypeSpec): 414 return a.most_specific_compatible_type(b) 415 if a != b: 416 raise ValueError("Types are not compatible: %r vs %r" % (a, b)) 417 return a 418 419 420class BatchableTypeSpec(TypeSpec): 421 """TypeSpec with a batchable tensor encoding. 422 423 The batchable tensor encoding is a list of `tf.Tensor`s that supports 424 batching and unbatching. In particular, stacking (or unstacking) 425 values with the same `TypeSpec` must be equivalent to stacking (or 426 unstacking) each of their tensor lists. Unlike the component encoding 427 (returned by `self._to_components)`, the batchable tensor encoding 428 may require using encoding/decoding ops. 429 430 If a subclass's batchable tensor encoding is not simply a flattened version 431 of the component encoding, then the subclass must override `_to_tensor_list`, 432 `_from_tensor_list`, and _flat_tensor_specs`. 433 """ 434 435 __slots__ = [] 436 437 @abc.abstractmethod 438 def _batch(self, batch_size): 439 """Returns a TypeSpec representing a batch of objects with this TypeSpec. 440 441 Args: 442 batch_size: An `int` representing the number of elements in a batch, 443 or `None` if the batch size may vary. 444 445 Returns: 446 A `TypeSpec` representing a batch of objects with this TypeSpec. 447 """ 448 raise NotImplementedError("%s._batch" % type(self).__name__) 449 450 @abc.abstractmethod 451 def _unbatch(self): 452 """Returns a TypeSpec representing a single element this TypeSpec. 453 454 Returns: 455 A `TypeSpec` representing a single element of objects with this TypeSpec. 456 """ 457 raise NotImplementedError("%s._unbatch" % type(self).__name__) 458 459 def _to_batched_tensor_list(self, value): 460 """Returns a tensor list encoding for value with rank>0.""" 461 tensor_list = self._to_tensor_list(value) 462 if any(t.shape.ndims == 0 for t in tensor_list): 463 raise ValueError("Value %s has insufficient rank for batching." % value) 464 return tensor_list 465 466 467def type_spec_from_value(value): 468 """Returns a `TypeSpec` that represents the given `value`. 469 470 Args: 471 value: A value that can be accepted or returned by TensorFlow APIs. 472 473 Returns: 474 A `TypeSpec` that is compatible with `value`. 475 476 Raises: 477 TypeError: If a TypeSpec cannot be built for `value`, because its type 478 is not supported. 479 """ 480 spec = _type_spec_from_value(value) 481 if spec is not None: 482 return spec 483 484 # Fallback: try converting value to a tensor. 485 try: 486 tensor = ops.convert_to_tensor(value) 487 spec = _type_spec_from_value(tensor) 488 if spec is not None: 489 return spec 490 except (ValueError, TypeError) as e: 491 logging.vlog( 492 3, "Failed to convert %r to tensor: %s" % (type(value).__name__, e)) 493 494 raise TypeError("Could not build a TypeSpec for %r with type %s" % 495 (value, type(value).__name__)) 496 497 498def _type_spec_from_value(value): 499 """Returns a `TypeSpec` that represents the given `value`.""" 500 if isinstance(value, ops.Tensor): 501 # Note: we do not include Tensor names when constructing TypeSpecs. 502 return tensor_spec.TensorSpec(value.shape, value.dtype) 503 504 if isinstance(value, composite_tensor.CompositeTensor): 505 return value._type_spec # pylint: disable=protected-access 506 507 # If `value` is a list and all of its elements can be represented by the same 508 # batchable type spec, then we can represent the entire list using a single 509 # type spec that captures the type accurately (unlike the `convert_to_tensor` 510 # fallback). 511 if isinstance(value, list) and value: 512 subspecs = [_type_spec_from_value(v) for v in value] 513 if isinstance(subspecs[0], BatchableTypeSpec): 514 merged_subspec = subspecs[0] 515 try: 516 for subspec in subspecs[1:]: 517 merged_subspec = merged_subspec.most_specific_compatible_type(subspec) 518 return merged_subspec._batch(len(subspecs)) # pylint: disable=protected-access 519 except (ValueError, TypeError): 520 pass # incompatible subspecs 521 522 for entry in reversed(_TYPE_CONVERSION_FUNCTION_REGISTRY): 523 type_object, converter_fn, allow_subclass = entry 524 if ((type(value) is type_object) or # pylint: disable=unidiomatic-typecheck 525 (allow_subclass and isinstance(value, type_object))): 526 return converter_fn(value) 527 528 return None 529 530_TYPE_CONVERSION_FUNCTION_REGISTRY = [] 531 532 533def register_type_spec_from_value_converter(type_object, converter_fn, 534 allow_subclass=False): 535 """Registers a function for converting values with a given type to TypeSpecs. 536 537 If multiple registered `type_object`s match a value, then the most recent 538 registration takes precedence. Custom converters should not be defined for 539 `CompositeTensor`s; use `CompositeTensor._type_spec` instead. 540 541 Args: 542 type_object: A Python `type` object representing the type of values 543 accepted by `converter_fn`. 544 converter_fn: A function that takes one argument (an instance of the 545 type represented by `type_object`) and returns a `TypeSpec`. 546 allow_subclass: If true, then use `isinstance(value, type_object)` to 547 check for matches. If false, then use `type(value) is type_object`. 548 """ 549 _, type_object = tf_decorator.unwrap(type_object) 550 _TYPE_CONVERSION_FUNCTION_REGISTRY.append( 551 (type_object, converter_fn, allow_subclass)) 552 553 554_pywrap_utils.RegisterType("TypeSpec", TypeSpec) 555