• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Meatadata about fields for user-defined ExtensionType classes."""
16
17import collections
18import collections.abc
19import typing
20
21from tensorflow.python.framework import composite_tensor
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import immutable_dict
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.framework import tensor_spec
27from tensorflow.python.framework import type_spec
28
29# These names may not be used as the name for a ExtensionType field (to prevent
30# name clashes).  All names beginning with `'_tf_extension_type'` are also
31# reserved.
32RESERVED_FIELD_NAMES = [
33    'self',
34    # Name of the nested TypeSpec class.
35    'Spec',
36    # Names defined by the CompositeTensor base class.
37    '_type_spec',
38    '_shape_invariant_to_type_spec',
39    '_consumers',
40    # Names defined by the TypeSpec base class.
41    'value_type',
42    'is_compatible_with',
43    'most_specific_compatible_type',
44    '_with_tensor_ranks_only',
45    '_to_components',
46    '_from_components',
47    '_component_specs',
48    '_to_tensor_list',
49    '_from_tensor_list',
50    '_from_compatible_tensor_list',
51    '_flat_tensor_specs',
52    '_serialize',
53    '_deserialize',
54    '_to_legacy_output_types',
55    '_to_legacy_output_shapes',
56    '_to_legacy_output_classes',
57]
58
59
60class Sentinel(object):
61  """Sentinel value that's not equal (w/ `is`) to any user value."""
62
63  def __init__(self, name):
64    self._name = name
65
66  def __repr__(self):
67    return self._name
68
69
70# ==============================================================================
71# ExtensionTypeField
72# ==============================================================================
73class ExtensionTypeField(
74    collections.namedtuple('ExtensionTypeField',
75                           ['name', 'value_type', 'default'])):
76  """Metadata about a single field in a `tf.ExtensionType` object."""
77
78  NO_DEFAULT = Sentinel('ExtensionTypeField.NO_DEFAULT')
79
80  def __new__(cls, name, value_type, default=NO_DEFAULT):
81    """Constructs a new ExtensionTypeField containing metadata for a single field.
82
83    Args:
84      name: The name of the new field (`str`).  May not be a reserved name.
85      value_type: A python type expression constraining what values this field
86        can take.
87      default: The default value for the new field, or `NO_DEFAULT` if this
88        field has no default value.
89
90    Returns:
91      A new `ExtensionTypeField`.
92
93    Raises:
94      TypeError: If the type described by `value_type` is not currently
95          supported by `tf.ExtensionType`.
96      TypeError: If `default` is specified and its type does not match
97        `value_type`.
98    """
99    try:
100      validate_field_value_type(value_type, allow_forward_references=True)
101    except TypeError as e:
102      raise TypeError(f'In field {name!r}: {e}')
103
104    if default is not cls.NO_DEFAULT:
105      default = _convert_value(default, value_type,
106                               (f'default value for {name}',))
107    return super(ExtensionTypeField, cls).__new__(cls, name, value_type,
108                                                  default)
109
110  @staticmethod
111  def is_reserved_name(name):
112    """Returns true if `name` is a reserved name."""
113    return name in RESERVED_FIELD_NAMES or name.lower().startswith(
114        '_tf_extension_type')
115
116
117def validate_field_value_type(value_type,
118                              in_mapping_key=False,
119                              allow_forward_references=False):
120  """Checks that `value_type` contains only supported type annotations.
121
122  Args:
123    value_type: The type annotation to check.
124    in_mapping_key: True if `value_type` is nested in the key of a mapping.
125    allow_forward_references: If false, then raise an exception if a
126      `value_type` contains a forward reference (i.e., a string literal).
127
128  Raises:
129    TypeError: If `value_type` contains an unsupported type annotation.
130  """
131  if isinstance(value_type, str) or is_forward_ref(value_type):
132    if allow_forward_references:
133      return
134    else:
135      raise TypeError(f'Unresolved forward reference {value_type!r}')
136
137  if value_type in (int, float, str, bytes, bool, None, _NoneType,
138                    dtypes.DType):
139    return
140  elif (value_type in (ops.Tensor, tensor_shape.TensorShape) or
141        isinstance(value_type, type_spec.TypeSpec) or
142        (isinstance(value_type, type) and
143         issubclass(value_type, composite_tensor.CompositeTensor))):
144    if in_mapping_key:
145      raise TypeError('Key must be hashable.')
146  elif is_generic_tuple(value_type) or is_generic_union(value_type):
147    type_args = get_generic_type_args(value_type)
148    if (len(type_args) == 2 and type_args[1] is Ellipsis and
149        is_generic_tuple(value_type)):  # `Tuple[X, ...]`
150      validate_field_value_type(type_args[0], in_mapping_key,
151                                allow_forward_references)
152    else:
153      for arg in get_generic_type_args(value_type):
154        validate_field_value_type(arg, in_mapping_key, allow_forward_references)
155  elif is_generic_mapping(value_type):
156    key_type, value_type = get_generic_type_args(value_type)
157    validate_field_value_type(key_type, True, allow_forward_references)
158    validate_field_value_type(value_type, in_mapping_key,
159                              allow_forward_references)
160  elif isinstance(value_type, type):
161    raise TypeError(f'Unsupported type annotation `{value_type.__name__}`')
162  else:
163    raise TypeError(f'Unsupported type annotation {value_type!r}')
164
165
166# ==============================================================================
167# Type-checking & conversion for ExtensionTypeField values
168# ==============================================================================
169
170
171def convert_fields(fields, field_values):
172  """Type-checks and converts each field in `field_values` (in place).
173
174  Args:
175    fields: A list of `ExtensionTypeField` objects.
176    field_values: A `dict` mapping field names to values.  Must contain an entry
177      for each field.  I.e., `set(field_values.keys())` must be equal to
178      `set([f.name for f in fields])`.
179
180  Raises:
181    ValueError: If the keys of `field_values` do not match the names of
182      the fields in `fields`.
183    TypeError: If any value in `field_values` does not have the type indicated
184      by the corresponding `ExtensionTypeField` object.
185  """
186  _convert_fields(fields, field_values, for_spec=False)
187
188
189def convert_fields_for_spec(fields, field_values):
190  """Type-checks and converts field values for a TypeSpec (in place).
191
192  This is similar to `convert_fields`, except that we expect a TypeSpec
193  for tensor-like types.  In particular, if the `value_type` of a field
194  specifies a tensor-like type (tf.Tensor, CompositeTensor, or TypeSpec),
195  then the corresponding value in `fields` is expected to contain a TypeSpec
196  (rather than a value described by that TypeSpec).
197
198  Args:
199    fields: A list of `ExtensionTypeField` objects.
200    field_values: A `dict` mapping field names to values.  Must contain an entry
201      for each field.  I.e., `set(field_values.keys())` must be equal to
202      `set([f.name for f in fields])`.
203
204  Raises:
205    ValueError: If the keys of `field_values` do not match the names of
206      the fields in `fields`.
207    TypeError: If any value in `field_values` does not have the type indicated
208      by the corresponding `ExtensionTypeField` object.
209  """
210  _convert_fields(fields, field_values, for_spec=True)
211
212
213def _convert_fields(fields, field_values, for_spec):
214  """Type-checks and converts each field in `field_values` (in place).
215
216  Args:
217    fields: A list of `ExtensionTypeField` objects.
218    field_values: A `dict` mapping field names to values.  Must contain an entry
219      for each field.  I.e., `set(field_values.keys())` must be equal to
220      `set([f.name for f in fields])`.
221    for_spec: If false, then expect a value for tensor-like types; if true, then
222      expect a TypeSpec for tensor-like types.
223
224  Raises:
225    ValueError: If the keys of `field_values` do not match the names of
226      the fields in `fields`.
227    TypeError: If any value in `field_values` does not have the type indicated
228      by the corresponding `ExtensionTypeField` object.
229  """
230  converted = {}
231  if len(fields) != len(field_values):
232    _report_field_mismatches(fields, field_values)
233  for field in fields:
234    if field.name not in field_values:
235      _report_field_mismatches(fields, field_values)
236    field_value = field_values[field.name]
237    converted[field.name] = _convert_value(field_value, field.value_type,
238                                           (field.name,), for_spec)
239  field_values.update(converted)
240
241
242def _convert_value(value, expected_type, path, for_spec=False):
243  """Type-checks and converts a value.
244
245  Args:
246    value: The value to type-check.
247    expected_type: The expected type for the value.
248    path: Tuple of `str` naming the value (used for exception messages).
249    for_spec: If false, then expect a value for tensor-like types; if true, then
250      expect a TensorSpec for tensor-like types.
251
252  Returns:
253    A copy of `value`, converted to the expected type.
254
255  Raises:
256    TypeError: If `value` can not be converted to the expected type.
257  """
258  assert isinstance(path, tuple)
259
260  if expected_type is None:
261    expected_type = _NoneType
262
263  if expected_type is ops.Tensor:
264    return _convert_tensor(value, path, for_spec)
265  elif isinstance(expected_type, tensor_spec.TensorSpec):
266    return _convert_tensor_spec(value, expected_type, path, for_spec)
267  elif isinstance(expected_type, type_spec.TypeSpec):
268    return _convert_type_spec(value, expected_type, path, for_spec)
269  elif (isinstance(expected_type, type) and
270        issubclass(expected_type, composite_tensor.CompositeTensor)):
271    return _convert_composite_tensor(value, expected_type, path, for_spec)
272  elif expected_type in (int, float, bool, str, bytes, _NoneType, dtypes.DType,
273                         tensor_shape.TensorShape):
274    if not isinstance(value, expected_type):
275      raise TypeError(f'{"".join(path)}: expected '
276                      f'{expected_type.__name__}, got {value!r}')
277    return value
278  elif is_generic_tuple(expected_type):
279    return _convert_tuple(value, expected_type, path, for_spec)
280  elif is_generic_mapping(expected_type):
281    return _convert_mapping(value, expected_type, path, for_spec)
282  elif is_generic_union(expected_type):
283    return _convert_union(value, expected_type, path, for_spec)
284  else:
285    raise TypeError(f'{"".join(path)}: Unsupported type annotation '
286                    f'{expected_type!r}')
287
288
289def _convert_tensor(value, path, for_spec):
290  """Converts `value` to a `Tensor`."""
291  if for_spec:
292    if not isinstance(value, tensor_spec.TensorSpec):
293      raise TypeError(f'{"".join(path)}: expected a TensorSpec, got {value!r}')
294    return value
295
296  if not isinstance(value, ops.Tensor):
297    try:
298      value = ops.convert_to_tensor(value)
299    except (ValueError, TypeError) as e:
300      raise TypeError(f'{"".join(path)}: expected a Tensor, '
301                      f'got {value!r}') from e
302  return value
303
304
305def _convert_tensor_spec(value, expected_type, path, for_spec):
306  """Converts `value` to a Tensor comptible with TensorSpec expected_type."""
307  if for_spec:
308    if not (isinstance(value, tensor_spec.TensorSpec) and
309            expected_type.is_compatible_with(value)):
310      raise TypeError(f'{"".join(path)}: expected a TensorSpec compatible '
311                      f'with {expected_type}, got {value!r}')
312    return value
313
314  if not isinstance(value, ops.Tensor):
315    try:
316      value = ops.convert_to_tensor(value, expected_type.dtype)
317    except (ValueError, TypeError):
318      raise TypeError(f'{"".join(path)}: expected a {expected_type.dtype!r} '
319                      f'Tensor, got {value!r}')
320  if not expected_type.is_compatible_with(value):
321    raise TypeError(f'{"".join(path)}: expected a Tensor compatible with '
322                    f'{expected_type}, got {value!r}')
323  return value
324
325
326def _convert_type_spec(value, expected_type, path, for_spec):
327  """Converts `value` to a value comptible with TypeSpec `expected_type`."""
328  if for_spec:
329    if not (isinstance(value, type_spec.TypeSpec) and
330            expected_type.is_compatible_with(value)):
331      raise TypeError(f'{"".join(path)}: expected a TypeSpec compatible '
332                      f'with {expected_type}, got {value!r}')
333    return value
334
335  if (isinstance(value, type_spec.TypeSpec) or
336      not expected_type.is_compatible_with(value)):
337    raise TypeError(f'{"".join(path)}: expected {expected_type!r}, '
338                    f'got {value!r}')
339  return value
340
341
342def _convert_composite_tensor(value, expected_type, path, for_spec):
343  """Converts `value` to a value of type `expected_type`."""
344  if for_spec:
345    if not (isinstance(value, type_spec.TypeSpec) and
346            issubclass(value.value_type, expected_type)):
347      raise TypeError(f'{"".join(path)}: expected a TypeSpec for '
348                      f'{expected_type.__name__}, got {value!r}')
349    return value
350
351  if not isinstance(value, expected_type):
352    raise TypeError(f'{"".join(path)}: expected {expected_type.__name__}, '
353                    f'got {value!r}')
354  return value
355
356
357def _convert_tuple(value, expected_type, path, for_spec):
358  """Converts `value` to a tuple with type `expected_type`."""
359  if not isinstance(value, typing.Sequence):
360    raise TypeError(f'{"".join(path)}: expected tuple, got {value!r}')
361  element_types = get_generic_type_args(expected_type)
362  if len(element_types) == 2 and element_types[1] is Ellipsis:
363    return tuple([
364        _convert_value(v, element_types[0], path + (f'[{i}]',), for_spec)
365        for (i, v) in enumerate(value)
366    ])
367  else:
368    if len(value) != len(element_types):
369      raise TypeError(f'{"".join(path)}: expected tuple with length '
370                      f'{len(element_types)}, got {value!r})')
371    return tuple([
372        _convert_value(v, t, path + (f'[{i}]',), for_spec)
373        for (i, (v, t)) in enumerate(zip(value, element_types))
374    ])
375
376
377def _convert_mapping(value, expected_type, path, for_spec):
378  """Converts `value` to a mapping with type `expected_type`."""
379  if not isinstance(value, typing.Mapping):
380    raise TypeError(f'{"".join(path)}: expected mapping, got {value!r}')
381  key_type, value_type = get_generic_type_args(expected_type)
382  return immutable_dict.ImmutableDict([
383      (_convert_value(k, key_type, path + ('[<key>]',), for_spec),
384       _convert_value(v, value_type, path + (f'[{k!r}]',), for_spec))
385      for (k, v) in value.items()
386  ])
387
388
389def _convert_union(value, expected_type, path, for_spec):
390  """Converts `value` to a value with any of the types in `expected_type`."""
391  for type_option in get_generic_type_args(expected_type):
392    try:
393      return _convert_value(value, type_option, path, for_spec)
394    except TypeError:
395      pass
396  raise TypeError(f'{"".join(path)}: expected {expected_type}, got {value!r}')
397
398
399def _report_field_mismatches(fields, field_values):
400  """Raises an exception with mismatches between fields and field_values."""
401  expected = set(f.name for f in fields)
402  actual = set(field_values)
403  extra = actual - expected
404  if extra:
405    raise ValueError(f'Got unexpected fields: {extra}')
406  missing = expected - actual
407  if missing:
408    raise ValueError(f'Missing required fields: {missing}')
409
410
411# ==============================================================================
412# Utilities for accessing Python generic type annotations (typing.*)
413# ==============================================================================
414def is_generic_union(tp):
415  """Returns true if `tp` is a parameterized typing.Union value."""
416  return (tp is not typing.Union and
417          getattr(tp, '__origin__', None) is typing.Union)
418
419
420def is_generic_tuple(tp):
421  """Returns true if `tp` is a parameterized typing.Tuple value."""
422  return (tp not in (tuple, typing.Tuple) and
423          getattr(tp, '__origin__', None) in (tuple, typing.Tuple))
424
425
426def is_generic_mapping(tp):
427  """Returns true if `tp` is a parameterized typing.Mapping value."""
428  return (tp not in (collections.abc.Mapping, typing.Mapping) and getattr(
429      tp, '__origin__', None) in (collections.abc.Mapping, typing.Mapping))
430
431
432def is_forward_ref(tp):
433  """Returns true if `tp` is a typing forward reference."""
434  if hasattr(typing, 'ForwardRef'):
435    return isinstance(tp, typing.ForwardRef)
436  elif hasattr(typing, '_ForwardRef'):
437    return isinstance(tp, typing._ForwardRef)  # pylint: disable=protected-access
438  else:
439    return False
440
441
442# Note: typing.get_args was added in Python 3.8.
443if hasattr(typing, 'get_args'):
444  get_generic_type_args = typing.get_args
445else:
446  get_generic_type_args = lambda tp: tp.__args__
447
448_NoneType = type(None)
449