1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for describing the structure of a `tf.data` type.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21 22import six 23 24from tensorflow.python.data.util import nest 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.framework import tensor_util 30from tensorflow.python.ops import sparse_ops 31from tensorflow.python.util.tf_export import tf_export 32 33 34_STRUCTURE_CONVERSION_FUNCTION_REGISTRY = {} 35 36 37@tf_export("data.experimental.Structure") 38@six.add_metaclass(abc.ABCMeta) 39class Structure(object): 40 """Represents structural information, such as type and shape, about a value. 41 42 A `Structure` generalizes the `tf.Tensor.dtype` and `tf.Tensor.shape` 43 properties, so that we can define generic containers of objects including: 44 45 * `tf.Tensor` 46 * `tf.SparseTensor` 47 * Nested structures of the above. 48 49 TODO(b/110122868): In the future, a single `Structure` will replace the 50 `tf.data.Dataset.output_types`, `tf.data.Dataset.output_shapes`, 51 and `tf.data.Dataset.output_classes`, and similar properties and arguments in 52 the `tf.data.Iterator` and `Optional` classes. 53 """ 54 55 @abc.abstractproperty 56 def _flat_shapes(self): 57 """A list of shapes matching the shapes of `self._to_tensor_list()`. 58 59 Returns: 60 A list of `tf.TensorShape` objects. 61 """ 62 raise NotImplementedError("Structure._flat_shapes") 63 64 @abc.abstractproperty 65 def _flat_types(self): 66 """A list of types matching the types of `self._to_tensor_list()`. 67 68 Returns: 69 A list of `tf.DType` objects. 70 """ 71 raise NotImplementedError("Structure._flat_shapes") 72 73 @abc.abstractmethod 74 def is_compatible_with(self, other): 75 """Returns `True` if `other` is compatible with this structure. 76 77 A structure `t` is a "subtype" of `s` if: 78 79 * `s` and `t` are instances of the same `Structure` subclass. 80 * The nested structures (if any) of `s` and `t` are the same, according to 81 `tf.contrib.framework.nest.assert_same_structure`, and each nested 82 structure of `t` is a "subtype" of the corresponding nested structure of 83 `s`. 84 * Any `tf.DType` components of `t` are the same as the corresponding 85 components in `s`. 86 * Any `tf.TensorShape` components of `t` are compatible with the 87 corresponding components in `s`, according to 88 `tf.TensorShape.is_compatible_with`. 89 90 Args: 91 other: A `Structure`. 92 93 Returns: 94 `True` if `other` is a subtype of this structure, otherwise `False`. 95 """ 96 raise NotImplementedError("Structure.is_compatible_with()") 97 98 @abc.abstractmethod 99 def _to_tensor_list(self, value): 100 """Returns a flat list of `tf.Tensor` representing `value`. 101 102 This method can be used, along with `self._flat_shapes` and 103 `self._flat_types` to represent structured values in lower level APIs 104 (such as plain TensorFlow operations) that do not understand structure. 105 106 Requires: `self.is_compatible_with(Structure.from_value(value))`. 107 108 Args: 109 value: A value with compatible structure. 110 111 Returns: 112 A flat list of `tf.Tensor` representing `value`. 113 """ 114 raise NotImplementedError("Structure._to_tensor_list()") 115 116 @abc.abstractmethod 117 def _to_batched_tensor_list(self, value): 118 """Returns a flat list of rank >= 1 `tf.Tensor` representing `value`. 119 120 This method can be used, along with `self._flat_shapes` and 121 `self._flat_types` to represent structured values in lower level APIs 122 (such as plain TensorFlow operations) that do not understand structure, 123 *and* that require that the plain tensors have a rank of at least one 124 (e.g. for the purpose of slicing the tensors). 125 126 Requires: `self.is_compatible_with(Structure.from_value(value))`. 127 128 Args: 129 value: A value with compatible structure. 130 131 Returns: 132 A flat list of `tf.Tensor` representing `value`. 133 """ 134 raise NotImplementedError("Structure._to_batched_tensor_list()") 135 136 @abc.abstractmethod 137 def _from_tensor_list(self, flat_value): 138 """Builds a flat list of `tf.Tensor` into a value matching this structure. 139 140 Args: 141 flat_value: A list of `tf.Tensor` with compatible flat structure. 142 143 Returns: 144 A structured object matching this structure. 145 146 Raises: 147 ValueError: If the shapes and types of the tensors in `flat_value` are not 148 compatible with `self._flat_shapes` and `self._flat_types` respectively. 149 """ 150 raise NotImplementedError("Structure._from_tensor_list()") 151 152 def _from_compatible_tensor_list(self, flat_value): 153 """A version of `_from_tensor_list()` that may avoid performing checks. 154 155 NOTE: This method should be used to avoid checks for performance reasons, 156 when the validity of `flat_value` has been validated by other means. 157 The shapes and types of the tensors in `flat_value` must be compatible with 158 `self._flat_shapes` and `self._flat_types` respectively. The behavior is 159 undefined if this requirement is not met. 160 161 Args: 162 flat_value: A list of `tf.Tensor` with compatible flat structure. 163 164 Returns: 165 A structured object matching this structure. 166 """ 167 return self._from_tensor_list(flat_value) 168 169 @abc.abstractmethod 170 def _batch(self, batch_size): 171 """Returns a structure representing a batch of objects with this structure. 172 173 Args: 174 batch_size: An `int` representing the number of elements in a batch, 175 or `None` if the batch size may vary. 176 177 Returns: 178 A `Structure` representing a batch of objects with this structure. 179 """ 180 raise NotImplementedError("Structure._batch()") 181 182 @abc.abstractmethod 183 def _unbatch(self): 184 raise NotImplementedError("Structure._unbatch()") 185 186 @staticmethod 187 def from_value(value): 188 """Returns a `Structure` that represents the given `value`. 189 190 Args: 191 value: A potentially structured value. 192 193 Returns: 194 A `Structure` that is compatible with `value`. 195 196 Raises: 197 TypeError: If a structure cannot be built for `value`, because its type 198 or one of its component types is not supported. 199 """ 200 # TODO(b/110122868): Add support for custom types and Dataset to this 201 # method. 202 if isinstance( 203 value, 204 (sparse_tensor_lib.SparseTensor, sparse_tensor_lib.SparseTensorValue)): 205 return SparseTensorStructure.from_value(value) 206 elif isinstance(value, (tuple, dict)): 207 return NestedStructure.from_value(value) 208 else: 209 for converter_type, converter_fn in ( 210 _STRUCTURE_CONVERSION_FUNCTION_REGISTRY.items()): 211 if isinstance(value, converter_type): 212 return converter_fn(value) 213 try: 214 tensor = ops.convert_to_tensor(value) 215 except (ValueError, TypeError): 216 raise TypeError("Could not build a structure for %r" % value) 217 return TensorStructure.from_value(tensor) 218 219 @staticmethod 220 def _register_custom_converter(type_object, converter_fn): 221 """Registers `converter_fn` for converting values of the given type. 222 223 Args: 224 type_object: A Python `type` object representing the type of values 225 accepted by `converter_fn`. 226 converter_fn: A function that takes one argument (an instance of the 227 type represented by `type_object`) and returns a `Structure`. 228 """ 229 _STRUCTURE_CONVERSION_FUNCTION_REGISTRY[type_object] = converter_fn 230 231 @abc.abstractmethod 232 def _to_legacy_output_types(self): 233 raise NotImplementedError("Structure._to_legacy_output_types()") 234 235 @abc.abstractmethod 236 def _to_legacy_output_shapes(self): 237 raise NotImplementedError("Structure._to_legacy_output_shapes()") 238 239 @abc.abstractmethod 240 def _to_legacy_output_classes(self): 241 raise NotImplementedError("Structure._to_legacy_output_classes()") 242 243 244def convert_legacy_structure(output_types, output_shapes, output_classes): 245 """Returns a `Structure` that represents the given legacy structure. 246 247 This method provides a way to convert from the existing `Dataset` and 248 `Iterator` structure-related properties to a `Structure` object. A "legacy" 249 structure is represented by the `tf.data.Dataset.output_types`, 250 `tf.data.Dataset.output_shapes`, and `tf.data.Dataset.output_classes` 251 properties. 252 253 TODO(b/110122868): Remove this function once `Structure` is used throughout 254 `tf.data`. 255 256 Args: 257 output_types: A nested structure of `tf.DType` objects corresponding to 258 each component of a structured value. 259 output_shapes: A nested structure of `tf.TensorShape` objects 260 corresponding to each component a structured value. 261 output_classes: A nested structure of Python `type` objects corresponding 262 to each component of a structured value. 263 264 Returns: 265 A `Structure`. 266 267 Raises: 268 TypeError: If a structure cannot be built from the arguments, because one of 269 the component classes in `output_classes` is not supported. 270 """ 271 flat_types = nest.flatten(output_types) 272 flat_shapes = nest.flatten(output_shapes) 273 flat_classes = nest.flatten(output_classes) 274 flat_ret = [] 275 for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes, 276 flat_classes): 277 if isinstance(flat_class, Structure): 278 flat_ret.append(flat_class) 279 elif issubclass(flat_class, sparse_tensor_lib.SparseTensor): 280 flat_ret.append(SparseTensorStructure(flat_type, flat_shape)) 281 elif issubclass(flat_class, ops.Tensor): 282 flat_ret.append(TensorStructure(flat_type, flat_shape)) 283 else: 284 # NOTE(mrry): Since legacy structures produced by iterators only 285 # comprise Tensors, SparseTensors, and nests, we do not need to 286 # support all structure types here. 287 raise TypeError( 288 "Could not build a structure for output class %r" % flat_type) 289 290 ret = nest.pack_sequence_as(output_classes, flat_ret) 291 if isinstance(ret, Structure): 292 return ret 293 else: 294 return NestedStructure(ret) 295 296 297# NOTE(mrry): The following classes make extensive use of non-public methods of 298# their base class, so we disable the protected-access lint warning once here. 299# pylint: disable=protected-access 300@tf_export("data.experimental.NestedStructure") 301class NestedStructure(Structure): 302 """Represents a nested structure in which each leaf is a `Structure`.""" 303 304 def __init__(self, nested_structure): 305 self._nested_structure = nested_structure 306 self._flat_nested_structure = nest.flatten(nested_structure) 307 self._flat_shapes_list = [] 308 self._flat_types_list = [] 309 for s in nest.flatten(nested_structure): 310 if not isinstance(s, Structure): 311 raise TypeError("nested_structure must be a (potentially nested) tuple " 312 "or dictionary of Structure objects.") 313 self._flat_shapes_list.extend(s._flat_shapes) 314 self._flat_types_list.extend(s._flat_types) 315 316 @property 317 def _flat_shapes(self): 318 return self._flat_shapes_list 319 320 @property 321 def _flat_types(self): 322 return self._flat_types_list 323 324 def is_compatible_with(self, other): 325 if not isinstance(other, NestedStructure): 326 return False 327 try: 328 # pylint: disable=protected-access 329 nest.assert_same_structure(self._nested_structure, 330 other._nested_structure) 331 except (ValueError, TypeError): 332 return False 333 334 return all( 335 substructure.is_compatible_with(other_substructure) 336 for substructure, other_substructure in zip( 337 nest.flatten(self._nested_structure), 338 nest.flatten(other._nested_structure))) 339 340 def _to_tensor_list(self, value): 341 ret = [] 342 343 try: 344 flat_value = nest.flatten_up_to(self._nested_structure, value) 345 except (ValueError, TypeError): 346 raise ValueError("The value %r is not compatible with the nested " 347 "structure %r." % (value, self._nested_structure)) 348 349 for sub_value, structure in zip(flat_value, self._flat_nested_structure): 350 if not structure.is_compatible_with(Structure.from_value(sub_value)): 351 raise ValueError("Component value %r is not compatible with the nested " 352 "structure %r." % (sub_value, structure)) 353 ret.extend(structure._to_tensor_list(sub_value)) 354 return ret 355 356 def _to_batched_tensor_list(self, value): 357 ret = [] 358 359 try: 360 flat_value = nest.flatten_up_to(self._nested_structure, value) 361 except (ValueError, TypeError): 362 raise ValueError("The value %r is not compatible with the nested " 363 "structure %r." % (value, self._nested_structure)) 364 365 for sub_value, structure in zip(flat_value, self._flat_nested_structure): 366 if not structure.is_compatible_with(Structure.from_value(sub_value)): 367 raise ValueError("Component value %r is not compatible with the nested " 368 "structure %r." % (sub_value, structure)) 369 ret.extend(structure._to_batched_tensor_list(sub_value)) 370 return ret 371 372 def _from_tensor_list(self, flat_value): 373 if len(flat_value) != len(self._flat_types): 374 raise ValueError("Expected %d flat values in NestedStructure but got %d." 375 % (len(self._flat_types), len(flat_value))) 376 377 flat_ret = [] 378 i = 0 379 for structure in self._flat_nested_structure: 380 num_flat_values = len(structure._flat_types) 381 sub_value = flat_value[i:i + num_flat_values] 382 flat_ret.append(structure._from_tensor_list(sub_value)) 383 i += num_flat_values 384 385 return nest.pack_sequence_as(self._nested_structure, flat_ret) 386 387 def _from_compatible_tensor_list(self, flat_value): 388 flat_ret = [] 389 i = 0 390 for structure in self._flat_nested_structure: 391 num_flat_values = len(structure._flat_types) 392 sub_value = flat_value[i:i + num_flat_values] 393 flat_ret.append(structure._from_compatible_tensor_list(sub_value)) 394 i += num_flat_values 395 396 return nest.pack_sequence_as(self._nested_structure, flat_ret) 397 398 @staticmethod 399 def from_value(value): 400 flat_nested_structure = [ 401 Structure.from_value(sub_value) for sub_value in nest.flatten(value) 402 ] 403 return NestedStructure(nest.pack_sequence_as(value, flat_nested_structure)) 404 405 def _to_legacy_output_types(self): 406 return nest.map_structure( 407 lambda s: s._to_legacy_output_types(), self._nested_structure) 408 409 def _to_legacy_output_shapes(self): 410 return nest.map_structure( 411 lambda s: s._to_legacy_output_shapes(), self._nested_structure) 412 413 def _to_legacy_output_classes(self): 414 return nest.map_structure( 415 lambda s: s._to_legacy_output_classes(), self._nested_structure) 416 417 def _batch(self, batch_size): 418 return NestedStructure(nest.map_structure( 419 lambda s: s._batch(batch_size), self._nested_structure)) 420 421 def _unbatch(self): 422 return NestedStructure(nest.map_structure( 423 lambda s: s._unbatch(), self._nested_structure)) 424 425 426@tf_export("data.experimental.TensorStructure") 427class TensorStructure(Structure): 428 """Represents structural information about a `tf.Tensor`.""" 429 430 def __init__(self, dtype, shape): 431 self._dtype = dtypes.as_dtype(dtype) 432 self._shape = tensor_shape.as_shape(shape) 433 434 @property 435 def _flat_shapes(self): 436 return [self._shape] 437 438 @property 439 def _flat_types(self): 440 return [self._dtype] 441 442 def is_compatible_with(self, other): 443 return (isinstance(other, TensorStructure) and 444 self._dtype.is_compatible_with(other._dtype) and 445 self._shape.is_compatible_with(other._shape)) 446 447 def _to_tensor_list(self, value): 448 if not self.is_compatible_with(Structure.from_value(value)): 449 raise ValueError("Value %r is not convertible to a tensor with dtype %s " 450 "and shape %s." % (value, self._dtype, self._shape)) 451 return [value] 452 453 def _to_batched_tensor_list(self, value): 454 if self._shape.merge_with(value.shape).ndims == 0: 455 raise ValueError("Unbatching a tensor is only supported for rank >= 1") 456 return [value] 457 458 def _from_tensor_list(self, flat_value): 459 if len(flat_value) != 1: 460 raise ValueError("TensorStructure corresponds to a single tf.Tensor.") 461 if not self.is_compatible_with(Structure.from_value(flat_value[0])): 462 raise ValueError("Cannot convert %r to a tensor with dtype %s and shape " 463 "%s." % (flat_value[0], self._dtype, self._shape)) 464 return self._from_compatible_tensor_list(flat_value) 465 466 def _from_compatible_tensor_list(self, flat_value): 467 # TODO(b/112266545): It would be cleaner to create a new `ensure_shape()` 468 # op here and return that, instead of mutating the input's shape using 469 # `Tensor.set_shape()`. However, that would add extra ops on the arguments 470 # of each `tf.data` function, which could impact performance. When this 471 # bug is resolved, we should be able to add the `ensure_shape()` ops and 472 # optimize them away using contextual shape information. 473 flat_value[0].set_shape(self._shape) 474 return flat_value[0] 475 476 @staticmethod 477 def from_value(value): 478 return TensorStructure(value.dtype, value.shape) 479 480 def _to_legacy_output_types(self): 481 return self._dtype 482 483 def _to_legacy_output_shapes(self): 484 return self._shape 485 486 def _to_legacy_output_classes(self): 487 return ops.Tensor 488 489 def _batch(self, batch_size): 490 return TensorStructure( 491 self._dtype, 492 tensor_shape.TensorShape([batch_size]).concatenate(self._shape)) 493 494 def _unbatch(self): 495 if self._shape.ndims == 0: 496 raise ValueError("Unbatching a tensor is only supported for rank >= 1") 497 return TensorStructure(self._dtype, self._shape[1:]) 498 499 500@tf_export("data.experimental.SparseTensorStructure") 501class SparseTensorStructure(Structure): 502 """Represents structural information about a `tf.SparseTensor`.""" 503 504 def __init__(self, dtype, dense_shape): 505 self._dtype = dtypes.as_dtype(dtype) 506 self._dense_shape = tensor_shape.as_shape(dense_shape) 507 508 @property 509 def _flat_shapes(self): 510 # NOTE(mrry): The default flat shape of a boxed `SparseTensor` is `(3,)`, 511 # but a `SparseTensorStructure` can also represent a batch of boxed 512 # `SparseTensor` objects with shape `(?, 3)` (and batches of batches, etc.), 513 # so the flat shape must be unknown. 514 return [tensor_shape.unknown_shape(None)] 515 516 @property 517 def _flat_types(self): 518 return [dtypes.variant] 519 520 def is_compatible_with(self, other): 521 return (isinstance(other, SparseTensorStructure) and 522 self._dtype.is_compatible_with(other._dtype) and 523 self._dense_shape.is_compatible_with(other._dense_shape)) 524 525 def _to_tensor_list(self, value): 526 return [sparse_ops.serialize_sparse(value, out_type=dtypes.variant)] 527 528 def _to_batched_tensor_list(self, value): 529 if self._dense_shape.merge_with( 530 tensor_util.constant_value_as_shape(value.dense_shape)).ndims == 0: 531 raise ValueError( 532 "Unbatching a sparse tensor is only supported for rank >= 1") 533 return [sparse_ops.serialize_many_sparse(value, out_type=dtypes.variant)] 534 535 def _from_tensor_list(self, flat_value): 536 if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or 537 not flat_value[0].shape.is_compatible_with(tensor_shape.vector(3))): 538 raise ValueError("SparseTensorStructure corresponds to a single " 539 "tf.variant vector of length 3.") 540 return self._from_compatible_tensor_list(flat_value) 541 542 def _from_compatible_tensor_list(self, flat_value): 543 ret = sparse_ops.deserialize_sparse( 544 flat_value[0], dtype=self._dtype, rank=self._dense_shape.ndims) 545 ret.indices.set_shape([None, self._dense_shape.ndims]) 546 ret.dense_shape.set_shape([self._dense_shape.ndims]) 547 return ret 548 549 @staticmethod 550 def from_value(value): 551 sparse_tensor = sparse_tensor_lib.SparseTensor.from_value(value) 552 return SparseTensorStructure( 553 sparse_tensor.dtype, 554 tensor_util.constant_value_as_shape(sparse_tensor.dense_shape)) 555 556 def _to_legacy_output_types(self): 557 return self._dtype 558 559 def _to_legacy_output_shapes(self): 560 return self._dense_shape 561 562 def _to_legacy_output_classes(self): 563 return sparse_tensor_lib.SparseTensor 564 565 def _batch(self, batch_size): 566 return SparseTensorStructure( 567 self._dtype, 568 tensor_shape.TensorShape([batch_size]).concatenate(self._dense_shape)) 569 570 def _unbatch(self): 571 if self._dense_shape.ndims == 0: 572 raise ValueError("Unbatching a tensor is only supported for rank >= 1") 573 return SparseTensorStructure(self._dtype, self._dense_shape[1:]) 574