1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Type specifications for TensorFlow APIs.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22import collections 23import re 24 25import numpy as np 26import six 27 28from tensorflow.python.framework import composite_tensor 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.platform import tf_logging as logging 32from tensorflow.python.util import _pywrap_utils 33from tensorflow.python.util import compat 34from tensorflow.python.util import nest 35from tensorflow.python.util import tf_decorator 36from tensorflow.python.util.lazy_loader import LazyLoader 37from tensorflow.python.util.tf_export import tf_export 38 39# Use LazyLoader to avoid circular dependencies. 40tensor_spec = LazyLoader( 41 "tensor_spec", globals(), 42 "tensorflow.python.framework.tensor_spec") 43ops = LazyLoader( 44 "ops", globals(), 45 "tensorflow.python.framework.ops") 46 47 48@tf_export("TypeSpec", v1=["TypeSpec", "data.experimental.Structure"]) 49@six.add_metaclass(abc.ABCMeta) 50class TypeSpec(object): 51 """Specifies a TensorFlow value type. 52 53 A `tf.TypeSpec` provides metadata describing an object accepted or returned 54 by TensorFlow APIs. Concrete subclasses, such as `tf.TensorSpec` and 55 `tf.RaggedTensorSpec`, are used to describe different value types. 56 57 For example, `tf.function`'s `input_signature` argument accepts a list 58 (or nested structure) of `TypeSpec`s. 59 60 Creating new subclasses of `TypeSpec` (outside of TensorFlow core) is not 61 currently supported. In particular, we may make breaking changes to the 62 private methods and properties defined by this base class. 63 64 Example: 65 66 >>> spec = tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int32) 67 >>> @tf.function(input_signature=[spec]) 68 ... def double(x): 69 ... return x * 2 70 >>> print(double(tf.ragged.constant([[1, 2], [3]]))) 71 <tf.RaggedTensor [[2, 4], [6]]> 72 """ 73 # === Subclassing === 74 # 75 # Each `TypeSpec` subclass must define: 76 # 77 # * A "component encoding" for values. 78 # * A "serialization" for types. 79 # 80 # The component encoding for a value is a nested structure of `tf.Tensor` 81 # or `CompositeTensor` that can be used by the `TypeSpec` to reconstruct 82 # the value. Each individual `TypeSpec` must use the same nested structure 83 # for all values -- this structure is defined by the `component_specs` 84 # attribute. Decomposing values into components, and reconstructing them 85 # from those components, should be inexpensive. In particular, it should 86 # *not* require any TensorFlow ops. 87 # 88 # The serialization for a `TypeSpec` is a nested tuple of values that can 89 # be used to reconstruct the `TypeSpec`. See the documentation for 90 # `_serialize()` for more information. 91 92 __slots__ = [] 93 94 @abc.abstractproperty 95 def value_type(self): 96 """The Python type for values that are compatible with this TypeSpec. 97 98 In particular, all values that are compatible with this TypeSpec must be an 99 instance of this type. 100 """ 101 raise NotImplementedError("%s.value_type" % type(self).__name__) 102 103 def is_compatible_with(self, spec_or_value): 104 """Returns true if `spec_or_value` is compatible with this TypeSpec.""" 105 # === Subclassing === 106 # If not overridden by subclasses, the default behavior is to convert 107 # `spec_or_value` to a `TypeSpec` (if it isn't already); and then to 108 # consider two `TypeSpec`s compatible if they have the same type, and 109 # the values returned by `_serialize` are compatible (where 110 # `tf.TensorShape`, `tf.TensorSpec`, and `tf.DType` are checked for 111 # compatibility using their `is_compatible_with` method; and all other 112 # types are considered compatible if they are equal). 113 if not isinstance(spec_or_value, TypeSpec): 114 spec_or_value = type_spec_from_value(spec_or_value) 115 if type(self) is not type(spec_or_value): 116 return False 117 return self.__is_compatible(self._serialize(), 118 spec_or_value._serialize()) # pylint: disable=protected-access 119 120 def most_specific_compatible_type(self, other): 121 """Returns the most specific TypeSpec compatible with `self` and `other`. 122 123 Args: 124 other: A `TypeSpec`. 125 126 Raises: 127 ValueError: If there is no TypeSpec that is compatible with both `self` 128 and `other`. 129 """ 130 # === Subclassing === 131 # If not overridden by a subclass, the default behavior is to raise a 132 # `ValueError` if `self` and `other` have different types, or if their type 133 # serializations differ by anything other than `TensorShape`s. Otherwise, 134 # the two type serializations are combined (using 135 # `most_specific_compatible_shape` to combine `TensorShape`s), and the 136 # result is used to construct and return a new `TypeSpec`. 137 if type(self) is not type(other): 138 raise ValueError("No TypeSpec is compatible with both %s and %s" % 139 (self, other)) 140 merged = self.__most_specific_compatible_type_serialization( 141 self._serialize(), other._serialize()) # pylint: disable=protected-access 142 return self._deserialize(merged) 143 144 def _with_tensor_ranks_only(self): 145 """Returns a TypeSpec compatible with `self`, with tensor shapes relaxed. 146 147 Returns: 148 A `TypeSpec` that is compatible with `self`, where any `TensorShape` 149 information has been relaxed to include only tensor rank (and not 150 the dimension sizes for individual axes). 151 """ 152 153 # === Subclassing === 154 # If not overridden by a subclass, the default behavior is to serialize 155 # this TypeSpec, relax any TensorSpec or TensorShape values, and 156 # deserialize the result. 157 158 def relax(value): 159 if isinstance(value, TypeSpec): 160 return value._with_tensor_ranks_only() # pylint: disable=protected-access 161 elif (isinstance(value, tensor_shape.TensorShape) and 162 value.rank is not None): 163 return tensor_shape.TensorShape([None] * value.rank) 164 else: 165 return value 166 167 return self._deserialize(nest.map_structure(relax, self._serialize())) 168 169 # === Component encoding for values === 170 171 @abc.abstractmethod 172 def _to_components(self, value): 173 """Encodes `value` as a nested structure of `Tensor` or `CompositeTensor`. 174 175 Args: 176 value: A value compatible with this `TypeSpec`. (Caller is responsible 177 for ensuring compatibility.) 178 179 Returns: 180 A nested structure of `tf.Tensor` or `tf.CompositeTensor` compatible with 181 `self._component_specs`, which can be used to reconstruct `value`. 182 """ 183 # === Subclassing === 184 # This method must be inexpensive (do not call TF ops). 185 raise NotImplementedError("%s._to_components()" % type(self).__name__) 186 187 @abc.abstractmethod 188 def _from_components(self, components): 189 """Reconstructs a value from a nested structure of Tensor/CompositeTensor. 190 191 Args: 192 components: A nested structure of `tf.Tensor` or `tf.CompositeTensor`, 193 compatible with `self._component_specs`. (Caller is responsible for 194 ensuring compatibility.) 195 196 Returns: 197 A value that is compatible with this `TypeSpec`. 198 """ 199 # === Subclassing === 200 # This method must be inexpensive (do not call TF ops). 201 raise NotImplementedError("%s._from_components()" % type(self).__name__) 202 203 @abc.abstractproperty 204 def _component_specs(self): 205 """A nested structure of TypeSpecs for this type's components. 206 207 Returns: 208 A nested structure describing the component encodings that are returned 209 by this TypeSpec's `_to_components` method. In particular, for a 210 TypeSpec `spec` and a compatible value `value`: 211 212 ``` 213 nest.map_structure(lambda t, c: assert t.is_compatible_with(c), 214 spec._component_specs, spec._to_components(value)) 215 ``` 216 """ 217 raise NotImplementedError("%s._component_specs()" % type(self).__name__) 218 219 # === Tensor list encoding for values === 220 221 def _to_tensor_list(self, value): 222 """Encodes `value` as a flat list of `tf.Tensor`. 223 224 By default, this just flattens `self._to_components(value)` using 225 `nest.flatten`. However, subclasses may override this to return a 226 different tensor encoding for values. In particular, some subclasses 227 of `BatchableTypeSpec` override this method to return a "boxed" encoding 228 for values, which then can be batched or unbatched. See 229 `BatchableTypeSpec` for more details. 230 231 Args: 232 value: A value with compatible this `TypeSpec`. (Caller is responsible 233 for ensuring compatibility.) 234 235 Returns: 236 A list of `tf.Tensor`, compatible with `self._flat_tensor_specs`, which 237 can be used to reconstruct `value`. 238 """ 239 return nest.flatten(self._to_components(value), expand_composites=True) 240 241 def _from_tensor_list(self, tensor_list): 242 """Reconstructs a value from a flat list of `tf.Tensor`. 243 244 Args: 245 tensor_list: A flat list of `tf.Tensor`, compatible with 246 `self._flat_tensor_specs`. 247 248 Returns: 249 A value that is compatible with this `TypeSpec`. 250 251 Raises: 252 ValueError: If `tensor_list` is not compatible with 253 `self._flat_tensor_specs`. 254 """ 255 self.__check_tensor_list(tensor_list) 256 return self._from_compatible_tensor_list(tensor_list) 257 258 def _from_compatible_tensor_list(self, tensor_list): 259 """Reconstructs a value from a compatible flat list of `tf.Tensor`. 260 261 Args: 262 tensor_list: A flat list of `tf.Tensor`, compatible with 263 `self._flat_tensor_specs`. (Caller is responsible for ensuring 264 compatibility.) 265 266 Returns: 267 A value that is compatible with this `TypeSpec`. 268 """ 269 return self._from_components(nest.pack_sequence_as( 270 self._component_specs, tensor_list, expand_composites=True)) 271 272 @property 273 def _flat_tensor_specs(self): 274 """A list of TensorSpecs compatible with self._to_tensor_list(v).""" 275 return nest.flatten(self._component_specs, expand_composites=True) 276 277 # === Serialization for types === 278 279 @abc.abstractmethod 280 def _serialize(self): 281 """Returns a nested tuple containing the state of this TypeSpec. 282 283 The serialization may contain the following value types: boolean, 284 integer, string, float, None, `TensorSpec`, `tf.TensorShape`, `tf.DType`, 285 `np.ndarray`, `TypeSpec`, and nested tuples, namedtuples, dicts, and 286 OrderedDicts of any of the above. 287 288 This method is used to provide default definitions for: equality 289 testing (__eq__, __ne__), hashing (__hash__), pickling (__reduce__), 290 string representation (__repr__), `self.is_compatible_with()`, 291 `self.most_specific_compatible_type()`, and protobuf serialization 292 (e.g. TensorInfo and StructuredValue). 293 """ 294 raise NotImplementedError("%s._serialize()" % type(self).__name__) 295 296 @classmethod 297 def _deserialize(cls, serialization): 298 """Reconstructs a TypeSpec from a value returned by `serialize`. 299 300 Args: 301 serialization: A value returned by _serialize. In some contexts, 302 `namedtuple`s in `serialization` may not have the identical type 303 that was returned by `_serialize` (but its type will still be a 304 `namedtuple` type with the same type name and field names). For 305 example, the code that loads a SavedModel does not have access to 306 the original `namedtuple` type, so it dynamically creates a new 307 `namedtuple` type with the same type name and field names as the 308 original one. If necessary, you can check `serialization` for these 309 duck-typed `nametuple` types, and restore them to the original type. 310 (E.g., this would be necessary if you rely on type checks such as 311 `isinstance` for this `TypeSpec`'s member variables). 312 313 Returns: 314 A `TypeSpec` of type `cls`. 315 """ 316 return cls(*serialization) 317 318 # === Operators === 319 320 def __eq__(self, other): 321 # pylint: disable=protected-access 322 return (type(other) is type(self) and 323 self.__get_cmp_key() == other.__get_cmp_key()) 324 325 def __ne__(self, other): 326 return not self == other 327 328 def __hash__(self): 329 return hash(self.__get_cmp_key()) 330 331 def __reduce__(self): 332 return type(self), self._serialize() 333 334 def __repr__(self): 335 return "%s%r" % (type(self).__name__, self._serialize()) 336 337 # === Legacy Output === 338 # TODO(b/133606651) Document and/or deprecate the legacy_output methods. 339 # (These are used by tf.data.) 340 341 def _to_legacy_output_types(self): 342 raise NotImplementedError("%s._to_legacy_output_types()" % 343 type(self).__name__) 344 345 def _to_legacy_output_shapes(self): 346 raise NotImplementedError("%s._to_legacy_output_shapes()" % 347 type(self).__name__) 348 349 def _to_legacy_output_classes(self): 350 return self.value_type 351 352 # === Private Helper Methods === 353 354 def __check_tensor_list(self, tensor_list): 355 expected = self._flat_tensor_specs 356 specs = [type_spec_from_value(t) for t in tensor_list] 357 if len(specs) != len(expected): 358 raise ValueError("Incompatible input: wrong number of tensors") 359 for i, (s1, s2) in enumerate(zip(specs, expected)): 360 if not s1.is_compatible_with(s2): 361 raise ValueError("Incompatible input: tensor %d (%s) is incompatible " 362 "with %s" % (i, tensor_list[i], s2)) 363 364 def __get_cmp_key(self): 365 """Returns a hashable eq-comparable key for `self`.""" 366 # TODO(b/133606651): Decide whether to cache this value. 367 return (type(self), self.__make_cmp_key(self._serialize())) 368 369 def __make_cmp_key(self, value): 370 """Converts `value` to a hashable key.""" 371 if isinstance(value, 372 (int, float, bool, np.generic, dtypes.DType, TypeSpec)): 373 return value 374 if isinstance(value, compat.bytes_or_text_types): 375 return value 376 if value is None: 377 return value 378 if isinstance(value, dict): 379 return tuple([ 380 tuple([self.__make_cmp_key(key), 381 self.__make_cmp_key(value[key])]) 382 for key in sorted(value.keys()) 383 ]) 384 if isinstance(value, tuple): 385 return tuple([self.__make_cmp_key(v) for v in value]) 386 if isinstance(value, list): 387 return (list, tuple([self.__make_cmp_key(v) for v in value])) 388 if isinstance(value, tensor_shape.TensorShape): 389 if value.ndims is None: 390 # Note: we include a type object in the tuple, to ensure we can't get 391 # false-positive matches (since users can't include type objects). 392 return (tensor_shape.TensorShape, None) 393 return (tensor_shape.TensorShape, tuple(value.as_list())) 394 if isinstance(value, np.ndarray): 395 return (np.ndarray, value.shape, 396 TypeSpec.__nested_list_to_tuple(value.tolist())) 397 raise ValueError("Unsupported value type %s returned by " 398 "%s._serialize" % 399 (type(value).__name__, type(self).__name__)) 400 401 @staticmethod 402 def __nested_list_to_tuple(value): 403 """Converts a nested list to a corresponding nested tuple.""" 404 if isinstance(value, list): 405 return tuple(TypeSpec.__nested_list_to_tuple(v) for v in value) 406 return value 407 408 @staticmethod 409 def __same_types(a, b): 410 """Returns whether a and b have the same type, up to namedtuple equivalence. 411 412 Consistent with tf.nest.assert_same_structure(), two namedtuple types 413 are considered the same iff they agree in their class name (without 414 qualification by module name) and in their sequence of field names. 415 This makes namedtuples recreated by StructureCoder compatible with their 416 original Python definition. 417 418 Args: 419 a: a Python object. 420 b: a Python object. 421 422 Returns: 423 A boolean that is true iff type(a) and type(b) are the same object 424 or equivalent namedtuple types. 425 """ 426 if nest.is_namedtuple(a) and nest.is_namedtuple(b): 427 return nest.same_namedtuples(a, b) 428 else: 429 return type(a) is type(b) 430 431 @staticmethod 432 def __is_compatible(a, b): 433 """Returns true if the given type serializations compatible.""" 434 if isinstance(a, TypeSpec): 435 return a.is_compatible_with(b) 436 if not TypeSpec.__same_types(a, b): 437 return False 438 if isinstance(a, (list, tuple)): 439 return (len(a) == len(b) and 440 all(TypeSpec.__is_compatible(x, y) for (x, y) in zip(a, b))) 441 if isinstance(a, dict): 442 return (len(a) == len(b) and sorted(a.keys()) == sorted(b.keys()) and all( 443 TypeSpec.__is_compatible(a[k], b[k]) for k in a.keys())) 444 if isinstance(a, (tensor_shape.TensorShape, dtypes.DType)): 445 return a.is_compatible_with(b) 446 return a == b 447 448 @staticmethod 449 def __most_specific_compatible_type_serialization(a, b): 450 """Helper for most_specific_compatible_type. 451 452 Combines two type serializations as follows: 453 454 * If they are both tuples of the same length, then recursively combine 455 the respective tuple elements. 456 * If they are both dicts with the same keys, then recursively combine 457 the respective dict elements. 458 * If they are both TypeSpecs, then combine using 459 TypeSpec.most_specific_compatible_type. 460 * If they are both TensorShapes, then combine using 461 TensorShape.most_specific_compatible_shape. 462 * If they are both TensorSpecs with the same dtype, then combine using 463 TensorShape.most_specific_compatible_shape to combine shapes. 464 * If they are equal, then return a. 465 * If none of the above, then raise a ValueError. 466 467 Args: 468 a: A serialized TypeSpec or nested component from a serialized TypeSpec. 469 b: A serialized TypeSpec or nested component from a serialized TypeSpec. 470 471 Returns: 472 A value with the same type and structure as `a` and `b`. 473 474 Raises: 475 ValueError: If `a` and `b` are incompatible. 476 """ 477 if not TypeSpec.__same_types(a, b): 478 raise ValueError("Types are not compatible: %r vs %r" % (a, b)) 479 if nest.is_namedtuple(a): 480 assert a._fields == b._fields # Implied by __same_types(a, b). 481 return type(a)(*[ 482 TypeSpec.__most_specific_compatible_type_serialization(x, y) 483 for (x, y) in zip(a, b)]) 484 if isinstance(a, (list, tuple)): 485 if len(a) != len(b): 486 raise ValueError("Types are not compatible: %r vs %r" % (a, b)) 487 return tuple(TypeSpec.__most_specific_compatible_type_serialization(x, y) 488 for (x, y) in zip(a, b)) 489 if isinstance(a, collections.OrderedDict): 490 a_keys, b_keys = a.keys(), b.keys() 491 if len(a) != len(b) or a_keys != b_keys: 492 raise ValueError("Types are not compatible: %r vs %r" % (a, b)) 493 return collections.OrderedDict([ 494 (k, 495 TypeSpec.__most_specific_compatible_type_serialization(a[k], b[k])) 496 for k in a_keys 497 ]) 498 if isinstance(a, dict): 499 a_keys, b_keys = sorted(a.keys()), sorted(b.keys()) 500 if len(a) != len(b) or a_keys != b_keys: 501 raise ValueError("Types are not compatible: %r vs %r" % (a, b)) 502 return { 503 k: TypeSpec.__most_specific_compatible_type_serialization(a[k], b[k]) 504 for k in a_keys 505 } 506 if isinstance(a, tensor_shape.TensorShape): 507 return a.most_specific_compatible_shape(b) 508 if isinstance(a, list): 509 raise AssertionError("_serialize() should not return list values.") 510 if isinstance(a, TypeSpec): 511 return a.most_specific_compatible_type(b) 512 if a != b: 513 raise ValueError("Types are not compatible: %r vs %r" % (a, b)) 514 return a 515 516 517class BatchableTypeSpec(TypeSpec): 518 """TypeSpec with a batchable tensor encoding. 519 520 The batchable tensor encoding is a list of `tf.Tensor`s that supports 521 batching and unbatching. In particular, stacking (or unstacking) 522 values with the same `TypeSpec` must be equivalent to stacking (or 523 unstacking) each of their tensor lists. Unlike the component encoding 524 (returned by `self._to_components)`, the batchable tensor encoding 525 may require using encoding/decoding ops. 526 527 If a subclass's batchable tensor encoding is not simply a flattened version 528 of the component encoding, then the subclass must override `_to_tensor_list`, 529 `_from_tensor_list`, and _flat_tensor_specs`. 530 """ 531 532 __slots__ = [] 533 534 @abc.abstractmethod 535 def _batch(self, batch_size): 536 """Returns a TypeSpec representing a batch of objects with this TypeSpec. 537 538 Args: 539 batch_size: An `int` representing the number of elements in a batch, 540 or `None` if the batch size may vary. 541 542 Returns: 543 A `TypeSpec` representing a batch of objects with this TypeSpec. 544 """ 545 raise NotImplementedError("%s._batch" % type(self).__name__) 546 547 @abc.abstractmethod 548 def _unbatch(self): 549 """Returns a TypeSpec representing a single element this TypeSpec. 550 551 Returns: 552 A `TypeSpec` representing a single element of objects with this TypeSpec. 553 """ 554 raise NotImplementedError("%s._unbatch" % type(self).__name__) 555 556 def _to_batched_tensor_list(self, value): 557 """Returns a tensor list encoding for value with rank>0.""" 558 tensor_list = self._to_tensor_list(value) 559 if any(t.shape.ndims == 0 for t in tensor_list): 560 raise ValueError("Value %s has insufficient rank for batching." % value) 561 return tensor_list 562 563 564@tf_export("type_spec_from_value") 565def type_spec_from_value(value): 566 """Returns a `tf.TypeSpec` that represents the given `value`. 567 568 Examples: 569 570 >>> tf.type_spec_from_value(tf.constant([1, 2, 3])) 571 TensorSpec(shape=(3,), dtype=tf.int32, name=None) 572 >>> tf.type_spec_from_value(np.array([4.0, 5.0], np.float64)) 573 TensorSpec(shape=(2,), dtype=tf.float64, name=None) 574 >>> tf.type_spec_from_value(tf.ragged.constant([[1, 2], [3, 4, 5]])) 575 RaggedTensorSpec(TensorShape([2, None]), tf.int32, 1, tf.int64) 576 577 >>> example_input = tf.ragged.constant([[1, 2], [3]]) 578 >>> @tf.function(input_signature=[tf.type_spec_from_value(example_input)]) 579 ... def f(x): 580 ... return tf.reduce_sum(x, axis=1) 581 582 Args: 583 value: A value that can be accepted or returned by TensorFlow APIs. 584 Accepted types for `value` include `tf.Tensor`, any value that can be 585 converted to `tf.Tensor` using `tf.convert_to_tensor`, and any subclass 586 of `CompositeTensor` (such as `tf.RaggedTensor`). 587 588 Returns: 589 A `TypeSpec` that is compatible with `value`. 590 591 Raises: 592 TypeError: If a TypeSpec cannot be built for `value`, because its type 593 is not supported. 594 """ 595 spec = _type_spec_from_value(value) 596 if spec is not None: 597 return spec 598 599 # Fallback: try converting value to a tensor. 600 try: 601 tensor = ops.convert_to_tensor(value) 602 spec = _type_spec_from_value(tensor) 603 if spec is not None: 604 return spec 605 except (ValueError, TypeError) as e: 606 logging.vlog( 607 3, "Failed to convert %r to tensor: %s" % (type(value).__name__, e)) 608 609 raise TypeError("Could not build a TypeSpec for %r with type %s" % 610 (value, type(value).__name__)) 611 612 613def _type_spec_from_value(value): 614 """Returns a `TypeSpec` that represents the given `value`.""" 615 if isinstance(value, ops.Tensor): 616 # Note: we do not include Tensor names when constructing TypeSpecs. 617 return tensor_spec.TensorSpec(value.shape, value.dtype) 618 619 if isinstance(value, composite_tensor.CompositeTensor): 620 return value._type_spec # pylint: disable=protected-access 621 622 # If `value` is a list and all of its elements can be represented by the same 623 # batchable type spec, then we can represent the entire list using a single 624 # type spec that captures the type accurately (unlike the `convert_to_tensor` 625 # fallback). 626 if isinstance(value, list) and value: 627 subspecs = [_type_spec_from_value(v) for v in value] 628 if isinstance(subspecs[0], BatchableTypeSpec): 629 merged_subspec = subspecs[0] 630 try: 631 for subspec in subspecs[1:]: 632 merged_subspec = merged_subspec.most_specific_compatible_type(subspec) 633 return merged_subspec._batch(len(subspecs)) # pylint: disable=protected-access 634 except (ValueError, TypeError): 635 pass # incompatible subspecs 636 637 for entry in reversed(_TYPE_CONVERSION_FUNCTION_REGISTRY): 638 type_object, converter_fn, allow_subclass = entry 639 if ((type(value) is type_object) or # pylint: disable=unidiomatic-typecheck 640 (allow_subclass and isinstance(value, type_object))): 641 return converter_fn(value) 642 643 return None 644 645_TYPE_CONVERSION_FUNCTION_REGISTRY = [] 646 647 648def register_type_spec_from_value_converter(type_object, converter_fn, 649 allow_subclass=False): 650 """Registers a function for converting values with a given type to TypeSpecs. 651 652 If multiple registered `type_object`s match a value, then the most recent 653 registration takes precedence. Custom converters should not be defined for 654 `CompositeTensor`s; use `CompositeTensor._type_spec` instead. 655 656 Args: 657 type_object: A Python `type` object representing the type of values 658 accepted by `converter_fn`. 659 converter_fn: A function that takes one argument (an instance of the 660 type represented by `type_object`) and returns a `TypeSpec`. 661 allow_subclass: If true, then use `isinstance(value, type_object)` to 662 check for matches. If false, then use `type(value) is type_object`. 663 """ 664 _, type_object = tf_decorator.unwrap(type_object) 665 _TYPE_CONVERSION_FUNCTION_REGISTRY.append( 666 (type_object, converter_fn, allow_subclass)) 667 668 669_pywrap_utils.RegisterType("TypeSpec", TypeSpec) 670 671 672_TYPE_SPEC_TO_NAME = {} 673_NAME_TO_TYPE_SPEC = {} 674 675 676# Regular expression for valid TypeSpec names. 677_REGISTERED_NAME_RE = re.compile(r"^(\w+\.)+\w+$") 678 679 680# TODO(b/173744905) tf_export this as "tf.register_type_spec". (And add a 681# usage example to the docstring, once the API is public.) 682# 683# TODO(b/173744905) Update this decorator to apply to ExtensionType rather than 684# TypeSpec (once we do refactoring to move to_components/from_components from 685# TypeSpec to ExtensionType). 686def register(name): 687 """Decorator used to register a globally unique name for a TypeSpec subclass. 688 689 Args: 690 name: The name of the type spec. Must be globally unique. Must have 691 the form `"{project_name}.{type_name}"`. E.g. `"my_project.MyTypeSpec"`. 692 693 Returns: 694 A class decorator that registers the decorated class with the given name. 695 """ 696 if not isinstance(name, str): 697 raise TypeError("Expected `name` to be a string; got %r" % (name,)) 698 if not _REGISTERED_NAME_RE.match(name): 699 raise ValueError( 700 "Registered name must have the form '{project_name}.{type_name}' " 701 "(e.g. 'my_project.MyTypeSpec'); got %r." % name) 702 703 def decorator_fn(cls): 704 if not (isinstance(cls, type) and issubclass(cls, TypeSpec)): 705 raise TypeError("Expected `cls` to be a TypeSpec; got %r" % (cls,)) 706 if cls in _TYPE_SPEC_TO_NAME: 707 raise ValueError("Class %s.%s has already been registered with name %s." 708 % (cls.__module__, cls.__name__, 709 _TYPE_SPEC_TO_NAME[cls])) 710 if name in _NAME_TO_TYPE_SPEC: 711 raise ValueError("Name %s has already been registered for class %s.%s." 712 % (name, _NAME_TO_TYPE_SPEC[name].__module__, 713 _NAME_TO_TYPE_SPEC[name].__name__)) 714 _TYPE_SPEC_TO_NAME[cls] = name 715 _NAME_TO_TYPE_SPEC[name] = cls 716 return cls 717 718 return decorator_fn 719 720 721# TODO(edloper) tf_export this as "tf.get_type_spec_name" (or some similar name) 722def get_name(cls): 723 """Returns the registered name for TypeSpec `cls`.""" 724 if not (isinstance(cls, type) and issubclass(cls, TypeSpec)): 725 raise TypeError("Expected `cls` to be a TypeSpec; got %r" % (cls,)) 726 if cls not in _TYPE_SPEC_TO_NAME: 727 raise ValueError("TypeSpec %s.%s has not been registered." % 728 (cls.__module__, cls.__name__)) 729 return _TYPE_SPEC_TO_NAME[cls] 730 731 732# TODO(edloper) tf_export this as "tf.lookup_type_spec" (or some similar name) 733def lookup(name): 734 """Returns the TypeSpec that has been registered with name `name`.""" 735 if not isinstance(name, str): 736 raise TypeError("Expected `name` to be a string; got %r" % (name,)) 737 if name not in _NAME_TO_TYPE_SPEC: 738 raise ValueError("No TypeSpec has been registered with name %r" % (name,)) 739 return _NAME_TO_TYPE_SPEC[name] 740