1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for describing the structure of a `tf.data` type.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import functools 22 23import six 24import wrapt 25 26from tensorflow.python.data.util import nest 27from tensorflow.python.framework import composite_tensor 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import sparse_tensor 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.framework import tensor_spec 32from tensorflow.python.framework import type_spec 33from tensorflow.python.ops import tensor_array_ops 34from tensorflow.python.ops.ragged import ragged_tensor 35from tensorflow.python.platform import tf_logging as logging 36from tensorflow.python.util import deprecation 37from tensorflow.python.util.compat import collections_abc 38from tensorflow.python.util.tf_export import tf_export 39 40 41# pylint: disable=invalid-name 42@tf_export(v1=["data.experimental.TensorStructure"]) 43@deprecation.deprecated(None, "Use `tf.TensorSpec` instead.") 44def _TensorStructure(dtype, shape): 45 return tensor_spec.TensorSpec(shape, dtype) 46 47 48@tf_export(v1=["data.experimental.SparseTensorStructure"]) 49@deprecation.deprecated(None, "Use `tf.SparseTensorSpec` instead.") 50def _SparseTensorStructure(dtype, shape): 51 return sparse_tensor.SparseTensorSpec(shape, dtype) 52 53 54@tf_export(v1=["data.experimental.TensorArrayStructure"]) 55@deprecation.deprecated(None, "Use `tf.TensorArraySpec` instead.") 56def _TensorArrayStructure(dtype, element_shape, dynamic_size, infer_shape): 57 return tensor_array_ops.TensorArraySpec(element_shape, dtype, 58 dynamic_size, infer_shape) 59 60 61@tf_export(v1=["data.experimental.RaggedTensorStructure"]) 62@deprecation.deprecated(None, "Use `tf.RaggedTensorSpec` instead.") 63def _RaggedTensorStructure(dtype, shape, ragged_rank): 64 return ragged_tensor.RaggedTensorSpec(shape, dtype, ragged_rank) 65# pylint: enable=invalid-name 66 67 68# TODO(jsimsa): Remove the special-case for `TensorArray` pass-through once 69# it is a subclass of `CompositeTensor`. 70def normalize_element(element, element_signature=None): 71 """Normalizes a nested structure of element components. 72 73 * Components matching `SparseTensorSpec` are converted to `SparseTensor`. 74 * Components matching `RaggedTensorSpec` are converted to `RaggedTensor`. 75 * Components matching `DatasetSpec` or `TensorArraySpec` are passed through. 76 * `CompositeTensor` components are passed through. 77 * All other components are converted to `Tensor`. 78 79 Args: 80 element: A nested structure of individual components. 81 element_signature: (Optional.) A nested structure of `tf.DType` objects 82 corresponding to each component of `element`. If specified, it will be 83 used to set the exact type of output tensor when converting input 84 components which are not tensors themselves (e.g. numpy arrays, native 85 python types, etc.) 86 87 Returns: 88 A nested structure of `Tensor`, `Dataset`, `SparseTensor`, `RaggedTensor`, 89 or `TensorArray` objects. 90 """ 91 normalized_components = [] 92 if element_signature is None: 93 components = nest.flatten(element) 94 flattened_signature = [None] * len(components) 95 pack_as = element 96 else: 97 flattened_signature = nest.flatten(element_signature) 98 components = nest.flatten_up_to(element_signature, element) 99 pack_as = element_signature 100 with ops.name_scope("normalize_element"): 101 # Imported here to avoid circular dependency. 102 from tensorflow.python.data.ops import dataset_ops # pylint: disable=g-import-not-at-top 103 for i, (t, spec) in enumerate(zip(components, flattened_signature)): 104 try: 105 if spec is None: 106 spec = type_spec_from_value(t, use_fallback=False) 107 except TypeError: 108 # TypeError indicates it was not possible to compute a `TypeSpec` for 109 # the value. As a fallback try converting the value to a tensor. 110 normalized_components.append( 111 ops.convert_to_tensor(t, name="component_%d" % i)) 112 else: 113 if isinstance(spec, sparse_tensor.SparseTensorSpec): 114 normalized_components.append(sparse_tensor.SparseTensor.from_value(t)) 115 elif isinstance(spec, ragged_tensor.RaggedTensorSpec): 116 normalized_components.append( 117 ragged_tensor.convert_to_tensor_or_ragged_tensor( 118 t, name="component_%d" % i)) 119 elif isinstance( 120 spec, (tensor_array_ops.TensorArraySpec, dataset_ops.DatasetSpec)): 121 normalized_components.append(t) 122 elif isinstance(spec, NoneTensorSpec): 123 normalized_components.append(NoneTensor()) 124 elif isinstance(t, composite_tensor.CompositeTensor): 125 normalized_components.append(t) 126 else: 127 dtype = getattr(spec, "dtype", None) 128 normalized_components.append( 129 ops.convert_to_tensor(t, name="component_%d" % i, dtype=dtype)) 130 return nest.pack_sequence_as(pack_as, normalized_components) 131 132 133def convert_legacy_structure(output_types, output_shapes, output_classes): 134 """Returns a `Structure` that represents the given legacy structure. 135 136 This method provides a way to convert from the existing `Dataset` and 137 `Iterator` structure-related properties to a `Structure` object. A "legacy" 138 structure is represented by the `tf.data.Dataset.output_types`, 139 `tf.data.Dataset.output_shapes`, and `tf.data.Dataset.output_classes` 140 properties. 141 142 TODO(b/110122868): Remove this function once `Structure` is used throughout 143 `tf.data`. 144 145 Args: 146 output_types: A nested structure of `tf.DType` objects corresponding to 147 each component of a structured value. 148 output_shapes: A nested structure of `tf.TensorShape` objects 149 corresponding to each component a structured value. 150 output_classes: A nested structure of Python `type` objects corresponding 151 to each component of a structured value. 152 153 Returns: 154 A `Structure`. 155 156 Raises: 157 TypeError: If a structure cannot be built from the arguments, because one of 158 the component classes in `output_classes` is not supported. 159 """ 160 flat_types = nest.flatten(output_types) 161 flat_shapes = nest.flatten(output_shapes) 162 flat_classes = nest.flatten(output_classes) 163 flat_ret = [] 164 for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes, 165 flat_classes): 166 if isinstance(flat_class, type_spec.TypeSpec): 167 flat_ret.append(flat_class) 168 elif issubclass(flat_class, sparse_tensor.SparseTensor): 169 flat_ret.append(sparse_tensor.SparseTensorSpec(flat_shape, flat_type)) 170 elif issubclass(flat_class, ops.Tensor): 171 flat_ret.append(tensor_spec.TensorSpec(flat_shape, flat_type)) 172 elif issubclass(flat_class, tensor_array_ops.TensorArray): 173 # We sneaked the dynamic_size and infer_shape into the legacy shape. 174 flat_ret.append( 175 tensor_array_ops.TensorArraySpec( 176 flat_shape[2:], flat_type, 177 dynamic_size=tensor_shape.dimension_value(flat_shape[0]), 178 infer_shape=tensor_shape.dimension_value(flat_shape[1]))) 179 else: 180 # NOTE(mrry): Since legacy structures produced by iterators only 181 # comprise Tensors, SparseTensors, and nests, we do not need to 182 # support all structure types here. 183 raise TypeError( 184 "Could not build a structure for output class %r" % (flat_class,)) 185 186 return nest.pack_sequence_as(output_classes, flat_ret) 187 188 189def _from_tensor_list_helper(decode_fn, element_spec, tensor_list): 190 """Returns an element constructed from the given spec and tensor list. 191 192 Args: 193 decode_fn: Method that constructs an element component from the element spec 194 component and a tensor list. 195 element_spec: A nested structure of `tf.TypeSpec` objects representing to 196 element type specification. 197 tensor_list: A list of tensors to use for constructing the value. 198 199 Returns: 200 An element constructed from the given spec and tensor list. 201 202 Raises: 203 ValueError: If the number of tensors needed to construct an element for 204 the given spec does not match the given number of tensors. 205 """ 206 207 # pylint: disable=protected-access 208 209 flat_specs = nest.flatten(element_spec) 210 flat_spec_lengths = [len(spec._flat_tensor_specs) for spec in flat_specs] 211 if sum(flat_spec_lengths) != len(tensor_list): 212 raise ValueError("Expected %d tensors but got %d." % 213 (sum(flat_spec_lengths), len(tensor_list))) 214 215 i = 0 216 flat_ret = [] 217 for (component_spec, num_flat_values) in zip(flat_specs, flat_spec_lengths): 218 value = tensor_list[i:i + num_flat_values] 219 flat_ret.append(decode_fn(component_spec, value)) 220 i += num_flat_values 221 return nest.pack_sequence_as(element_spec, flat_ret) 222 223 224def from_compatible_tensor_list(element_spec, tensor_list): 225 """Returns an element constructed from the given spec and tensor list. 226 227 Args: 228 element_spec: A nested structure of `tf.TypeSpec` objects representing to 229 element type specification. 230 tensor_list: A list of tensors to use for constructing the value. 231 232 Returns: 233 An element constructed from the given spec and tensor list. 234 235 Raises: 236 ValueError: If the number of tensors needed to construct an element for 237 the given spec does not match the given number of tensors. 238 """ 239 240 # pylint: disable=protected-access 241 # pylint: disable=g-long-lambda 242 return _from_tensor_list_helper( 243 lambda spec, value: spec._from_compatible_tensor_list(value), 244 element_spec, tensor_list) 245 246 247def from_tensor_list(element_spec, tensor_list): 248 """Returns an element constructed from the given spec and tensor list. 249 250 Args: 251 element_spec: A nested structure of `tf.TypeSpec` objects representing to 252 element type specification. 253 tensor_list: A list of tensors to use for constructing the value. 254 255 Returns: 256 An element constructed from the given spec and tensor list. 257 258 Raises: 259 ValueError: If the number of tensors needed to construct an element for 260 the given spec does not match the given number of tensors or the given 261 spec is not compatible with the tensor list. 262 """ 263 264 # pylint: disable=protected-access 265 # pylint: disable=g-long-lambda 266 return _from_tensor_list_helper( 267 lambda spec, value: spec._from_tensor_list(value), element_spec, 268 tensor_list) 269 270 271def get_flat_tensor_specs(element_spec): 272 """Returns a list `tf.TypeSpec`s for the element tensor representation. 273 274 Args: 275 element_spec: A nested structure of `tf.TypeSpec` objects representing to 276 element type specification. 277 278 Returns: 279 A list `tf.TypeSpec`s for the element tensor representation. 280 """ 281 282 # pylint: disable=protected-access 283 return functools.reduce(lambda state, value: state + value._flat_tensor_specs, 284 nest.flatten(element_spec), []) 285 286 287def get_flat_tensor_shapes(element_spec): 288 """Returns a list `tf.TensorShapes`s for the element tensor representation. 289 290 Args: 291 element_spec: A nested structure of `tf.TypeSpec` objects representing to 292 element type specification. 293 294 Returns: 295 A list `tf.TensorShapes`s for the element tensor representation. 296 """ 297 return [spec.shape for spec in get_flat_tensor_specs(element_spec)] 298 299 300def get_flat_tensor_types(element_spec): 301 """Returns a list `tf.DType`s for the element tensor representation. 302 303 Args: 304 element_spec: A nested structure of `tf.TypeSpec` objects representing to 305 element type specification. 306 307 Returns: 308 A list `tf.DType`s for the element tensor representation. 309 """ 310 return [spec.dtype for spec in get_flat_tensor_specs(element_spec)] 311 312 313def _to_tensor_list_helper(encode_fn, element_spec, element): 314 """Returns a tensor list representation of the element. 315 316 Args: 317 encode_fn: Method that constructs a tensor list representation from the 318 given element spec and element. 319 element_spec: A nested structure of `tf.TypeSpec` objects representing to 320 element type specification. 321 element: The element to convert to tensor list representation. 322 323 Returns: 324 A tensor list representation of `element`. 325 326 Raises: 327 ValueError: If `element_spec` and `element` do not have the same number of 328 elements or if the two structures are not nested in the same way. 329 TypeError: If `element_spec` and `element` differ in the type of sequence 330 in any of their substructures. 331 """ 332 333 nest.assert_same_structure(element_spec, element) 334 335 def reduce_fn(state, value): 336 spec, component = value 337 return encode_fn(state, spec, component) 338 339 return functools.reduce( 340 reduce_fn, zip(nest.flatten(element_spec), nest.flatten(element)), []) 341 342 343def to_batched_tensor_list(element_spec, element): 344 """Returns a tensor list representation of the element. 345 346 Args: 347 element_spec: A nested structure of `tf.TypeSpec` objects representing to 348 element type specification. 349 element: The element to convert to tensor list representation. 350 351 Returns: 352 A tensor list representation of `element`. 353 354 Raises: 355 ValueError: If `element_spec` and `element` do not have the same number of 356 elements or if the two structures are not nested in the same way or the 357 rank of any of the tensors in the tensor list representation is 0. 358 TypeError: If `element_spec` and `element` differ in the type of sequence 359 in any of their substructures. 360 """ 361 362 # pylint: disable=protected-access 363 # pylint: disable=g-long-lambda 364 return _to_tensor_list_helper( 365 lambda state, spec, component: state + spec._to_batched_tensor_list( 366 component), element_spec, element) 367 368 369def to_tensor_list(element_spec, element): 370 """Returns a tensor list representation of the element. 371 372 Args: 373 element_spec: A nested structure of `tf.TypeSpec` objects representing to 374 element type specification. 375 element: The element to convert to tensor list representation. 376 377 Returns: 378 A tensor list representation of `element`. 379 380 Raises: 381 ValueError: If `element_spec` and `element` do not have the same number of 382 elements or if the two structures are not nested in the same way. 383 TypeError: If `element_spec` and `element` differ in the type of sequence 384 in any of their substructures. 385 """ 386 387 # pylint: disable=protected-access 388 # pylint: disable=g-long-lambda 389 return _to_tensor_list_helper( 390 lambda state, spec, component: state + spec._to_tensor_list(component), 391 element_spec, element) 392 393 394def are_compatible(spec1, spec2): 395 """Indicates whether two type specifications are compatible. 396 397 Two type specifications are compatible if they have the same nested structure 398 and the their individual components are pair-wise compatible. 399 400 Args: 401 spec1: A `tf.TypeSpec` object to compare. 402 spec2: A `tf.TypeSpec` object to compare. 403 404 Returns: 405 `True` if the two type specifications are compatible and `False` otherwise. 406 """ 407 408 try: 409 nest.assert_same_structure(spec1, spec2) 410 except TypeError: 411 return False 412 except ValueError: 413 return False 414 415 for s1, s2 in zip(nest.flatten(spec1), nest.flatten(spec2)): 416 if not s1.is_compatible_with(s2) or not s2.is_compatible_with(s1): 417 return False 418 return True 419 420 421def type_spec_from_value(element, use_fallback=True): 422 """Creates a type specification for the given value. 423 424 Args: 425 element: The element to create the type specification for. 426 use_fallback: Whether to fall back to converting the element to a tensor 427 in order to compute its `TypeSpec`. 428 429 Returns: 430 A nested structure of `TypeSpec`s that represents the type specification 431 of `element`. 432 433 Raises: 434 TypeError: If a `TypeSpec` cannot be built for `element`, because its type 435 is not supported. 436 """ 437 spec = type_spec._type_spec_from_value(element) # pylint: disable=protected-access 438 if spec is not None: 439 return spec 440 441 if isinstance(element, collections_abc.Mapping): 442 # We create a shallow copy in an attempt to preserve the key order. 443 # 444 # Note that we do not guarantee that the key order is preserved, which is 445 # a limitation inherited from `copy()`. As a consequence, callers of 446 # `type_spec_from_value` should not assume that the key order of a `dict` 447 # in the returned nested structure matches the key order of the 448 # corresponding `dict` in the input value. 449 if isinstance(element, collections.defaultdict): 450 ctor = lambda items: type(element)(element.default_factory, items) 451 else: 452 ctor = type(element) 453 return ctor([(k, type_spec_from_value(v)) for k, v in element.items()]) 454 455 if isinstance(element, tuple): 456 if hasattr(element, "_fields") and isinstance( 457 element._fields, collections_abc.Sequence) and all( 458 isinstance(f, six.string_types) for f in element._fields): 459 if isinstance(element, wrapt.ObjectProxy): 460 element_type = type(element.__wrapped__) 461 else: 462 element_type = type(element) 463 # `element` is a namedtuple 464 return element_type(*[type_spec_from_value(v) for v in element]) 465 # `element` is not a namedtuple 466 return tuple([type_spec_from_value(v) for v in element]) 467 468 if use_fallback: 469 # As a fallback try converting the element to a tensor. 470 try: 471 tensor = ops.convert_to_tensor(element) 472 spec = type_spec_from_value(tensor) 473 if spec is not None: 474 return spec 475 except (ValueError, TypeError) as e: 476 logging.vlog( 477 3, "Failed to convert %r to tensor: %s" % (type(element).__name__, e)) 478 479 raise TypeError("Could not build a TypeSpec for %r with type %s" % 480 (element, type(element).__name__)) 481 482 483# TODO(b/149584798): Move this to framework and add tests for non-tf.data 484# functionality. 485class NoneTensor(composite_tensor.CompositeTensor): 486 """Composite tensor representation for `None` value.""" 487 488 @property 489 def _type_spec(self): 490 return NoneTensorSpec() 491 492 493# TODO(b/149584798): Move this to framework and add tests for non-tf.data 494# functionality. 495@type_spec.register("tf.NoneTensorSpec") 496class NoneTensorSpec(type_spec.BatchableTypeSpec): 497 """Type specification for `None` value.""" 498 499 @property 500 def value_type(self): 501 return NoneTensor 502 503 def _serialize(self): 504 return () 505 506 @property 507 def _component_specs(self): 508 return [] 509 510 def _to_components(self, value): 511 return [] 512 513 def _from_components(self, components): 514 return 515 516 def _to_tensor_list(self, value): 517 return [] 518 519 @staticmethod 520 def from_value(value): 521 return NoneTensorSpec() 522 523 def _batch(self, batch_size): 524 return NoneTensorSpec() 525 526 def _unbatch(self): 527 return NoneTensorSpec() 528 529 def _to_batched_tensor_list(self, value): 530 return [] 531 532 def _to_legacy_output_types(self): 533 return self 534 535 def _to_legacy_output_shapes(self): 536 return self 537 538 def _to_legacy_output_classes(self): 539 return self 540 541 def most_specific_compatible_shape(self, other): 542 if type(self) is not type(other): 543 raise ValueError("No TypeSpec is compatible with both %s and %s" % 544 (self, other)) 545 return self 546 547 548type_spec.register_type_spec_from_value_converter(type(None), 549 NoneTensorSpec.from_value) 550