1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Feature configuration for tf.io.parse_example.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import re 23 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import check_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import sparse_ops 32from tensorflow.python.ops.ragged import ragged_math_ops 33from tensorflow.python.ops.ragged import ragged_tensor 34from tensorflow.python.platform import tf_logging 35from tensorflow.python.util.tf_export import tf_export 36 37 38# TODO(b/122887740) Refactor code: 39# * Move input verification to feature configuration objects (e.g., 40# VarLenFeature should check that dtype is a valid dtype). 41# * Add an _add_feature() method to each feature configuration object 42# (rather than using a dispatch table in _ParseOpParams._add_feature). 43# * Update _construct_tensors_for_composite_features() to call a method 44# on the feature object (rather than using dispatch). 45 46 47@tf_export("io.VarLenFeature", v1=["VarLenFeature", "io.VarLenFeature"]) 48class VarLenFeature(collections.namedtuple("VarLenFeature", ["dtype"])): 49 """Configuration for parsing a variable-length input feature. 50 51 Fields: 52 dtype: Data type of input. 53 """ 54 pass 55 56 57@tf_export("io.RaggedFeature") 58class RaggedFeature( 59 collections.namedtuple( 60 "RaggedFeature", 61 ["dtype", "value_key", "partitions", "row_splits_dtype", "validate"])): 62 """Configuration for passing a RaggedTensor input feature. 63 64 `value_key` specifies the feature key for a variable-length list of values; 65 and `partitions` specifies zero or more feature keys for partitioning those 66 values into higher dimensions. Each element of `partitions` must be one of 67 the following: 68 69 * `tf.io.RaggedFeature.RowSplits(key: string)` 70 * `tf.io.RaggedFeature.RowLengths(key: string)` 71 * `tf.io.RaggedFeature.RowStarts(key: string)` 72 * `tf.io.RaggedFeature.RowLimits(key: string)` 73 * `tf.io.RaggedFeature.ValueRowIds(key: string)` 74 * `tf.io.RaggedFeature.UniformRowLength(length: int)`. 75 76 Where `key` is a feature key whose values are used to partition the values. 77 Partitions are listed from outermost to innermost. 78 79 * If `len(partitions) == 0` (the default), then: 80 81 * A feature from a single `tf.Example` is parsed into a 1D `tf.Tensor`. 82 * A feature from a batch of `tf.Example`s is parsed into a 2D 83 `tf.RaggedTensor`, where the outer dimension is the batch dimension, and 84 the inner (ragged) dimension is the feature length in each example. 85 86 * If `len(partitions) == 1`, then: 87 88 * A feature from a single `tf.Example` is parsed into a 2D 89 `tf.RaggedTensor`, where the values taken from the `value_key` are 90 separated into rows using the partition key. 91 * A feature from a batch of `tf.Example`s is parsed into a 3D 92 `tf.RaggedTensor`, where the outer dimension is the batch dimension, 93 the two inner dimensions are formed by separating the `value_key` values 94 from each example into rows using that example's partition key. 95 96 * If `len(partitions) > 1`, then: 97 98 * A feature from a single `tf.Example` is parsed into a `tf.RaggedTensor` 99 whose rank is `len(partitions)+1`, and whose ragged_rank is 100 `len(partitions)`. 101 102 * A feature from a batch of `tf.Example`s is parsed into a `tf.RaggedTensor` 103 whose rank is `len(partitions)+2` and whose ragged_rank is 104 `len(partitions)+1`, where the outer dimension is the batch dimension. 105 106 There is one exception: if the final (i.e., innermost) element(s) of 107 `partitions` are `UniformRowLength`s, then the values are simply reshaped (as 108 a higher-dimensional `tf.Tensor`), rather than being wrapped in a 109 `tf.RaggedTensor`. 110 111 #### Examples 112 113 >>> import google.protobuf.text_format as pbtext 114 >>> example_batch = [ 115 ... pbtext.Merge(r''' 116 ... features { 117 ... feature {key: "v" value {int64_list {value: [3, 1, 4, 1, 5, 9]}}} 118 ... feature {key: "s1" value {int64_list {value: [0, 2, 3, 3, 6]}}} 119 ... feature {key: "s2" value {int64_list {value: [0, 2, 3, 4]}}} 120 ... }''', tf.train.Example()).SerializeToString(), 121 ... pbtext.Merge(r''' 122 ... features { 123 ... feature {key: "v" value {int64_list {value: [2, 7, 1, 8, 2, 8, 1]}}} 124 ... feature {key: "s1" value {int64_list {value: [0, 3, 4, 5, 7]}}} 125 ... feature {key: "s2" value {int64_list {value: [0, 1, 1, 4]}}} 126 ... }''', tf.train.Example()).SerializeToString()] 127 128 >>> features = { 129 ... # Zero partitions: returns 1D tf.Tensor for each Example. 130 ... 'f1': tf.io.RaggedFeature(value_key="v", dtype=tf.int64), 131 ... # One partition: returns 2D tf.RaggedTensor for each Example. 132 ... 'f2': tf.io.RaggedFeature(value_key="v", dtype=tf.int64, partitions=[ 133 ... tf.io.RaggedFeature.RowSplits("s1")]), 134 ... # Two partitions: returns 3D tf.RaggedTensor for each Example. 135 ... 'f3': tf.io.RaggedFeature(value_key="v", dtype=tf.int64, partitions=[ 136 ... tf.io.RaggedFeature.RowSplits("s2"), 137 ... tf.io.RaggedFeature.RowSplits("s1")]) 138 ... } 139 140 >>> feature_dict = tf.io.parse_single_example(example_batch[0], features) 141 >>> for (name, val) in sorted(feature_dict.items()): 142 ... print('%s: %s' % (name, val)) 143 f1: tf.Tensor([3 1 4 1 5 9], shape=(6,), dtype=int64) 144 f2: <tf.RaggedTensor [[3, 1], [4], [], [1, 5, 9]]> 145 f3: <tf.RaggedTensor [[[3, 1], [4]], [[]], [[1, 5, 9]]]> 146 147 >>> feature_dict = tf.io.parse_example(example_batch, features) 148 >>> for (name, val) in sorted(feature_dict.items()): 149 ... print('%s: %s' % (name, val)) 150 f1: <tf.RaggedTensor [[3, 1, 4, 1, 5, 9], 151 [2, 7, 1, 8, 2, 8, 1]]> 152 f2: <tf.RaggedTensor [[[3, 1], [4], [], [1, 5, 9]], 153 [[2, 7, 1], [8], [2], [8, 1]]]> 154 f3: <tf.RaggedTensor [[[[3, 1], [4]], [[]], [[1, 5, 9]]], 155 [[[2, 7, 1]], [], [[8], [2], [8, 1]]]]> 156 157 Fields: 158 dtype: Data type of the `RaggedTensor`. Must be one of: 159 `tf.dtypes.int64`, `tf.dtypes.float32`, `tf.dtypes.string`. 160 value_key: (Optional.) Key for a `Feature` in the input `Example`, whose 161 parsed `Tensor` will be the resulting `RaggedTensor.flat_values`. If 162 not specified, then it defaults to the key for this `RaggedFeature`. 163 partitions: (Optional.) A list of objects specifying the row-partitioning 164 tensors (from outermost to innermost). Each entry in this list must be 165 one of: 166 * `tf.io.RaggedFeature.RowSplits(key: string)` 167 * `tf.io.RaggedFeature.RowLengths(key: string)` 168 * `tf.io.RaggedFeature.RowStarts(key: string)` 169 * `tf.io.RaggedFeature.RowLimits(key: string)` 170 * `tf.io.RaggedFeature.ValueRowIds(key: string)` 171 * `tf.io.RaggedFeature.UniformRowLength(length: int)`. 172 Where `key` is a key for a `Feature` in the input `Example`, whose parsed 173 `Tensor` will be the resulting row-partitioning tensor. 174 row_splits_dtype: (Optional.) Data type for the row-partitioning tensor(s). 175 One of `int32` or `int64`. Defaults to `int32`. 176 validate: (Optional.) Boolean indicating whether or not to validate that 177 the input values form a valid RaggedTensor. Defaults to `False`. 178 """ 179 180 # pylint: disable=invalid-name 181 RowSplits = collections.namedtuple("RowSplits", ["key"]) 182 RowLengths = collections.namedtuple("RowLengths", ["key"]) 183 RowStarts = collections.namedtuple("RowStarts", ["key"]) 184 RowLimits = collections.namedtuple("RowLimits", ["key"]) 185 ValueRowIds = collections.namedtuple("ValueRowIds", ["key"]) 186 UniformRowLength = collections.namedtuple("UniformRowLength", ["length"]) 187 # pylint: enable=invalid-name 188 189 _PARTITION_TYPES = (RowSplits, RowLengths, RowStarts, RowLimits, ValueRowIds, 190 UniformRowLength) 191 192 def __new__(cls, 193 dtype, 194 value_key=None, 195 partitions=(), 196 row_splits_dtype=dtypes.int32, 197 validate=False): 198 if value_key is not None: 199 if not isinstance(value_key, str): 200 raise ValueError("value_key must be a string; got %r" % value_key) 201 if not value_key: 202 raise ValueError("value_key may not be empty") 203 dtype = dtypes.as_dtype(dtype) 204 if dtype not in (dtypes.int64, dtypes.float32, dtypes.string): 205 raise ValueError("dtypes must be int64, float32, or bytes; got %r" % 206 dtype) 207 row_splits_dtype = dtypes.as_dtype(row_splits_dtype) 208 if row_splits_dtype not in (dtypes.int32, dtypes.int64): 209 raise ValueError("row_splits_dtype must be int32 or int64; got %r" % 210 row_splits_dtype) 211 if not isinstance(partitions, (list, tuple)): 212 raise TypeError("partitions must be a list or tuple") 213 for partition in partitions: 214 if not isinstance(partition, cls._PARTITION_TYPES): 215 raise TypeError("partitions must be a list of partition objects %s;" 216 " got: %r" % (cls._PARTITION_TYPES, partition)) 217 if not isinstance(validate, bool): 218 raise TypeError("validate must be a bool; got %r" % validate) 219 return super(RaggedFeature, cls).__new__(cls, dtype, value_key, partitions, 220 row_splits_dtype, validate) 221 222 223@tf_export("io.SparseFeature", v1=["io.SparseFeature", "SparseFeature"]) 224class SparseFeature( 225 collections.namedtuple( 226 "SparseFeature", 227 ["index_key", "value_key", "dtype", "size", "already_sorted"])): 228 """Configuration for parsing a sparse input feature from an `Example`. 229 230 Note, preferably use `VarLenFeature` (possibly in combination with a 231 `SequenceExample`) in order to parse out `SparseTensor`s instead of 232 `SparseFeature` due to its simplicity. 233 234 Closely mimicking the `SparseTensor` that will be obtained by parsing an 235 `Example` with a `SparseFeature` config, a `SparseFeature` contains a 236 237 * `value_key`: The name of key for a `Feature` in the `Example` whose parsed 238 `Tensor` will be the resulting `SparseTensor.values`. 239 240 * `index_key`: A list of names - one for each dimension in the resulting 241 `SparseTensor` whose `indices[i][dim]` indicating the position of 242 the `i`-th value in the `dim` dimension will be equal to the `i`-th value in 243 the Feature with key named `index_key[dim]` in the `Example`. 244 245 * `size`: A list of ints for the resulting `SparseTensor.dense_shape`. 246 247 For example, we can represent the following 2D `SparseTensor` 248 249 ```python 250 SparseTensor(indices=[[3, 1], [20, 0]], 251 values=[0.5, -1.0] 252 dense_shape=[100, 3]) 253 ``` 254 255 with an `Example` input proto 256 257 ```python 258 features { 259 feature { key: "val" value { float_list { value: [ 0.5, -1.0 ] } } } 260 feature { key: "ix0" value { int64_list { value: [ 3, 20 ] } } } 261 feature { key: "ix1" value { int64_list { value: [ 1, 0 ] } } } 262 } 263 ``` 264 265 and `SparseFeature` config with 2 `index_key`s 266 267 ```python 268 SparseFeature(index_key=["ix0", "ix1"], 269 value_key="val", 270 dtype=tf.float32, 271 size=[100, 3]) 272 ``` 273 274 Fields: 275 index_key: A single string name or a list of string names of index features. 276 For each key the underlying feature's type must be `int64` and its length 277 must always match that of the `value_key` feature. 278 To represent `SparseTensor`s with a `dense_shape` of `rank` higher than 1 279 a list of length `rank` should be used. 280 value_key: Name of value feature. The underlying feature's type must 281 be `dtype` and its length must always match that of all the `index_key`s' 282 features. 283 dtype: Data type of the `value_key` feature. 284 size: A Python int or list thereof specifying the dense shape. Should be a 285 list if and only if `index_key` is a list. In that case the list must be 286 equal to the length of `index_key`. Each for each entry `i` all values in 287 the `index_key`[i] feature must be in `[0, size[i])`. 288 already_sorted: A Python boolean to specify whether the values in 289 `value_key` are already sorted by their index position. If so skip 290 sorting. False by default (optional). 291 """ 292 293 def __new__(cls, index_key, value_key, dtype, size, already_sorted=False): 294 return super(SparseFeature, cls).__new__( 295 cls, index_key, value_key, dtype, size, already_sorted) 296 297 298@tf_export("io.FixedLenFeature", v1=["io.FixedLenFeature", "FixedLenFeature"]) 299class FixedLenFeature(collections.namedtuple( 300 "FixedLenFeature", ["shape", "dtype", "default_value"])): 301 """Configuration for parsing a fixed-length input feature. 302 303 To treat sparse input as dense, provide a `default_value`; otherwise, 304 the parse functions will fail on any examples missing this feature. 305 306 Fields: 307 shape: Shape of input data. 308 dtype: Data type of input. 309 default_value: Value to be used if an example is missing this feature. It 310 must be compatible with `dtype` and of the specified `shape`. 311 """ 312 313 def __new__(cls, shape, dtype, default_value=None): 314 return super(FixedLenFeature, cls).__new__( 315 cls, shape, dtype, default_value) 316 317 318@tf_export("io.FixedLenSequenceFeature", 319 v1=["io.FixedLenSequenceFeature", "FixedLenSequenceFeature"]) 320class FixedLenSequenceFeature(collections.namedtuple( 321 "FixedLenSequenceFeature", 322 ["shape", "dtype", "allow_missing", "default_value"])): 323 """Configuration for parsing a variable-length input feature into a `Tensor`. 324 325 The resulting `Tensor` of parsing a single `SequenceExample` or `Example` has 326 a static `shape` of `[None] + shape` and the specified `dtype`. 327 The resulting `Tensor` of parsing a `batch_size` many `Example`s has 328 a static `shape` of `[batch_size, None] + shape` and the specified `dtype`. 329 The entries in the `batch` from different `Examples` will be padded with 330 `default_value` to the maximum length present in the `batch`. 331 332 To treat a sparse input as dense, provide `allow_missing=True`; otherwise, 333 the parse functions will fail on any examples missing this feature. 334 335 Fields: 336 shape: Shape of input data for dimension 2 and higher. First dimension is 337 of variable length `None`. 338 dtype: Data type of input. 339 allow_missing: Whether to allow this feature to be missing from a feature 340 list item. Is available only for parsing `SequenceExample` not for 341 parsing `Examples`. 342 default_value: Scalar value to be used to pad multiple `Example`s to their 343 maximum length. Irrelevant for parsing a single `Example` or 344 `SequenceExample`. Defaults to "" for dtype string and 0 otherwise 345 (optional). 346 """ 347 348 def __new__(cls, shape, dtype, allow_missing=False, default_value=None): 349 return super(FixedLenSequenceFeature, cls).__new__( 350 cls, shape, dtype, allow_missing, default_value) 351 352 353class _ParseOpParams(object): 354 """Raw parameters used by `gen_parsing_ops`. 355 356 Attributes: 357 sparse_keys: A list of string keys in the examples' features. The results 358 for these keys will be returned as `SparseTensor` objects. 359 sparse_types: A list of `DTypes` of the same length as `sparse_keys`. Only 360 `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string` 361 (`BytesList`) are supported. 362 dense_keys: A list of string keys in the examples' features. The results for 363 these keys will be returned as `Tensor`s 364 dense_types: A list of DTypes of the same length as `dense_keys`. Only 365 `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string` 366 (`BytesList`) are supported. 367 dense_defaults: A dict mapping string keys to `Tensor`s. The keys of the 368 dict must match the dense_keys of the feature. 369 dense_shapes: A list of tuples with the same length as `dense_keys`. The 370 shape of the data for each dense feature referenced by `dense_keys`. 371 Required for any input tensors identified by `dense_keys`. Must be either 372 fully defined, or may contain an unknown first dimension. An unknown first 373 dimension means the feature is treated as having a variable number of 374 blocks, and the output shape along this dimension is considered unknown at 375 graph build time. Padding is applied for minibatch elements smaller than 376 the maximum number of blocks for the given feature along this dimension. 377 ragged_keys: A list of string keys in the examples' features. The 378 results for these keys will be returned as `RaggedTensor` objects. 379 ragged_value_types: A list of `DTypes` of the same length as `ragged_keys`, 380 specifying the value type for each ragged feature. Must be one of: 381 `tf.float32`, `tf.int64`, `tf.string`. 382 ragged_split_types: A list of `DTypes` of the same length as `ragged_keys`, 383 specifying the row_splits type for each ragged feature. Must be one of: 384 `tf.int32`, `tf.int64`. 385 dense_shapes_as_proto: dense_shapes converted to TensorShapeProto. 386 dense_defaults_vec: A vector of `Tensor`s containing the default values, 387 corresponding 1:1 with `dense_keys`. 388 num_features: The total number of feature keys. 389 """ 390 391 def __init__(self, 392 sparse_keys=None, 393 sparse_types=None, 394 dense_keys=None, 395 dense_types=None, 396 dense_defaults=None, 397 dense_shapes=None, 398 ragged_keys=None, 399 ragged_value_types=None, 400 ragged_split_types=None): 401 # Note: we use an OrderedDict for dense_defaults, to ensure consistent 402 # graph construction order for _e2e_test. 403 dense_defaults = ( 404 collections.OrderedDict() if dense_defaults is None else dense_defaults) 405 sparse_keys = [] if sparse_keys is None else sparse_keys 406 sparse_types = [] if sparse_types is None else sparse_types 407 dense_keys = [] if dense_keys is None else dense_keys 408 dense_types = [] if dense_types is None else dense_types 409 dense_shapes = ([[]] * 410 len(dense_keys) if dense_shapes is None else dense_shapes) 411 ragged_keys = [] if ragged_keys is None else ragged_keys 412 ragged_value_types = ([] 413 if ragged_value_types is None else ragged_value_types) 414 ragged_split_types = ([] 415 if ragged_split_types is None else ragged_split_types) 416 self.sparse_keys = sparse_keys 417 self.sparse_types = [dtypes.as_dtype(t) for t in sparse_types] 418 self.dense_keys = dense_keys 419 self.dense_types = [dtypes.as_dtype(t) for t in dense_types] 420 self.dense_shapes = [tensor_shape.as_shape(s) for s in dense_shapes] 421 self.dense_defaults = dense_defaults 422 self.ragged_keys = ragged_keys 423 self.ragged_value_types = [dtypes.as_dtype(t) for t in ragged_value_types] 424 self.ragged_split_types = [dtypes.as_dtype(t) for t in ragged_split_types] 425 self._validate() 426 427 @classmethod 428 def from_features(cls, features, types): 429 """Builds _ParseOpParams for a given set of features and allowed types. 430 431 Args: 432 features: A `dict` mapping feature keys to objects of a type in `types`. 433 types: Type of features to allow, among `FixedLenFeature`, 434 `VarLenFeature`, `SparseFeature`, and `FixedLenSequenceFeature`. 435 436 Returns: 437 A `_ParseOpParams` containing the raw parameters for `gen_parsing_ops`. 438 439 Raises: 440 ValueError: if `features` contains an item not in `types`, or an invalid 441 feature. 442 ValueError: if sparse and dense key sets intersect. 443 ValueError: if input lengths do not match up. 444 """ 445 params = cls() 446 if features: 447 # NOTE: We iterate over sorted keys to keep things deterministic. 448 for key in sorted(features.keys()): 449 feature = features[key] 450 if not isinstance(feature, tuple(types)): 451 raise ValueError("Unsupported %s %s for key '%s')." % 452 (type(feature).__name__, feature, key)) 453 params._add_feature(key, feature) # pylint: disable=protected-access 454 params._validate() # pylint: disable=protected-access 455 return params 456 457 @property 458 def dense_shapes_as_proto(self): 459 return [shape.as_proto() for shape in self.dense_shapes] 460 461 @property 462 def num_features(self): 463 return len(self.dense_keys) + len(self.sparse_keys) + len(self.ragged_keys) 464 465 @property 466 def dense_defaults_vec(self): 467 return [ 468 self._make_dense_default(k, s, t) 469 for k, s, t in zip(self.dense_keys, self.dense_shapes, self.dense_types) 470 ] 471 472 def _make_dense_default(self, key, shape, dtype): 473 """Construct the default value tensor for a specified dense feature. 474 475 Args: 476 key: The key string identifying the dense feature. 477 shape: The dense feature's shape. 478 dtype: The dense feature's dtype. 479 480 Returns: 481 A Tensor. 482 """ 483 default_value = self.dense_defaults.get(key) 484 if (shape.ndims is not None and shape.ndims > 0 and 485 shape.dims[0].value is None): 486 # Variable stride dense shape, the default value should be a 487 # scalar padding value. 488 if default_value is None: 489 default_value = ops.convert_to_tensor( 490 "" if dtype == dtypes.string else 0, dtype=dtype) 491 else: 492 # Reshape to a scalar to ensure user gets an error if they 493 # provide a tensor that's not intended to be a padding value 494 # (0 or 2+ elements). 495 key_name = "padding_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key) 496 default_value = ops.convert_to_tensor( 497 default_value, dtype=dtype, name=key_name) 498 default_value = array_ops.reshape(default_value, []) 499 else: 500 if default_value is None: 501 default_value = constant_op.constant([], dtype=dtype) 502 elif not isinstance(default_value, ops.Tensor): 503 key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key) 504 default_value = ops.convert_to_tensor( 505 default_value, dtype=dtype, name=key_name) 506 default_value = array_ops.reshape(default_value, shape) 507 508 return default_value 509 510 def _add_feature(self, key, feature): 511 """Adds the specified feature to this ParseOpParams.""" 512 if isinstance(feature, VarLenFeature): 513 self._add_varlen_feature(key, feature) 514 elif isinstance(feature, SparseFeature): 515 self._add_sparse_feature(key, feature) 516 elif isinstance(feature, FixedLenFeature): 517 self._add_fixed_len_feature(key, feature) 518 elif isinstance(feature, FixedLenSequenceFeature): 519 self._add_fixed_len_sequence_feature(key, feature) 520 elif isinstance(feature, RaggedFeature): 521 self._add_ragged_feature(key, feature) 522 else: 523 raise ValueError("Invalid feature %s:%s." % (key, feature)) 524 525 def _add_varlen_feature(self, key, feature): 526 """Adds a VarLenFeature.""" 527 if not feature.dtype: 528 raise ValueError("Missing type for feature %s." % key) 529 self._add_sparse_key(key, feature.dtype) 530 531 def _add_sparse_key(self, key, dtype): 532 """Adds a sparse key & dtype, checking for duplicates.""" 533 if key in self.sparse_keys: 534 original_dtype = self.sparse_types[self.sparse_keys.index(key)] 535 if original_dtype != dtype: 536 raise ValueError("Conflicting type %s vs %s for feature %s." % 537 (original_dtype, dtype, key)) 538 else: 539 self.sparse_keys.append(key) 540 self.sparse_types.append(dtype) 541 542 def _add_sparse_feature(self, key, feature): 543 """Adds a SparseFeature.""" 544 545 if not feature.index_key: 546 raise ValueError("Missing index_key for SparseFeature %s." % (feature,)) 547 if not feature.value_key: 548 raise ValueError("Missing value_key for SparseFeature %s." % (feature,)) 549 if not feature.dtype: 550 raise ValueError("Missing type for feature %s." % key) 551 index_keys = feature.index_key 552 if isinstance(index_keys, str): 553 index_keys = [index_keys] 554 elif len(index_keys) > 1: 555 tf_logging.warning("SparseFeature is a complicated feature config " 556 "and should only be used after careful " 557 "consideration of VarLenFeature.") 558 for index_key in sorted(index_keys): 559 self._add_sparse_key(index_key, dtypes.int64) 560 self._add_sparse_key(feature.value_key, feature.dtype) 561 562 def _add_fixed_len_feature(self, key, feature): 563 """Adds a FixedLenFeature.""" 564 if not feature.dtype: 565 raise ValueError("Missing type for feature %s." % key) 566 if feature.shape is None: 567 raise ValueError("Missing shape for feature %s." % key) 568 feature_tensor_shape = tensor_shape.as_shape(feature.shape) 569 if (feature.shape and feature_tensor_shape.ndims and 570 feature_tensor_shape.dims[0].value is None): 571 raise ValueError("First dimension of shape for feature %s unknown. " 572 "Consider using FixedLenSequenceFeature." % key) 573 if (feature.shape is not None and 574 not feature_tensor_shape.is_fully_defined()): 575 raise ValueError("All dimensions of shape for feature %s need to be " 576 "known but received %s." % (key, str(feature.shape))) 577 self.dense_keys.append(key) 578 self.dense_shapes.append(tensor_shape.as_shape(feature.shape)) 579 self.dense_types.append(feature.dtype) 580 if feature.default_value is not None: 581 self.dense_defaults[key] = feature.default_value 582 583 def _add_fixed_len_sequence_feature(self, key, feature): 584 """Adds a FixedLenSequenceFeature.""" 585 if not feature.dtype: 586 raise ValueError("Missing type for feature %s." % key) 587 if feature.shape is None: 588 raise ValueError("Missing shape for feature %s." % key) 589 self.dense_keys.append(key) 590 self.dense_shapes.append(tensor_shape.as_shape(feature.shape)) 591 self.dense_types.append(feature.dtype) 592 if feature.allow_missing: 593 self.dense_defaults[key] = None 594 if feature.default_value is not None: 595 self.dense_defaults[key] = feature.default_value 596 597 def _add_ragged_key(self, key, value_type, split_type): 598 """Adds a ragged key & dtype, checking for duplicates.""" 599 if key in self.ragged_keys: 600 original_value_type = self.ragged_value_types[self.ragged_keys.index(key)] 601 original_split_type = self.ragged_split_types[self.ragged_keys.index(key)] 602 if original_value_type != value_type: 603 raise ValueError("Conflicting type %s vs %s for feature %s." % 604 (original_value_type, value_type, key)) 605 if original_split_type != split_type: 606 raise ValueError("Conflicting partition type %s vs %s for feature %s." % 607 (original_split_type, split_type, key)) 608 else: 609 self.ragged_keys.append(key) 610 self.ragged_value_types.append(value_type) 611 self.ragged_split_types.append(split_type) 612 613 def _add_ragged_feature(self, key, feature): 614 """Adds a RaggedFeature.""" 615 value_key = key if feature.value_key is None else feature.value_key 616 self._add_ragged_key(value_key, feature.dtype, feature.row_splits_dtype) 617 for partition in feature.partitions: 618 if not isinstance(partition, RaggedFeature.UniformRowLength): 619 self._add_ragged_key(partition.key, dtypes.int64, 620 feature.row_splits_dtype) 621 622 def _validate(self): 623 """Validates the features in this ParseOpParams.""" 624 if len(self.dense_shapes) != len(self.dense_keys): 625 raise ValueError( 626 "len(self.dense_shapes) != len(self.dense_keys): %d vs %d" % 627 (len(self.dense_shapes), len(self.dense_keys))) 628 if len(self.dense_types) != len(self.dense_keys): 629 raise ValueError( 630 "len(self.dense_types) != len(self.dense_keys): %d vs %d" % 631 (len(self.dense_types), len(self.dense_keys))) 632 if len(self.sparse_types) != len(self.sparse_keys): 633 raise ValueError( 634 "len(self.sparse_types) != len(self.sparse_keys): %d vs %d" % 635 (len(self.sparse_types), len(self.sparse_keys))) 636 if len(self.ragged_value_types) != len(self.ragged_keys): 637 raise ValueError( 638 "len(self.ragged_value_types) != len(self.ragged_keys): %d vs %d" % 639 (len(self.ragged_value_types), len(self.ragged_keys))) 640 if len(self.ragged_split_types) != len(self.ragged_keys): 641 raise ValueError( 642 "len(self.ragged_split_types) != len(self.ragged_keys): %d vs %d" % 643 (len(self.ragged_split_types), len(self.ragged_keys))) 644 645 dense_key_set = set(self.dense_keys) 646 sparse_key_set = set(self.sparse_keys) 647 ragged_key_set = set(self.ragged_keys) 648 if not dense_key_set.isdisjoint(sparse_key_set): 649 raise ValueError( 650 "Dense and sparse keys must not intersect; intersection: %s" % 651 dense_key_set.intersection(sparse_key_set)) 652 if not dense_key_set.isdisjoint(ragged_key_set): 653 raise ValueError( 654 "Dense and ragged keys must not intersect; intersection: %s" % 655 dense_key_set.intersection(ragged_key_set)) 656 if not ragged_key_set.isdisjoint(sparse_key_set): 657 raise ValueError( 658 "Ragged and sparse keys must not intersect; intersection: %s" % 659 ragged_key_set.intersection(sparse_key_set)) 660 661 662def _construct_tensors_for_composite_features(features, tensor_dict): 663 """Creates tensors for SparseFeatures and RaggedFeatures. 664 665 Constructs new dict based on `tensor_dict`. 666 667 For each key in `features` whose value is a `SparseFeature`: 668 669 * Looks up that SparseFeature's value_key and index_keys in tensor_dict. 670 * Uses those tensors to construct a single SparseTensor. 671 * Stores that SparseTensor in the output dict under the same key. 672 673 For each key in `features` whose value is a `RaggedFeature`: 674 675 * Looks up that RaggedFeature's value_key and partition keys in tensor_dict. 676 * Uses those tensors to construct a single RaggedTensor. 677 * Stores that RaggedTensor in the output dict under the same key. 678 679 For any other key in `features`: 680 681 * Copies that key and its value from tensor_dict to the output dictionary. 682 683 Args: 684 features: A `dict` mapping feature keys to `SparseFeature` or 685 `RaggedFeature` values. Values of other types will be ignored. 686 tensor_dict: A `dict` mapping feature keys to `Tensor`, `SparseTensor`, and 687 `RaggedTensor` values. Expected to contain keys of the `SparseFeature`s' 688 `index_key`s and `value_key`s and mapping them to `SparseTensor`s. 689 690 Returns: 691 A `dict` mapping feature keys to `Tensor`, `SparseTensor`, and 692 `RaggedTensor` values. Similar to `tensor_dict` except each `SparseFeature` 693 in `features` results in a single `SparseTensor`; and each `RaggedFeature` 694 in `features` results in a single `RaggedTensor`. 695 """ 696 tensor_dict = dict(tensor_dict) # Do not modify argument passed in. 697 updates = {} 698 for key in sorted(features.keys()): 699 feature = features[key] 700 if isinstance(feature, SparseFeature): 701 # Construct SparseTensors for SparseFeatures 702 if isinstance(feature.index_key, str): 703 sp_ids = tensor_dict[feature.index_key] 704 else: 705 sp_ids = [tensor_dict[index_key] for index_key in feature.index_key] 706 sp_values = tensor_dict[feature.value_key] 707 updates[key] = sparse_ops.sparse_merge( 708 sp_ids, 709 sp_values, 710 vocab_size=feature.size, 711 already_sorted=feature.already_sorted) 712 elif isinstance(feature, RaggedFeature): 713 # Construct RaggedTensors for RaggedFeatures. 714 value_key = key if feature.value_key is None else feature.value_key 715 rt = tensor_dict[value_key] 716 if isinstance(rt, ragged_tensor.RaggedTensor): 717 # We processed a batch of tf.Example or tf.SequenceExample, or single 718 # tf.SequenceExample. 719 if rt.ragged_rank > 1: 720 # We're processing a batch of SequenceExample, and we effectively have 721 # two batch dimensions. Cllapse those batch dimensions here, and 722 # restore them below (using outer_splits). 723 outer_splits = rt.row_splits 724 rt = rt.values 725 else: 726 outer_splits = None 727 for partition in reversed(feature.partitions): 728 rt = _add_batched_ragged_partition(rt, partition, tensor_dict, 729 key, feature.validate, 730 outer_splits) 731 if outer_splits is not None: 732 rt = ragged_tensor.RaggedTensor.from_row_splits( 733 rt, outer_splits, validate=feature.validate) 734 else: 735 # We processed a single tf.Example. 736 for partition in reversed(feature.partitions): 737 rt = _add_ragged_partition(rt, partition, tensor_dict, 738 feature.row_splits_dtype, feature.validate) 739 updates[key] = rt 740 741 # Process updates after all composite tensors have been constructed (in case 742 # multiple features use the same value_key, and one uses that key as its 743 # feature key). 744 tensor_dict.update(updates) 745 746 # Remove tensors from dictionary that were only used to construct 747 # tensors for SparseFeature or RaggedTensor. 748 for key in set(tensor_dict) - set(features): 749 del tensor_dict[key] 750 return tensor_dict 751 752 753def _add_ragged_partition(values, partition, tensor_dict, row_splits_dtype, 754 validate): 755 """Creates a RaggedTensor from a values tensor and a partition tensor. 756 757 Args: 758 values: The values tensor for the new RaggedTensor. 759 partition: The partition configuration object. Specifies the key that 760 should be used to look up the partition tensor (unless partition is a 761 RaggedFeature.UniformRowLength, in which case there is no partition 762 tensor). 763 tensor_dict: The dictionary mapping keys to tensors. 764 row_splits_dtype: The dtype for the partition tensor. 765 validate: Whether to validate that the values form a valid RaggedTensor. 766 767 Returns: 768 A new RaggedTensor formed from the values and partition tensors. 769 """ 770 if isinstance(partition, RaggedFeature.UniformRowLength): 771 if isinstance(values, ragged_tensor.RaggedTensor): 772 length = ops.convert_to_tensor(partition.length, dtype=row_splits_dtype) 773 return ragged_tensor.RaggedTensor.from_uniform_row_length( 774 values, length, validate=validate) 775 else: 776 return array_ops.reshape(values, array_ops.concat( 777 [[-1, partition.length], array_ops.shape(values)[1:]], axis=0)) 778 else: 779 partition_t = math_ops.cast(tensor_dict[partition.key], row_splits_dtype) 780 if isinstance(partition, RaggedFeature.RowSplits): 781 return ragged_tensor.RaggedTensor.from_row_splits( 782 values, partition_t, validate=validate) 783 elif isinstance(partition, RaggedFeature.RowLengths): 784 return ragged_tensor.RaggedTensor.from_row_lengths( 785 values, partition_t, validate=validate) 786 elif isinstance(partition, RaggedFeature.RowStarts): 787 return ragged_tensor.RaggedTensor.from_row_starts( 788 values, partition_t, validate=validate) 789 elif isinstance(partition, RaggedFeature.RowLimits): 790 return ragged_tensor.RaggedTensor.from_row_limits( 791 values, partition_t, validate=validate) 792 elif isinstance(partition, RaggedFeature.ValueRowIds): 793 return ragged_tensor.RaggedTensor.from_value_rowids( 794 values, partition_t, validate=validate) 795 raise ValueError("Unhandled partition type %r" % partition) 796 797 798def _add_batched_ragged_partition(rt, partition, tensor_dict, feature_key, 799 validate, outer_splits=None): 800 """Adds a batched ragged partition tensor to a batched ragged tensor. 801 802 Args: 803 rt: A RaggedTensor with shape [batch_size, ...]. 804 partition: The partition configuration object. Specifies the key that 805 should be used to look up the partition tensor (unless partition is a 806 RaggedFeature.UniformRowLength, in which case there is no partition 807 tensor). The specified tensor must have shape [batch_size, ...]. 808 tensor_dict: The dictionary mapping keys to tensors. 809 feature_key: The name of the feature being parsed (for error messages). 810 validate: Whether to validate that the values form a valid RaggedTensor. 811 outer_splits: If not None, then we have two batch dimensions, and this 812 is the row-splits for the collapsed batch dimension. Every partition 813 tensor must have an outer row_splits that matches this value. 814 815 Returns: 816 A new RaggedTensor where each batch item `rt[i]` has been partitioned 817 using the `partition_t[i]`. 818 """ 819 if isinstance(partition, RaggedFeature.UniformRowLength): 820 if rt.ragged_rank > 1: 821 length = ops.convert_to_tensor(partition.length, rt.row_splits.dtype) 822 return ragged_tensor.RaggedTensor.from_row_splits( 823 ragged_tensor.RaggedTensor.from_uniform_row_length( 824 rt.values, length, validate=validate), 825 rt.row_splits // length, 826 validate=validate) 827 else: 828 reshaped_vals = array_ops.reshape(rt.values, array_ops.concat( 829 [[-1, partition.length], array_ops.shape(rt.values)[1:]], axis=0)) 830 return ragged_tensor.RaggedTensor.from_row_splits( 831 reshaped_vals, rt.row_splits // partition.length, validate=validate) 832 833 partition_t = tensor_dict[partition.key] 834 if partition_t.values.dtype != rt.row_splits.dtype: 835 partition_t = math_ops.cast(partition_t, rt.row_splits.dtype) 836 837 checks = [] 838 if outer_splits is not None: 839 if validate: 840 checks.append(check_ops.assert_equal( 841 outer_splits, partition_t.row_splits, 842 message="Feature %s: values and partitions are not aligned" 843 % feature_key)) 844 partition_t = partition_t.values 845 846 with ops.control_dependencies(checks): 847 if isinstance(partition, (RaggedFeature.RowSplits, 848 RaggedFeature.RowLimits)): 849 if isinstance(partition, RaggedFeature.RowSplits): 850 partition_t = partition_t[:, 1:] 851 adjusted_limits = partition_t.values + array_ops.repeat( 852 rt.row_starts(), partition_t.row_lengths()) 853 return partition_t.with_values( 854 ragged_tensor.RaggedTensor.from_row_limits( 855 rt.values, adjusted_limits, validate=validate)) 856 elif isinstance(partition, RaggedFeature.RowStarts): 857 adjusted_starts = partition_t.values + array_ops.repeat( 858 rt.row_starts(), partition_t.row_lengths()) 859 return partition_t.with_values( 860 ragged_tensor.RaggedTensor.from_row_starts( 861 rt.values, adjusted_starts, validate=validate)) 862 elif isinstance(partition, RaggedFeature.RowLengths): 863 return partition_t.with_values( 864 ragged_tensor.RaggedTensor.from_row_lengths( 865 rt.values, partition_t.values, validate=validate)) 866 elif isinstance(partition, RaggedFeature.ValueRowIds): 867 nrows = math_ops.maximum( # number of rows in each batch item 868 ragged_math_ops.reduce_max(partition_t + 1, axis=1), 0) 869 adjusted_rowids = partition_t.values + array_ops.repeat( 870 math_ops.cumsum(nrows, exclusive=True), partition_t.row_lengths()) 871 return ragged_tensor.RaggedTensor.from_row_lengths( 872 ragged_tensor.RaggedTensor.from_value_rowids( 873 rt.values, adjusted_rowids, validate=validate), 874 nrows, 875 validate=validate) 876 877 raise ValueError("Unhandled partition type %r" % partition) 878 879 880def _build_ragged_tensors(serialized_shape, 881 ragged_values, 882 ragged_row_splits, 883 ragged_inner_splits=None): 884 """Builds RaggedTensors from the outputs of a parse op.""" 885 if ragged_inner_splits is not None: 886 ragged_values = [ 887 ragged_tensor.RaggedTensor.from_row_splits(val, split, validate=False) 888 for (val, split) in zip(ragged_values, ragged_inner_splits) 889 ] 890 if serialized_shape.ndims == 0: 891 return ragged_values 892 else: 893 return [ 894 ragged_tensor.RaggedTensor.from_row_splits(val, split, validate=False) 895 for (val, split) in zip(ragged_values, ragged_row_splits) 896 ] 897