1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Type specifications for TensorFlow APIs.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22import collections 23import re 24 25import numpy as np 26import six 27 28from tensorflow.python.framework import composite_tensor 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.platform import tf_logging as logging 32from tensorflow.python.util import _pywrap_utils 33from tensorflow.python.util import compat 34from tensorflow.python.util import nest 35from tensorflow.python.util import tf_decorator 36from tensorflow.python.util.compat import collections_abc 37from tensorflow.python.util.lazy_loader import LazyLoader 38from tensorflow.python.util.tf_export import tf_export 39 40# Use LazyLoader to avoid circular dependencies. 41tensor_spec = LazyLoader( 42 "tensor_spec", globals(), 43 "tensorflow.python.framework.tensor_spec") 44ops = LazyLoader( 45 "ops", globals(), 46 "tensorflow.python.framework.ops") 47 48 49@tf_export("TypeSpec", v1=["TypeSpec", "data.experimental.Structure"]) 50@six.add_metaclass(abc.ABCMeta) 51class TypeSpec(object): 52 """Specifies a TensorFlow value type. 53 54 A `tf.TypeSpec` provides metadata describing an object accepted or returned 55 by TensorFlow APIs. Concrete subclasses, such as `tf.TensorSpec` and 56 `tf.RaggedTensorSpec`, are used to describe different value types. 57 58 For example, `tf.function`'s `input_signature` argument accepts a list 59 (or nested structure) of `TypeSpec`s. 60 61 Creating new subclasses of `TypeSpec` (outside of TensorFlow core) is not 62 currently supported. In particular, we may make breaking changes to the 63 private methods and properties defined by this base class. 64 65 Example: 66 67 >>> spec = tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int32) 68 >>> @tf.function(input_signature=[spec]) 69 ... def double(x): 70 ... return x * 2 71 >>> print(double(tf.ragged.constant([[1, 2], [3]]))) 72 <tf.RaggedTensor [[2, 4], [6]]> 73 """ 74 # === Subclassing === 75 # 76 # Each `TypeSpec` subclass must define: 77 # 78 # * A "component encoding" for values. 79 # * A "serialization" for types. 80 # 81 # The component encoding for a value is a nested structure of `tf.Tensor` 82 # or `CompositeTensor` that can be used by the `TypeSpec` to reconstruct 83 # the value. Each individual `TypeSpec` must use the same nested structure 84 # for all values -- this structure is defined by the `component_specs` 85 # attribute. Decomposing values into components, and reconstructing them 86 # from those components, should be inexpensive. In particular, it should 87 # *not* require any TensorFlow ops. 88 # 89 # The serialization for a `TypeSpec` is a nested tuple of values that can 90 # be used to reconstruct the `TypeSpec`. See the documentation for 91 # `_serialize()` for more information. 92 93 __slots__ = [] 94 95 @abc.abstractproperty 96 def value_type(self): 97 """The Python type for values that are compatible with this TypeSpec. 98 99 In particular, all values that are compatible with this TypeSpec must be an 100 instance of this type. 101 """ 102 raise NotImplementedError("%s.value_type" % type(self).__name__) 103 104 def is_compatible_with(self, spec_or_value): 105 """Returns true if `spec_or_value` is compatible with this TypeSpec.""" 106 # === Subclassing === 107 # If not overridden by subclasses, the default behavior is to convert 108 # `spec_or_value` to a `TypeSpec` (if it isn't already); and then to 109 # consider two `TypeSpec`s compatible if they have the same type, and 110 # the values returned by `_serialize` are compatible (where 111 # `tf.TensorShape`, `tf.TensorSpec`, and `tf.DType` are checked for 112 # compatibility using their `is_compatible_with` method; and all other 113 # types are considered compatible if they are equal). 114 if not isinstance(spec_or_value, TypeSpec): 115 spec_or_value = type_spec_from_value(spec_or_value) 116 if type(self) is not type(spec_or_value): 117 return False 118 return self.__is_compatible(self._serialize(), 119 spec_or_value._serialize()) # pylint: disable=protected-access 120 121 def most_specific_compatible_type(self, other): 122 """Returns the most specific TypeSpec compatible with `self` and `other`. 123 124 Args: 125 other: A `TypeSpec`. 126 127 Raises: 128 ValueError: If there is no TypeSpec that is compatible with both `self` 129 and `other`. 130 """ 131 # === Subclassing === 132 # If not overridden by a subclass, the default behavior is to raise a 133 # `ValueError` if `self` and `other` have different types, or if their type 134 # serializations differ by anything other than `TensorShape`s. Otherwise, 135 # the two type serializations are combined (using 136 # `most_specific_compatible_shape` to combine `TensorShape`s), and the 137 # result is used to construct and return a new `TypeSpec`. 138 if type(self) is not type(other): 139 raise ValueError("No TypeSpec is compatible with both %s and %s" % 140 (self, other)) 141 merged = self.__most_specific_compatible_type_serialization( 142 self._serialize(), other._serialize()) # pylint: disable=protected-access 143 return self._deserialize(merged) 144 145 def _with_tensor_ranks_only(self): 146 """Returns a TypeSpec compatible with `self`, with tensor shapes relaxed. 147 148 Returns: 149 A `TypeSpec` that is compatible with `self`, where any `TensorShape` 150 information has been relaxed to include only tensor rank (and not 151 the dimension sizes for individual axes). 152 """ 153 154 # === Subclassing === 155 # If not overridden by a subclass, the default behavior is to serialize 156 # this TypeSpec, relax any TensorSpec or TensorShape values, and 157 # deserialize the result. 158 159 def relax(value): 160 if isinstance(value, TypeSpec): 161 return value._with_tensor_ranks_only() # pylint: disable=protected-access 162 elif (isinstance(value, tensor_shape.TensorShape) and 163 value.rank is not None): 164 return tensor_shape.TensorShape([None] * value.rank) 165 else: 166 return value 167 168 return self._deserialize(nest.map_structure(relax, self._serialize())) 169 170 # === Component encoding for values === 171 172 @abc.abstractmethod 173 def _to_components(self, value): 174 """Encodes `value` as a nested structure of `Tensor` or `CompositeTensor`. 175 176 Args: 177 value: A value compatible with this `TypeSpec`. (Caller is responsible 178 for ensuring compatibility.) 179 180 Returns: 181 A nested structure of `tf.Tensor` or `tf.CompositeTensor` compatible with 182 `self._component_specs`, which can be used to reconstruct `value`. 183 """ 184 # === Subclassing === 185 # This method must be inexpensive (do not call TF ops). 186 raise NotImplementedError("%s._to_components()" % type(self).__name__) 187 188 @abc.abstractmethod 189 def _from_components(self, components): 190 """Reconstructs a value from a nested structure of Tensor/CompositeTensor. 191 192 Args: 193 components: A nested structure of `tf.Tensor` or `tf.CompositeTensor`, 194 compatible with `self._component_specs`. (Caller is responsible for 195 ensuring compatibility.) 196 197 Returns: 198 A value that is compatible with this `TypeSpec`. 199 """ 200 # === Subclassing === 201 # This method must be inexpensive (do not call TF ops). 202 raise NotImplementedError("%s._from_components()" % type(self).__name__) 203 204 @abc.abstractproperty 205 def _component_specs(self): 206 """A nested structure of TypeSpecs for this type's components. 207 208 Returns: 209 A nested structure describing the component encodings that are returned 210 by this TypeSpec's `_to_components` method. In particular, for a 211 TypeSpec `spec` and a compatible value `value`: 212 213 ``` 214 nest.map_structure(lambda t, c: assert t.is_compatible_with(c), 215 spec._component_specs, spec._to_components(value)) 216 ``` 217 """ 218 raise NotImplementedError("%s._component_specs()" % type(self).__name__) 219 220 # === Tensor list encoding for values === 221 222 def _to_tensor_list(self, value): 223 """Encodes `value` as a flat list of `tf.Tensor`. 224 225 By default, this just flattens `self._to_components(value)` using 226 `nest.flatten`. However, subclasses may override this to return a 227 different tensor encoding for values. In particular, some subclasses 228 of `BatchableTypeSpec` override this method to return a "boxed" encoding 229 for values, which then can be batched or unbatched. See 230 `BatchableTypeSpec` for more details. 231 232 Args: 233 value: A value with compatible this `TypeSpec`. (Caller is responsible 234 for ensuring compatibility.) 235 236 Returns: 237 A list of `tf.Tensor`, compatible with `self._flat_tensor_specs`, which 238 can be used to reconstruct `value`. 239 """ 240 return nest.flatten(self._to_components(value), expand_composites=True) 241 242 def _from_tensor_list(self, tensor_list): 243 """Reconstructs a value from a flat list of `tf.Tensor`. 244 245 Args: 246 tensor_list: A flat list of `tf.Tensor`, compatible with 247 `self._flat_tensor_specs`. 248 249 Returns: 250 A value that is compatible with this `TypeSpec`. 251 252 Raises: 253 ValueError: If `tensor_list` is not compatible with 254 `self._flat_tensor_specs`. 255 """ 256 self.__check_tensor_list(tensor_list) 257 return self._from_compatible_tensor_list(tensor_list) 258 259 def _from_compatible_tensor_list(self, tensor_list): 260 """Reconstructs a value from a compatible flat list of `tf.Tensor`. 261 262 Args: 263 tensor_list: A flat list of `tf.Tensor`, compatible with 264 `self._flat_tensor_specs`. (Caller is responsible for ensuring 265 compatibility.) 266 267 Returns: 268 A value that is compatible with this `TypeSpec`. 269 """ 270 return self._from_components(nest.pack_sequence_as( 271 self._component_specs, tensor_list, expand_composites=True)) 272 273 @property 274 def _flat_tensor_specs(self): 275 """A list of TensorSpecs compatible with self._to_tensor_list(v).""" 276 return nest.flatten(self._component_specs, expand_composites=True) 277 278 # === Serialization for types === 279 280 @abc.abstractmethod 281 def _serialize(self): 282 """Returns a nested tuple containing the state of this TypeSpec. 283 284 The serialization may contain the following value types: boolean, 285 integer, string, float, None, `TensorSpec`, `tf.TensorShape`, `tf.DType`, 286 `np.ndarray`, `TypeSpec`, and nested tuples, namedtuples, dicts, and 287 OrderedDicts of any of the above. 288 289 This method is used to provide default definitions for: equality 290 testing (__eq__, __ne__), hashing (__hash__), pickling (__reduce__), 291 string representation (__repr__), `self.is_compatible_with()`, 292 `self.most_specific_compatible_type()`, and protobuf serialization 293 (e.g. TensorInfo and StructuredValue). 294 """ 295 raise NotImplementedError("%s._serialize()" % type(self).__name__) 296 297 @classmethod 298 def _deserialize(cls, serialization): 299 """Reconstructs a TypeSpec from a value returned by `serialize`.""" 300 return cls(*serialization) 301 302 # === Operators === 303 304 def __eq__(self, other): 305 # pylint: disable=protected-access 306 return (type(other) is type(self) and 307 self.__get_cmp_key() == other.__get_cmp_key()) 308 309 def __ne__(self, other): 310 return not self == other 311 312 def __hash__(self): 313 return hash(self.__get_cmp_key()) 314 315 def __reduce__(self): 316 return type(self), self._serialize() 317 318 def __repr__(self): 319 return "%s%r" % (type(self).__name__, self._serialize()) 320 321 # === Legacy Output === 322 # TODO(b/133606651) Document and/or deprecate the legacy_output methods. 323 # (These are used by tf.data.) 324 325 def _to_legacy_output_types(self): 326 raise NotImplementedError("%s._to_legacy_output_types()" % 327 type(self).__name__) 328 329 def _to_legacy_output_shapes(self): 330 raise NotImplementedError("%s._to_legacy_output_shapes()" % 331 type(self).__name__) 332 333 def _to_legacy_output_classes(self): 334 return self.value_type 335 336 # === Private Helper Methods === 337 338 def __check_tensor_list(self, tensor_list): 339 expected = self._flat_tensor_specs 340 specs = [type_spec_from_value(t) for t in tensor_list] 341 if len(specs) != len(expected): 342 raise ValueError("Incompatible input: wrong number of tensors") 343 for i, (s1, s2) in enumerate(zip(specs, expected)): 344 if not s1.is_compatible_with(s2): 345 raise ValueError("Incompatible input: tensor %d (%s) is incompatible " 346 "with %s" % (i, tensor_list[i], s2)) 347 348 def __get_cmp_key(self): 349 """Returns a hashable eq-comparable key for `self`.""" 350 # TODO(b/133606651): Decide whether to cache this value. 351 return (type(self), self.__make_cmp_key(self._serialize())) 352 353 def __make_cmp_key(self, value): 354 """Converts `value` to a hashable key.""" 355 if isinstance(value, 356 (int, float, bool, np.generic, dtypes.DType, TypeSpec)): 357 return value 358 if isinstance(value, compat.bytes_or_text_types): 359 return value 360 if value is None: 361 return value 362 if isinstance(value, dict): 363 return tuple([ 364 tuple([self.__make_cmp_key(key), 365 self.__make_cmp_key(value[key])]) 366 for key in sorted(value.keys()) 367 ]) 368 if isinstance(value, tuple): 369 return tuple([self.__make_cmp_key(v) for v in value]) 370 if isinstance(value, list): 371 return (list, tuple([self.__make_cmp_key(v) for v in value])) 372 if isinstance(value, tensor_shape.TensorShape): 373 if value.ndims is None: 374 # Note: we include a type object in the tuple, to ensure we can't get 375 # false-positive matches (since users can't include type objects). 376 return (tensor_shape.TensorShape, None) 377 return (tensor_shape.TensorShape, tuple(value.as_list())) 378 if isinstance(value, np.ndarray): 379 return (np.ndarray, value.shape, 380 TypeSpec.__nested_list_to_tuple(value.tolist())) 381 raise ValueError("Unsupported value type %s returned by " 382 "%s._serialize" % 383 (type(value).__name__, type(self).__name__)) 384 385 @staticmethod 386 def __nested_list_to_tuple(value): 387 """Converts a nested list to a corresponding nested tuple.""" 388 if isinstance(value, list): 389 return tuple(TypeSpec.__nested_list_to_tuple(v) for v in value) 390 return value 391 392 @staticmethod 393 def __is_compatible(a, b): 394 """Returns true if the given type serializations compatible.""" 395 if isinstance(a, TypeSpec): 396 return a.is_compatible_with(b) 397 if type(a) is not type(b): 398 return False 399 if isinstance(a, (list, tuple)): 400 return (len(a) == len(b) and 401 all(TypeSpec.__is_compatible(x, y) for (x, y) in zip(a, b))) 402 if isinstance(a, dict): 403 return (len(a) == len(b) and sorted(a.keys()) == sorted(b.keys()) and all( 404 TypeSpec.__is_compatible(a[k], b[k]) for k in a.keys())) 405 if isinstance(a, (tensor_shape.TensorShape, dtypes.DType)): 406 return a.is_compatible_with(b) 407 return a == b 408 409 @staticmethod 410 def __is_named_tuple(t): 411 """Returns true if the given tuple t is a namedtuple.""" 412 return (hasattr(t, "_fields") and 413 isinstance(t._fields, collections_abc.Sequence) and 414 all(isinstance(f, six.string_types) for f in t._fields)) 415 416 @staticmethod 417 def __most_specific_compatible_type_serialization(a, b): 418 """Helper for most_specific_compatible_type. 419 420 Combines two type serializations as follows: 421 422 * If they are both tuples of the same length, then recursively combine 423 the respective tuple elements. 424 * If they are both dicts with the same keys, then recursively combine 425 the respective dict elements. 426 * If they are both TypeSpecs, then combine using 427 TypeSpec.most_specific_compatible_type. 428 * If they are both TensorShapes, then combine using 429 TensorShape.most_specific_compatible_shape. 430 * If they are both TensorSpecs with the same dtype, then combine using 431 TensorShape.most_specific_compatible_shape to combine shapes. 432 * If they are equal, then return a. 433 * If none of the above, then raise a ValueError. 434 435 Args: 436 a: A serialized TypeSpec or nested component from a serialized TypeSpec. 437 b: A serialized TypeSpec or nested component from a serialized TypeSpec. 438 439 Returns: 440 A value with the same type and structure as `a` and `b`. 441 442 Raises: 443 ValueError: If `a` and `b` are incompatible. 444 """ 445 if type(a) is not type(b): 446 raise ValueError("Types are not compatible: %r vs %r" % (a, b)) 447 if isinstance(a, (list, tuple)): 448 if len(a) != len(b): 449 raise ValueError("Types are not compatible: %r vs %r" % (a, b)) 450 if TypeSpec.__is_named_tuple(a): 451 if not hasattr(b, "_fields") or not isinstance( 452 b._fields, collections_abc.Sequence) or a._fields != b._fields: 453 raise ValueError("Types are not compatible: %r vs %r" % (a, b)) 454 return type(a)(*[ 455 TypeSpec.__most_specific_compatible_type_serialization(x, y) 456 for (x, y) in zip(a, b)]) 457 return tuple(TypeSpec.__most_specific_compatible_type_serialization(x, y) 458 for (x, y) in zip(a, b)) 459 if isinstance(a, collections.OrderedDict): 460 a_keys, b_keys = a.keys(), b.keys() 461 if len(a) != len(b) or a_keys != b_keys: 462 raise ValueError("Types are not compatible: %r vs %r" % (a, b)) 463 return collections.OrderedDict([ 464 (k, 465 TypeSpec.__most_specific_compatible_type_serialization(a[k], b[k])) 466 for k in a_keys 467 ]) 468 if isinstance(a, dict): 469 a_keys, b_keys = sorted(a.keys()), sorted(b.keys()) 470 if len(a) != len(b) or a_keys != b_keys: 471 raise ValueError("Types are not compatible: %r vs %r" % (a, b)) 472 return { 473 k: TypeSpec.__most_specific_compatible_type_serialization(a[k], b[k]) 474 for k in a_keys 475 } 476 if isinstance(a, tensor_shape.TensorShape): 477 return a.most_specific_compatible_shape(b) 478 if isinstance(a, list): 479 raise AssertionError("_serialize() should not return list values.") 480 if isinstance(a, TypeSpec): 481 return a.most_specific_compatible_type(b) 482 if a != b: 483 raise ValueError("Types are not compatible: %r vs %r" % (a, b)) 484 return a 485 486 487class BatchableTypeSpec(TypeSpec): 488 """TypeSpec with a batchable tensor encoding. 489 490 The batchable tensor encoding is a list of `tf.Tensor`s that supports 491 batching and unbatching. In particular, stacking (or unstacking) 492 values with the same `TypeSpec` must be equivalent to stacking (or 493 unstacking) each of their tensor lists. Unlike the component encoding 494 (returned by `self._to_components)`, the batchable tensor encoding 495 may require using encoding/decoding ops. 496 497 If a subclass's batchable tensor encoding is not simply a flattened version 498 of the component encoding, then the subclass must override `_to_tensor_list`, 499 `_from_tensor_list`, and _flat_tensor_specs`. 500 """ 501 502 __slots__ = [] 503 504 @abc.abstractmethod 505 def _batch(self, batch_size): 506 """Returns a TypeSpec representing a batch of objects with this TypeSpec. 507 508 Args: 509 batch_size: An `int` representing the number of elements in a batch, 510 or `None` if the batch size may vary. 511 512 Returns: 513 A `TypeSpec` representing a batch of objects with this TypeSpec. 514 """ 515 raise NotImplementedError("%s._batch" % type(self).__name__) 516 517 @abc.abstractmethod 518 def _unbatch(self): 519 """Returns a TypeSpec representing a single element this TypeSpec. 520 521 Returns: 522 A `TypeSpec` representing a single element of objects with this TypeSpec. 523 """ 524 raise NotImplementedError("%s._unbatch" % type(self).__name__) 525 526 def _to_batched_tensor_list(self, value): 527 """Returns a tensor list encoding for value with rank>0.""" 528 tensor_list = self._to_tensor_list(value) 529 if any(t.shape.ndims == 0 for t in tensor_list): 530 raise ValueError("Value %s has insufficient rank for batching." % value) 531 return tensor_list 532 533 534@tf_export("type_spec_from_value") 535def type_spec_from_value(value): 536 """Returns a `tf.TypeSpec` that represents the given `value`. 537 538 Examples: 539 540 >>> tf.type_spec_from_value(tf.constant([1, 2, 3])) 541 TensorSpec(shape=(3,), dtype=tf.int32, name=None) 542 >>> tf.type_spec_from_value(np.array([4.0, 5.0], np.float64)) 543 TensorSpec(shape=(2,), dtype=tf.float64, name=None) 544 >>> tf.type_spec_from_value(tf.ragged.constant([[1, 2], [3, 4, 5]])) 545 RaggedTensorSpec(TensorShape([2, None]), tf.int32, 1, tf.int64) 546 547 >>> example_input = tf.ragged.constant([[1, 2], [3]]) 548 >>> @tf.function(input_signature=[tf.type_spec_from_value(example_input)]) 549 ... def f(x): 550 ... return tf.reduce_sum(x, axis=1) 551 552 Args: 553 value: A value that can be accepted or returned by TensorFlow APIs. 554 Accepted types for `value` include `tf.Tensor`, any value that can be 555 converted to `tf.Tensor` using `tf.convert_to_tensor`, and any subclass 556 of `CompositeTensor` (such as `tf.RaggedTensor`). 557 558 Returns: 559 A `TypeSpec` that is compatible with `value`. 560 561 Raises: 562 TypeError: If a TypeSpec cannot be built for `value`, because its type 563 is not supported. 564 """ 565 spec = _type_spec_from_value(value) 566 if spec is not None: 567 return spec 568 569 # Fallback: try converting value to a tensor. 570 try: 571 tensor = ops.convert_to_tensor(value) 572 spec = _type_spec_from_value(tensor) 573 if spec is not None: 574 return spec 575 except (ValueError, TypeError) as e: 576 logging.vlog( 577 3, "Failed to convert %r to tensor: %s" % (type(value).__name__, e)) 578 579 raise TypeError("Could not build a TypeSpec for %r with type %s" % 580 (value, type(value).__name__)) 581 582 583def _type_spec_from_value(value): 584 """Returns a `TypeSpec` that represents the given `value`.""" 585 if isinstance(value, ops.Tensor): 586 # Note: we do not include Tensor names when constructing TypeSpecs. 587 return tensor_spec.TensorSpec(value.shape, value.dtype) 588 589 if isinstance(value, composite_tensor.CompositeTensor): 590 return value._type_spec # pylint: disable=protected-access 591 592 # If `value` is a list and all of its elements can be represented by the same 593 # batchable type spec, then we can represent the entire list using a single 594 # type spec that captures the type accurately (unlike the `convert_to_tensor` 595 # fallback). 596 if isinstance(value, list) and value: 597 subspecs = [_type_spec_from_value(v) for v in value] 598 if isinstance(subspecs[0], BatchableTypeSpec): 599 merged_subspec = subspecs[0] 600 try: 601 for subspec in subspecs[1:]: 602 merged_subspec = merged_subspec.most_specific_compatible_type(subspec) 603 return merged_subspec._batch(len(subspecs)) # pylint: disable=protected-access 604 except (ValueError, TypeError): 605 pass # incompatible subspecs 606 607 for entry in reversed(_TYPE_CONVERSION_FUNCTION_REGISTRY): 608 type_object, converter_fn, allow_subclass = entry 609 if ((type(value) is type_object) or # pylint: disable=unidiomatic-typecheck 610 (allow_subclass and isinstance(value, type_object))): 611 return converter_fn(value) 612 613 return None 614 615_TYPE_CONVERSION_FUNCTION_REGISTRY = [] 616 617 618def register_type_spec_from_value_converter(type_object, converter_fn, 619 allow_subclass=False): 620 """Registers a function for converting values with a given type to TypeSpecs. 621 622 If multiple registered `type_object`s match a value, then the most recent 623 registration takes precedence. Custom converters should not be defined for 624 `CompositeTensor`s; use `CompositeTensor._type_spec` instead. 625 626 Args: 627 type_object: A Python `type` object representing the type of values 628 accepted by `converter_fn`. 629 converter_fn: A function that takes one argument (an instance of the 630 type represented by `type_object`) and returns a `TypeSpec`. 631 allow_subclass: If true, then use `isinstance(value, type_object)` to 632 check for matches. If false, then use `type(value) is type_object`. 633 """ 634 _, type_object = tf_decorator.unwrap(type_object) 635 _TYPE_CONVERSION_FUNCTION_REGISTRY.append( 636 (type_object, converter_fn, allow_subclass)) 637 638 639_pywrap_utils.RegisterType("TypeSpec", TypeSpec) 640 641 642_TYPE_SPEC_TO_NAME = {} 643_NAME_TO_TYPE_SPEC = {} 644 645 646# Regular expression for valid TypeSpec names. 647_REGISTERED_NAME_RE = re.compile(r"^(\w+\.)+\w+$") 648 649 650# TODO(b/173744905) tf_export this as "tf.register_type_spec". (And add a 651# usage example to the docstring, once the API is public.) 652# 653# TODO(b/173744905) Update this decorator to apply to ExtensionType rather than 654# TypeSpec (once we do refactoring to move to_components/from_components from 655# TypeSpec to ExtensionType). 656def register(name): 657 """Decorator used to register a globally unique name for a TypeSpec subclass. 658 659 Args: 660 name: The name of the type spec. Must be globally unique. Must have 661 the form `"{project_name}.{type_name}"`. E.g. `"my_project.MyTypeSpec"`. 662 663 Returns: 664 A class decorator that registers the decorated class with the given name. 665 """ 666 if not isinstance(name, str): 667 raise TypeError("Expected `name` to be a string; got %r" % (name,)) 668 if not _REGISTERED_NAME_RE.match(name): 669 raise ValueError( 670 "Registered name must have the form '{project_name}.{type_name}' " 671 "(e.g. 'my_project.MyTypeSpec'); got %r." % name) 672 673 def decorator_fn(cls): 674 if not (isinstance(cls, type) and issubclass(cls, TypeSpec)): 675 raise TypeError("Expected `cls` to be a TypeSpec; got %r" % (cls,)) 676 if cls in _TYPE_SPEC_TO_NAME: 677 raise ValueError("Class %s.%s has already been registered with name %s." 678 % (cls.__module__, cls.__name__, 679 _TYPE_SPEC_TO_NAME[cls])) 680 if name in _NAME_TO_TYPE_SPEC: 681 raise ValueError("Name %s has already been registered for class %s.%s." 682 % (name, _NAME_TO_TYPE_SPEC[name].__module__, 683 _NAME_TO_TYPE_SPEC[name].__name__)) 684 _TYPE_SPEC_TO_NAME[cls] = name 685 _NAME_TO_TYPE_SPEC[name] = cls 686 return cls 687 688 return decorator_fn 689 690 691# TODO(edloper) tf_export this as "tf.get_type_spec_name" (or some similar name) 692def get_name(cls): 693 """Returns the registered name for TypeSpec `cls`.""" 694 if not (isinstance(cls, type) and issubclass(cls, TypeSpec)): 695 raise TypeError("Expected `cls` to be a TypeSpec; got %r" % (cls,)) 696 if cls not in _TYPE_SPEC_TO_NAME: 697 raise ValueError("TypeSpec %s.%s has not been registered." % 698 (cls.__module__, cls.__name__)) 699 return _TYPE_SPEC_TO_NAME[cls] 700 701 702# TODO(edloper) tf_export this as "tf.lookup_type_spec" (or some similar name) 703def lookup(name): 704 """Returns the TypeSpec that has been registered with name `name`.""" 705 if not isinstance(name, str): 706 raise TypeError("Expected `name` to be a string; got %r" % (name,)) 707 if name not in _NAME_TO_TYPE_SPEC: 708 raise ValueError("No TypeSpec has been registered with name %r" % (name,)) 709 return _NAME_TO_TYPE_SPEC[name] 710