1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Feature configuration for tf.io.parse_example.""" 16 17import collections 18import re 19 20from tensorflow.python.framework import constant_op 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import tensor_shape 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import check_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops import sparse_ops 28from tensorflow.python.ops.ragged import ragged_math_ops 29from tensorflow.python.ops.ragged import ragged_tensor 30from tensorflow.python.platform import tf_logging 31from tensorflow.python.util.tf_export import tf_export 32 33 34# TODO(b/122887740) Refactor code: 35# * Move input verification to feature configuration objects (e.g., 36# VarLenFeature should check that dtype is a valid dtype). 37# * Add an _add_feature() method to each feature configuration object 38# (rather than using a dispatch table in _ParseOpParams._add_feature). 39# * Update _construct_tensors_for_composite_features() to call a method 40# on the feature object (rather than using dispatch). 41 42 43@tf_export("io.VarLenFeature", v1=["VarLenFeature", "io.VarLenFeature"]) 44class VarLenFeature(collections.namedtuple("VarLenFeature", ["dtype"])): 45 """Configuration for parsing a variable-length input feature. 46 47 Fields: 48 dtype: Data type of input. 49 """ 50 pass 51 52 53@tf_export("io.RaggedFeature") 54class RaggedFeature( 55 collections.namedtuple( 56 "RaggedFeature", 57 ["dtype", "value_key", "partitions", "row_splits_dtype", "validate"])): 58 """Configuration for passing a RaggedTensor input feature. 59 60 `value_key` specifies the feature key for a variable-length list of values; 61 and `partitions` specifies zero or more feature keys for partitioning those 62 values into higher dimensions. Each element of `partitions` must be one of 63 the following: 64 65 * `tf.io.RaggedFeature.RowSplits(key: string)` 66 * `tf.io.RaggedFeature.RowLengths(key: string)` 67 * `tf.io.RaggedFeature.RowStarts(key: string)` 68 * `tf.io.RaggedFeature.RowLimits(key: string)` 69 * `tf.io.RaggedFeature.ValueRowIds(key: string)` 70 * `tf.io.RaggedFeature.UniformRowLength(length: int)`. 71 72 Where `key` is a feature key whose values are used to partition the values. 73 Partitions are listed from outermost to innermost. 74 75 * If `len(partitions) == 0` (the default), then: 76 77 * A feature from a single `tf.Example` is parsed into a 1D `tf.Tensor`. 78 * A feature from a batch of `tf.Example`s is parsed into a 2D 79 `tf.RaggedTensor`, where the outer dimension is the batch dimension, and 80 the inner (ragged) dimension is the feature length in each example. 81 82 * If `len(partitions) == 1`, then: 83 84 * A feature from a single `tf.Example` is parsed into a 2D 85 `tf.RaggedTensor`, where the values taken from the `value_key` are 86 separated into rows using the partition key. 87 * A feature from a batch of `tf.Example`s is parsed into a 3D 88 `tf.RaggedTensor`, where the outer dimension is the batch dimension, 89 the two inner dimensions are formed by separating the `value_key` values 90 from each example into rows using that example's partition key. 91 92 * If `len(partitions) > 1`, then: 93 94 * A feature from a single `tf.Example` is parsed into a `tf.RaggedTensor` 95 whose rank is `len(partitions)+1`, and whose ragged_rank is 96 `len(partitions)`. 97 98 * A feature from a batch of `tf.Example`s is parsed into a `tf.RaggedTensor` 99 whose rank is `len(partitions)+2` and whose ragged_rank is 100 `len(partitions)+1`, where the outer dimension is the batch dimension. 101 102 There is one exception: if the final (i.e., innermost) element(s) of 103 `partitions` are `UniformRowLength`s, then the values are simply reshaped (as 104 a higher-dimensional `tf.Tensor`), rather than being wrapped in a 105 `tf.RaggedTensor`. 106 107 #### Examples 108 109 >>> import google.protobuf.text_format as pbtext 110 >>> example_batch = [ 111 ... pbtext.Merge(r''' 112 ... features { 113 ... feature {key: "v" value {int64_list {value: [3, 1, 4, 1, 5, 9]}}} 114 ... feature {key: "s1" value {int64_list {value: [0, 2, 3, 3, 6]}}} 115 ... feature {key: "s2" value {int64_list {value: [0, 2, 3, 4]}}} 116 ... }''', tf.train.Example()).SerializeToString(), 117 ... pbtext.Merge(r''' 118 ... features { 119 ... feature {key: "v" value {int64_list {value: [2, 7, 1, 8, 2, 8, 1]}}} 120 ... feature {key: "s1" value {int64_list {value: [0, 3, 4, 5, 7]}}} 121 ... feature {key: "s2" value {int64_list {value: [0, 1, 1, 4]}}} 122 ... }''', tf.train.Example()).SerializeToString()] 123 124 >>> features = { 125 ... # Zero partitions: returns 1D tf.Tensor for each Example. 126 ... 'f1': tf.io.RaggedFeature(value_key="v", dtype=tf.int64), 127 ... # One partition: returns 2D tf.RaggedTensor for each Example. 128 ... 'f2': tf.io.RaggedFeature(value_key="v", dtype=tf.int64, partitions=[ 129 ... tf.io.RaggedFeature.RowSplits("s1")]), 130 ... # Two partitions: returns 3D tf.RaggedTensor for each Example. 131 ... 'f3': tf.io.RaggedFeature(value_key="v", dtype=tf.int64, partitions=[ 132 ... tf.io.RaggedFeature.RowSplits("s2"), 133 ... tf.io.RaggedFeature.RowSplits("s1")]) 134 ... } 135 136 >>> feature_dict = tf.io.parse_single_example(example_batch[0], features) 137 >>> for (name, val) in sorted(feature_dict.items()): 138 ... print('%s: %s' % (name, val)) 139 f1: tf.Tensor([3 1 4 1 5 9], shape=(6,), dtype=int64) 140 f2: <tf.RaggedTensor [[3, 1], [4], [], [1, 5, 9]]> 141 f3: <tf.RaggedTensor [[[3, 1], [4]], [[]], [[1, 5, 9]]]> 142 143 >>> feature_dict = tf.io.parse_example(example_batch, features) 144 >>> for (name, val) in sorted(feature_dict.items()): 145 ... print('%s: %s' % (name, val)) 146 f1: <tf.RaggedTensor [[3, 1, 4, 1, 5, 9], 147 [2, 7, 1, 8, 2, 8, 1]]> 148 f2: <tf.RaggedTensor [[[3, 1], [4], [], [1, 5, 9]], 149 [[2, 7, 1], [8], [2], [8, 1]]]> 150 f3: <tf.RaggedTensor [[[[3, 1], [4]], [[]], [[1, 5, 9]]], 151 [[[2, 7, 1]], [], [[8], [2], [8, 1]]]]> 152 153 Fields: 154 dtype: Data type of the `RaggedTensor`. Must be one of: 155 `tf.dtypes.int64`, `tf.dtypes.float32`, `tf.dtypes.string`. 156 value_key: (Optional.) Key for a `Feature` in the input `Example`, whose 157 parsed `Tensor` will be the resulting `RaggedTensor.flat_values`. If 158 not specified, then it defaults to the key for this `RaggedFeature`. 159 partitions: (Optional.) A list of objects specifying the row-partitioning 160 tensors (from outermost to innermost). Each entry in this list must be 161 one of: 162 * `tf.io.RaggedFeature.RowSplits(key: string)` 163 * `tf.io.RaggedFeature.RowLengths(key: string)` 164 * `tf.io.RaggedFeature.RowStarts(key: string)` 165 * `tf.io.RaggedFeature.RowLimits(key: string)` 166 * `tf.io.RaggedFeature.ValueRowIds(key: string)` 167 * `tf.io.RaggedFeature.UniformRowLength(length: int)`. 168 Where `key` is a key for a `Feature` in the input `Example`, whose parsed 169 `Tensor` will be the resulting row-partitioning tensor. 170 row_splits_dtype: (Optional.) Data type for the row-partitioning tensor(s). 171 One of `int32` or `int64`. Defaults to `int32`. 172 validate: (Optional.) Boolean indicating whether or not to validate that 173 the input values form a valid RaggedTensor. Defaults to `False`. 174 """ 175 176 # pylint: disable=invalid-name 177 RowSplits = collections.namedtuple("RowSplits", ["key"]) 178 RowLengths = collections.namedtuple("RowLengths", ["key"]) 179 RowStarts = collections.namedtuple("RowStarts", ["key"]) 180 RowLimits = collections.namedtuple("RowLimits", ["key"]) 181 ValueRowIds = collections.namedtuple("ValueRowIds", ["key"]) 182 UniformRowLength = collections.namedtuple("UniformRowLength", ["length"]) 183 # pylint: enable=invalid-name 184 185 _PARTITION_TYPES = (RowSplits, RowLengths, RowStarts, RowLimits, ValueRowIds, 186 UniformRowLength) 187 188 def __new__(cls, 189 dtype, 190 value_key=None, 191 partitions=(), 192 row_splits_dtype=dtypes.int32, 193 validate=False): 194 if value_key is not None: 195 if not isinstance(value_key, str): 196 raise ValueError( 197 f"Argument `value_key` must be a string; got {value_key}") 198 if not value_key: 199 raise ValueError("Argument `value_key` must not be empty") 200 dtype = dtypes.as_dtype(dtype) 201 if dtype not in (dtypes.int64, dtypes.float32, dtypes.string): 202 raise ValueError("Argument `dtype` must be int64, float32, or bytes; got " 203 f"{dtype!r}") 204 row_splits_dtype = dtypes.as_dtype(row_splits_dtype) 205 if row_splits_dtype not in (dtypes.int32, dtypes.int64): 206 raise ValueError("Argument `row_splits_dtype` must be int32 or int64; got" 207 f"{row_splits_dtype!r}") 208 if not isinstance(partitions, (list, tuple)): 209 raise TypeError("Argument `partitions` must be a list or tuple. Received" 210 f"partitions={partitions} of type " 211 f"{type(partitions).__name__}.") 212 for partition in partitions: 213 if not isinstance(partition, cls._PARTITION_TYPES): 214 raise TypeError("Argument `partitions` must be a list of partition " 215 f"objects {cls._PARTITION_TYPES}; got: {partition!r}") 216 if not isinstance(validate, bool): 217 raise TypeError(f"Argument `validate` must be a bool; got {validate!r}") 218 return super(RaggedFeature, cls).__new__(cls, dtype, value_key, partitions, 219 row_splits_dtype, validate) 220 221 222@tf_export("io.SparseFeature", v1=["io.SparseFeature", "SparseFeature"]) 223class SparseFeature( 224 collections.namedtuple( 225 "SparseFeature", 226 ["index_key", "value_key", "dtype", "size", "already_sorted"])): 227 """Configuration for parsing a sparse input feature from an `Example`. 228 229 Note, preferably use `VarLenFeature` (possibly in combination with a 230 `SequenceExample`) in order to parse out `SparseTensor`s instead of 231 `SparseFeature` due to its simplicity. 232 233 Closely mimicking the `SparseTensor` that will be obtained by parsing an 234 `Example` with a `SparseFeature` config, a `SparseFeature` contains a 235 236 * `value_key`: The name of key for a `Feature` in the `Example` whose parsed 237 `Tensor` will be the resulting `SparseTensor.values`. 238 239 * `index_key`: A list of names - one for each dimension in the resulting 240 `SparseTensor` whose `indices[i][dim]` indicating the position of 241 the `i`-th value in the `dim` dimension will be equal to the `i`-th value in 242 the Feature with key named `index_key[dim]` in the `Example`. 243 244 * `size`: A list of ints for the resulting `SparseTensor.dense_shape`. 245 246 For example, we can represent the following 2D `SparseTensor` 247 248 ```python 249 SparseTensor(indices=[[3, 1], [20, 0]], 250 values=[0.5, -1.0] 251 dense_shape=[100, 3]) 252 ``` 253 254 with an `Example` input proto 255 256 ```python 257 features { 258 feature { key: "val" value { float_list { value: [ 0.5, -1.0 ] } } } 259 feature { key: "ix0" value { int64_list { value: [ 3, 20 ] } } } 260 feature { key: "ix1" value { int64_list { value: [ 1, 0 ] } } } 261 } 262 ``` 263 264 and `SparseFeature` config with 2 `index_key`s 265 266 ```python 267 SparseFeature(index_key=["ix0", "ix1"], 268 value_key="val", 269 dtype=tf.float32, 270 size=[100, 3]) 271 ``` 272 273 Fields: 274 index_key: A single string name or a list of string names of index features. 275 For each key the underlying feature's type must be `int64` and its length 276 must always match that of the `value_key` feature. 277 To represent `SparseTensor`s with a `dense_shape` of `rank` higher than 1 278 a list of length `rank` should be used. 279 value_key: Name of value feature. The underlying feature's type must 280 be `dtype` and its length must always match that of all the `index_key`s' 281 features. 282 dtype: Data type of the `value_key` feature. 283 size: A Python int or list thereof specifying the dense shape. Should be a 284 list if and only if `index_key` is a list. In that case the list must be 285 equal to the length of `index_key`. Each for each entry `i` all values in 286 the `index_key`[i] feature must be in `[0, size[i])`. 287 already_sorted: A Python boolean to specify whether the values in 288 `value_key` are already sorted by their index position. If so skip 289 sorting. False by default (optional). 290 """ 291 292 def __new__(cls, index_key, value_key, dtype, size, already_sorted=False): 293 return super(SparseFeature, cls).__new__( 294 cls, index_key, value_key, dtype, size, already_sorted) 295 296 297@tf_export("io.FixedLenFeature", v1=["io.FixedLenFeature", "FixedLenFeature"]) 298class FixedLenFeature(collections.namedtuple( 299 "FixedLenFeature", ["shape", "dtype", "default_value"])): 300 """Configuration for parsing a fixed-length input feature. 301 302 To treat sparse input as dense, provide a `default_value`; otherwise, 303 the parse functions will fail on any examples missing this feature. 304 305 Fields: 306 shape: Shape of input data. 307 dtype: Data type of input. 308 default_value: Value to be used if an example is missing this feature. It 309 must be compatible with `dtype` and of the specified `shape`. 310 """ 311 312 def __new__(cls, shape, dtype, default_value=None): 313 return super(FixedLenFeature, cls).__new__( 314 cls, shape, dtype, default_value) 315 316 317@tf_export("io.FixedLenSequenceFeature", 318 v1=["io.FixedLenSequenceFeature", "FixedLenSequenceFeature"]) 319class FixedLenSequenceFeature(collections.namedtuple( 320 "FixedLenSequenceFeature", 321 ["shape", "dtype", "allow_missing", "default_value"])): 322 """Configuration for parsing a variable-length input feature into a `Tensor`. 323 324 The resulting `Tensor` of parsing a single `SequenceExample` or `Example` has 325 a static `shape` of `[None] + shape` and the specified `dtype`. 326 The resulting `Tensor` of parsing a `batch_size` many `Example`s has 327 a static `shape` of `[batch_size, None] + shape` and the specified `dtype`. 328 The entries in the `batch` from different `Examples` will be padded with 329 `default_value` to the maximum length present in the `batch`. 330 331 To treat a sparse input as dense, provide `allow_missing=True`; otherwise, 332 the parse functions will fail on any examples missing this feature. 333 334 Fields: 335 shape: Shape of input data for dimension 2 and higher. First dimension is 336 of variable length `None`. 337 dtype: Data type of input. 338 allow_missing: Whether to allow this feature to be missing from a feature 339 list item. Is available only for parsing `SequenceExample` not for 340 parsing `Examples`. 341 default_value: Scalar value to be used to pad multiple `Example`s to their 342 maximum length. Irrelevant for parsing a single `Example` or 343 `SequenceExample`. Defaults to "" for dtype string and 0 otherwise 344 (optional). 345 """ 346 347 def __new__(cls, shape, dtype, allow_missing=False, default_value=None): 348 return super(FixedLenSequenceFeature, cls).__new__( 349 cls, shape, dtype, allow_missing, default_value) 350 351 352class _ParseOpParams: 353 """Raw parameters used by `gen_parsing_ops`. 354 355 Attributes: 356 sparse_keys: A list of string keys in the examples' features. The results 357 for these keys will be returned as `SparseTensor` objects. 358 sparse_types: A list of `DTypes` of the same length as `sparse_keys`. Only 359 `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string` 360 (`BytesList`) are supported. 361 dense_keys: A list of string keys in the examples' features. The results for 362 these keys will be returned as `Tensor`s 363 dense_types: A list of DTypes of the same length as `dense_keys`. Only 364 `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string` 365 (`BytesList`) are supported. 366 dense_defaults: A dict mapping string keys to `Tensor`s. The keys of the 367 dict must match the dense_keys of the feature. 368 dense_shapes: A list of tuples with the same length as `dense_keys`. The 369 shape of the data for each dense feature referenced by `dense_keys`. 370 Required for any input tensors identified by `dense_keys`. Must be either 371 fully defined, or may contain an unknown first dimension. An unknown first 372 dimension means the feature is treated as having a variable number of 373 blocks, and the output shape along this dimension is considered unknown at 374 graph build time. Padding is applied for minibatch elements smaller than 375 the maximum number of blocks for the given feature along this dimension. 376 ragged_keys: A list of string keys in the examples' features. The 377 results for these keys will be returned as `RaggedTensor` objects. 378 ragged_value_types: A list of `DTypes` of the same length as `ragged_keys`, 379 specifying the value type for each ragged feature. Must be one of: 380 `tf.float32`, `tf.int64`, `tf.string`. 381 ragged_split_types: A list of `DTypes` of the same length as `ragged_keys`, 382 specifying the row_splits type for each ragged feature. Must be one of: 383 `tf.int32`, `tf.int64`. 384 dense_shapes_as_proto: dense_shapes converted to TensorShapeProto. 385 dense_defaults_vec: A vector of `Tensor`s containing the default values, 386 corresponding 1:1 with `dense_keys`. 387 num_features: The total number of feature keys. 388 """ 389 390 def __init__(self, 391 sparse_keys=None, 392 sparse_types=None, 393 dense_keys=None, 394 dense_types=None, 395 dense_defaults=None, 396 dense_shapes=None, 397 ragged_keys=None, 398 ragged_value_types=None, 399 ragged_split_types=None): 400 # Note: we use an OrderedDict for dense_defaults, to ensure consistent 401 # graph construction order for _e2e_test. 402 dense_defaults = ( 403 collections.OrderedDict() if dense_defaults is None else dense_defaults) 404 sparse_keys = [] if sparse_keys is None else sparse_keys 405 sparse_types = [] if sparse_types is None else sparse_types 406 dense_keys = [] if dense_keys is None else dense_keys 407 dense_types = [] if dense_types is None else dense_types 408 dense_shapes = ([[]] * 409 len(dense_keys) if dense_shapes is None else dense_shapes) 410 ragged_keys = [] if ragged_keys is None else ragged_keys 411 ragged_value_types = ([] 412 if ragged_value_types is None else ragged_value_types) 413 ragged_split_types = ([] 414 if ragged_split_types is None else ragged_split_types) 415 self.sparse_keys = sparse_keys 416 self.sparse_types = [dtypes.as_dtype(t) for t in sparse_types] 417 self.dense_keys = dense_keys 418 self.dense_types = [dtypes.as_dtype(t) for t in dense_types] 419 self.dense_shapes = [tensor_shape.as_shape(s) for s in dense_shapes] 420 self.dense_defaults = dense_defaults 421 self.ragged_keys = ragged_keys 422 self.ragged_value_types = [dtypes.as_dtype(t) for t in ragged_value_types] 423 self.ragged_split_types = [dtypes.as_dtype(t) for t in ragged_split_types] 424 self._validate() 425 426 @classmethod 427 def from_features(cls, features, types): 428 """Builds _ParseOpParams for a given set of features and allowed types. 429 430 Args: 431 features: A `dict` mapping feature keys to objects of a type in `types`. 432 types: Type of features to allow, among `FixedLenFeature`, 433 `VarLenFeature`, `SparseFeature`, and `FixedLenSequenceFeature`. 434 435 Returns: 436 A `_ParseOpParams` containing the raw parameters for `gen_parsing_ops`. 437 438 Raises: 439 ValueError: if `features` contains an item not in `types`, or an invalid 440 feature. 441 ValueError: if sparse and dense key sets intersect. 442 ValueError: if input lengths do not match up. 443 """ 444 params = cls() 445 if features: 446 # NOTE: We iterate over sorted keys to keep things deterministic. 447 for key in sorted(features.keys()): 448 feature = features[key] 449 if not isinstance(feature, tuple(types)): 450 raise ValueError( 451 f"Unsupported {type(feature).__name__} {feature} for key '{key}'") 452 params._add_feature(key, feature) # pylint: disable=protected-access 453 params._validate() # pylint: disable=protected-access 454 return params 455 456 @property 457 def dense_shapes_as_proto(self): 458 return [shape.as_proto() for shape in self.dense_shapes] 459 460 @property 461 def num_features(self): 462 return len(self.dense_keys) + len(self.sparse_keys) + len(self.ragged_keys) 463 464 @property 465 def dense_defaults_vec(self): 466 return [ 467 self._make_dense_default(k, s, t) 468 for k, s, t in zip(self.dense_keys, self.dense_shapes, self.dense_types) 469 ] 470 471 def _make_dense_default(self, key, shape, dtype): 472 """Construct the default value tensor for a specified dense feature. 473 474 Args: 475 key: The key string identifying the dense feature. 476 shape: The dense feature's shape. 477 dtype: The dense feature's dtype. 478 479 Returns: 480 A Tensor. 481 """ 482 default_value = self.dense_defaults.get(key) 483 if (shape.ndims is not None and shape.ndims > 0 and 484 shape.dims[0].value is None): 485 # Variable stride dense shape, the default value should be a 486 # scalar padding value. 487 if default_value is None: 488 default_value = ops.convert_to_tensor( 489 "" if dtype == dtypes.string else 0, dtype=dtype) 490 else: 491 # Reshape to a scalar to ensure user gets an error if they 492 # provide a tensor that's not intended to be a padding value 493 # (0 or 2+ elements). 494 key_name = "padding_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key) 495 default_value = ops.convert_to_tensor( 496 default_value, dtype=dtype, name=key_name) 497 default_value = array_ops.reshape(default_value, []) 498 else: 499 if default_value is None: 500 default_value = constant_op.constant([], dtype=dtype) 501 elif not isinstance(default_value, ops.Tensor): 502 key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key) 503 default_value = ops.convert_to_tensor( 504 default_value, dtype=dtype, name=key_name) 505 default_value = array_ops.reshape(default_value, shape) 506 507 return default_value 508 509 def _add_feature(self, key, feature): 510 """Adds the specified feature to this ParseOpParams.""" 511 if isinstance(feature, VarLenFeature): 512 self._add_varlen_feature(key, feature) 513 elif isinstance(feature, SparseFeature): 514 self._add_sparse_feature(key, feature) 515 elif isinstance(feature, FixedLenFeature): 516 self._add_fixed_len_feature(key, feature) 517 elif isinstance(feature, FixedLenSequenceFeature): 518 self._add_fixed_len_sequence_feature(key, feature) 519 elif isinstance(feature, RaggedFeature): 520 self._add_ragged_feature(key, feature) 521 else: 522 raise ValueError(f"Invalid feature {key}:{feature}.") 523 524 def _add_varlen_feature(self, key, feature): 525 """Adds a VarLenFeature.""" 526 if not feature.dtype: 527 raise ValueError( 528 f"Missing type for feature {key}. Received feature={feature}") 529 self._add_sparse_key(key, feature.dtype) 530 531 def _add_sparse_key(self, key, dtype): 532 """Adds a sparse key & dtype, checking for duplicates.""" 533 if key in self.sparse_keys: 534 original_dtype = self.sparse_types[self.sparse_keys.index(key)] 535 if original_dtype != dtype: 536 raise ValueError( 537 f"Conflicting type {original_dtype} vs {dtype} for feature {key}.") 538 else: 539 self.sparse_keys.append(key) 540 self.sparse_types.append(dtype) 541 542 def _add_sparse_feature(self, key, feature): 543 """Adds a SparseFeature.""" 544 545 if not feature.index_key: 546 raise ValueError(f"Missing index_key for SparseFeature {feature}.") 547 if not feature.value_key: 548 raise ValueError(f"Missing value_key for SparseFeature {feature}.") 549 if not feature.dtype: 550 raise ValueError(f"Missing type for feature {key}. Received feature=" 551 f"{feature}.") 552 index_keys = feature.index_key 553 if isinstance(index_keys, str): 554 index_keys = [index_keys] 555 elif len(index_keys) > 1: 556 tf_logging.warning("SparseFeature is a complicated feature config " 557 "and should only be used after careful " 558 "consideration of VarLenFeature.") 559 for index_key in sorted(index_keys): 560 self._add_sparse_key(index_key, dtypes.int64) 561 self._add_sparse_key(feature.value_key, feature.dtype) 562 563 def _add_fixed_len_feature(self, key, feature): 564 """Adds a FixedLenFeature.""" 565 if not feature.dtype: 566 raise ValueError(f"Missing type for feature {key}. Received feature=" 567 f"{feature}.") 568 if feature.shape is None: 569 raise ValueError(f"Missing shape for feature {key}. Received feature=" 570 f"{feature}.") 571 feature_tensor_shape = tensor_shape.as_shape(feature.shape) 572 if (feature.shape and feature_tensor_shape.ndims and 573 feature_tensor_shape.dims[0].value is None): 574 raise ValueError(f"First dimension of shape for feature {key} unknown. " 575 "Consider using FixedLenSequenceFeature. Received " 576 f"feature={feature}.") 577 if (feature.shape is not None and 578 not feature_tensor_shape.is_fully_defined()): 579 raise ValueError(f"All dimensions of shape for feature {key} need to be " 580 f"known but received {feature.shape!s}.") 581 self.dense_keys.append(key) 582 self.dense_shapes.append(tensor_shape.as_shape(feature.shape)) 583 self.dense_types.append(feature.dtype) 584 if feature.default_value is not None: 585 self.dense_defaults[key] = feature.default_value 586 587 def _add_fixed_len_sequence_feature(self, key, feature): 588 """Adds a FixedLenSequenceFeature.""" 589 if not feature.dtype: 590 raise ValueError(f"Missing type for feature {key}. Received feature=" 591 f"{feature}.") 592 if feature.shape is None: 593 raise ValueError(f"Missing shape for feature {key}. Received feature=" 594 f"{feature}.") 595 self.dense_keys.append(key) 596 self.dense_shapes.append(tensor_shape.as_shape(feature.shape)) 597 self.dense_types.append(feature.dtype) 598 if feature.allow_missing: 599 self.dense_defaults[key] = None 600 if feature.default_value is not None: 601 self.dense_defaults[key] = feature.default_value 602 603 def _add_ragged_key(self, key, value_type, split_type): 604 """Adds a ragged key & dtype, checking for duplicates.""" 605 if key in self.ragged_keys: 606 original_value_type = self.ragged_value_types[self.ragged_keys.index(key)] 607 original_split_type = self.ragged_split_types[self.ragged_keys.index(key)] 608 if original_value_type != value_type: 609 raise ValueError(f"Conflicting type {original_value_type} vs " 610 f"{value_type} for feature {key}.") 611 if original_split_type != split_type: 612 raise ValueError(f"Conflicting partition type {original_split_type} vs " 613 f"{split_type} for feature {key}.") 614 else: 615 self.ragged_keys.append(key) 616 self.ragged_value_types.append(value_type) 617 self.ragged_split_types.append(split_type) 618 619 def _add_ragged_feature(self, key, feature): 620 """Adds a RaggedFeature.""" 621 value_key = key if feature.value_key is None else feature.value_key 622 self._add_ragged_key(value_key, feature.dtype, feature.row_splits_dtype) 623 for partition in feature.partitions: 624 if not isinstance(partition, RaggedFeature.UniformRowLength): 625 self._add_ragged_key(partition.key, dtypes.int64, 626 feature.row_splits_dtype) 627 628 def _validate(self): 629 """Validates the features in this ParseOpParams.""" 630 if len(self.dense_shapes) != len(self.dense_keys): 631 raise ValueError("len(self.dense_shapes) != len(self.dense_keys): " 632 f"{len(self.dense_shapes)} vs {len(self.dense_keys)}.") 633 if len(self.dense_types) != len(self.dense_keys): 634 raise ValueError("len(self.dense_types) != len(self.dense_keys): " 635 f"{len(self.dense_types)} vs {len(self.dense_keys)}.") 636 if len(self.sparse_types) != len(self.sparse_keys): 637 raise ValueError("len(self.sparse_types) != len(self.sparse_keys): " 638 f"{len(self.sparse_types)} vs {len(self.sparse_keys)}.") 639 if len(self.ragged_value_types) != len(self.ragged_keys): 640 raise ValueError( 641 "len(self.ragged_value_types) != len(self.ragged_keys): " 642 f"{len(self.ragged_value_types)} vs {len(self.ragged_keys)}.") 643 if len(self.ragged_split_types) != len(self.ragged_keys): 644 raise ValueError( 645 "len(self.ragged_split_types) != len(self.ragged_keys): " 646 f"{len(self.ragged_split_types)} vs {len(self.ragged_keys)}.") 647 648 dense_key_set = set(self.dense_keys) 649 sparse_key_set = set(self.sparse_keys) 650 ragged_key_set = set(self.ragged_keys) 651 if not dense_key_set.isdisjoint(sparse_key_set): 652 raise ValueError( 653 "Dense and sparse keys must not intersect; dense_keys: " 654 f"{self.dense_keys}, sparse_keys: {self.sparse_keys}, intersection: " 655 f"{dense_key_set.intersection(sparse_key_set)}") 656 if not dense_key_set.isdisjoint(ragged_key_set): 657 raise ValueError( 658 "Dense and ragged keys must not intersect; dense_keys: ", 659 f"{self.dense_keys}, ragged_keys: {self.ragged_keys}, intersection: " 660 f"{dense_key_set.intersection(ragged_key_set)}") 661 if not ragged_key_set.isdisjoint(sparse_key_set): 662 raise ValueError( 663 "Ragged and sparse keys must not intersect; ragged_keys: " 664 f"{self.ragged_keys}, sparse_keys: {self.sparse_keys}, intersection: " 665 f"{ragged_key_set.intersection(sparse_key_set)}") 666 667 668def _construct_tensors_for_composite_features(features, tensor_dict): 669 """Creates tensors for SparseFeatures and RaggedFeatures. 670 671 Constructs new dict based on `tensor_dict`. 672 673 For each key in `features` whose value is a `SparseFeature`: 674 675 * Looks up that SparseFeature's value_key and index_keys in tensor_dict. 676 * Uses those tensors to construct a single SparseTensor. 677 * Stores that SparseTensor in the output dict under the same key. 678 679 For each key in `features` whose value is a `RaggedFeature`: 680 681 * Looks up that RaggedFeature's value_key and partition keys in tensor_dict. 682 * Uses those tensors to construct a single RaggedTensor. 683 * Stores that RaggedTensor in the output dict under the same key. 684 685 For any other key in `features`: 686 687 * Copies that key and its value from tensor_dict to the output dictionary. 688 689 Args: 690 features: A `dict` mapping feature keys to `SparseFeature` or 691 `RaggedFeature` values. Values of other types will be ignored. 692 tensor_dict: A `dict` mapping feature keys to `Tensor`, `SparseTensor`, and 693 `RaggedTensor` values. Expected to contain keys of the `SparseFeature`s' 694 `index_key`s and `value_key`s and mapping them to `SparseTensor`s. 695 696 Returns: 697 A `dict` mapping feature keys to `Tensor`, `SparseTensor`, and 698 `RaggedTensor` values. Similar to `tensor_dict` except each `SparseFeature` 699 in `features` results in a single `SparseTensor`; and each `RaggedFeature` 700 in `features` results in a single `RaggedTensor`. 701 """ 702 tensor_dict = dict(tensor_dict) # Do not modify argument passed in. 703 updates = {} 704 for key in sorted(features.keys()): 705 feature = features[key] 706 if isinstance(feature, SparseFeature): 707 # Construct SparseTensors for SparseFeatures 708 if isinstance(feature.index_key, str): 709 sp_ids = tensor_dict[feature.index_key] 710 else: 711 sp_ids = [tensor_dict[index_key] for index_key in feature.index_key] 712 sp_values = tensor_dict[feature.value_key] 713 updates[key] = sparse_ops.sparse_merge( 714 sp_ids, 715 sp_values, 716 vocab_size=feature.size, 717 already_sorted=feature.already_sorted) 718 elif isinstance(feature, RaggedFeature): 719 # Construct RaggedTensors for RaggedFeatures. 720 value_key = key if feature.value_key is None else feature.value_key 721 rt = tensor_dict[value_key] 722 if isinstance(rt, ragged_tensor.RaggedTensor): 723 # We processed a batch of tf.Example or tf.SequenceExample, or single 724 # tf.SequenceExample. 725 if rt.ragged_rank > 1: 726 # We're processing a batch of SequenceExample, and we effectively have 727 # two batch dimensions. Cllapse those batch dimensions here, and 728 # restore them below (using outer_splits). 729 outer_splits = rt.row_splits 730 rt = rt.values 731 else: 732 outer_splits = None 733 for partition in reversed(feature.partitions): 734 rt = _add_batched_ragged_partition(rt, partition, tensor_dict, 735 key, feature.validate, 736 outer_splits) 737 if outer_splits is not None: 738 rt = ragged_tensor.RaggedTensor.from_row_splits( 739 rt, outer_splits, validate=feature.validate) 740 else: 741 # We processed a single tf.Example. 742 for partition in reversed(feature.partitions): 743 rt = _add_ragged_partition(rt, partition, tensor_dict, 744 feature.row_splits_dtype, feature.validate) 745 updates[key] = rt 746 747 # Process updates after all composite tensors have been constructed (in case 748 # multiple features use the same value_key, and one uses that key as its 749 # feature key). 750 tensor_dict.update(updates) 751 752 # Remove tensors from dictionary that were only used to construct 753 # tensors for SparseFeature or RaggedTensor. 754 for key in set(tensor_dict) - set(features): 755 del tensor_dict[key] 756 return tensor_dict 757 758 759def _add_ragged_partition(values, partition, tensor_dict, row_splits_dtype, 760 validate): 761 """Creates a RaggedTensor from a values tensor and a partition tensor. 762 763 Args: 764 values: The values tensor for the new RaggedTensor. 765 partition: The partition configuration object. Specifies the key that 766 should be used to look up the partition tensor (unless partition is a 767 RaggedFeature.UniformRowLength, in which case there is no partition 768 tensor). 769 tensor_dict: The dictionary mapping keys to tensors. 770 row_splits_dtype: The dtype for the partition tensor. 771 validate: Whether to validate that the values form a valid RaggedTensor. 772 773 Returns: 774 A new RaggedTensor formed from the values and partition tensors. 775 """ 776 if isinstance(partition, RaggedFeature.UniformRowLength): 777 if isinstance(values, ragged_tensor.RaggedTensor): 778 length = ops.convert_to_tensor(partition.length, dtype=row_splits_dtype) 779 return ragged_tensor.RaggedTensor.from_uniform_row_length( 780 values, length, validate=validate) 781 else: 782 return array_ops.reshape(values, array_ops.concat( 783 [[-1, partition.length], array_ops.shape(values)[1:]], axis=0)) 784 else: 785 partition_t = math_ops.cast(tensor_dict[partition.key], row_splits_dtype) 786 if isinstance(partition, RaggedFeature.RowSplits): 787 return ragged_tensor.RaggedTensor.from_row_splits( 788 values, partition_t, validate=validate) 789 elif isinstance(partition, RaggedFeature.RowLengths): 790 return ragged_tensor.RaggedTensor.from_row_lengths( 791 values, partition_t, validate=validate) 792 elif isinstance(partition, RaggedFeature.RowStarts): 793 return ragged_tensor.RaggedTensor.from_row_starts( 794 values, partition_t, validate=validate) 795 elif isinstance(partition, RaggedFeature.RowLimits): 796 return ragged_tensor.RaggedTensor.from_row_limits( 797 values, partition_t, validate=validate) 798 elif isinstance(partition, RaggedFeature.ValueRowIds): 799 return ragged_tensor.RaggedTensor.from_value_rowids( 800 values, partition_t, validate=validate) 801 raise ValueError(f"Unhandled partition type {partition!r}") 802 803 804def _add_batched_ragged_partition(rt, partition, tensor_dict, feature_key, 805 validate, outer_splits=None): 806 """Adds a batched ragged partition tensor to a batched ragged tensor. 807 808 Args: 809 rt: A RaggedTensor with shape [batch_size, ...]. 810 partition: The partition configuration object. Specifies the key that 811 should be used to look up the partition tensor (unless partition is a 812 RaggedFeature.UniformRowLength, in which case there is no partition 813 tensor). The specified tensor must have shape [batch_size, ...]. 814 tensor_dict: The dictionary mapping keys to tensors. 815 feature_key: The name of the feature being parsed (for error messages). 816 validate: Whether to validate that the values form a valid RaggedTensor. 817 outer_splits: If not None, then we have two batch dimensions, and this 818 is the row-splits for the collapsed batch dimension. Every partition 819 tensor must have an outer row_splits that matches this value. 820 821 Returns: 822 A new RaggedTensor where each batch item `rt[i]` has been partitioned 823 using the `partition_t[i]`. 824 """ 825 if isinstance(partition, RaggedFeature.UniformRowLength): 826 if rt.ragged_rank > 1: 827 length = ops.convert_to_tensor(partition.length, rt.row_splits.dtype) 828 return ragged_tensor.RaggedTensor.from_row_splits( 829 ragged_tensor.RaggedTensor.from_uniform_row_length( 830 rt.values, length, validate=validate), 831 rt.row_splits // length, 832 validate=validate) 833 else: 834 reshaped_vals = array_ops.reshape(rt.values, array_ops.concat( 835 [[-1, partition.length], array_ops.shape(rt.values)[1:]], axis=0)) 836 return ragged_tensor.RaggedTensor.from_row_splits( 837 reshaped_vals, rt.row_splits // partition.length, validate=validate) 838 839 partition_t = tensor_dict[partition.key] 840 if partition_t.values.dtype != rt.row_splits.dtype: 841 partition_t = math_ops.cast(partition_t, rt.row_splits.dtype) 842 843 checks = [] 844 if outer_splits is not None: 845 if validate: 846 checks.append(check_ops.assert_equal( 847 outer_splits, partition_t.row_splits, 848 message="Feature %s: values and partitions are not aligned" 849 % feature_key)) 850 partition_t = partition_t.values 851 852 with ops.control_dependencies(checks): 853 if isinstance(partition, (RaggedFeature.RowSplits, 854 RaggedFeature.RowLimits)): 855 if isinstance(partition, RaggedFeature.RowSplits): 856 partition_t = partition_t[:, 1:] 857 adjusted_limits = partition_t.values + array_ops.repeat( 858 rt.row_starts(), partition_t.row_lengths()) 859 return partition_t.with_values( 860 ragged_tensor.RaggedTensor.from_row_limits( 861 rt.values, adjusted_limits, validate=validate)) 862 elif isinstance(partition, RaggedFeature.RowStarts): 863 adjusted_starts = partition_t.values + array_ops.repeat( 864 rt.row_starts(), partition_t.row_lengths()) 865 return partition_t.with_values( 866 ragged_tensor.RaggedTensor.from_row_starts( 867 rt.values, adjusted_starts, validate=validate)) 868 elif isinstance(partition, RaggedFeature.RowLengths): 869 return partition_t.with_values( 870 ragged_tensor.RaggedTensor.from_row_lengths( 871 rt.values, partition_t.values, validate=validate)) 872 elif isinstance(partition, RaggedFeature.ValueRowIds): 873 nrows = math_ops.maximum( # number of rows in each batch item 874 ragged_math_ops.reduce_max(partition_t + 1, axis=1), 0) 875 adjusted_rowids = partition_t.values + array_ops.repeat( 876 math_ops.cumsum(nrows, exclusive=True), partition_t.row_lengths()) 877 return ragged_tensor.RaggedTensor.from_row_lengths( 878 ragged_tensor.RaggedTensor.from_value_rowids( 879 rt.values, adjusted_rowids, validate=validate), 880 nrows, 881 validate=validate) 882 883 raise ValueError(f"Unhandled partition type {partition!r}") 884 885 886def _build_ragged_tensors(serialized_shape, 887 ragged_values, 888 ragged_row_splits, 889 ragged_inner_splits=None): 890 """Builds RaggedTensors from the outputs of a parse op.""" 891 if ragged_inner_splits is not None: 892 ragged_values = [ 893 ragged_tensor.RaggedTensor.from_row_splits(val, split, validate=False) 894 for (val, split) in zip(ragged_values, ragged_inner_splits) 895 ] 896 if serialized_shape.ndims == 0: 897 return ragged_values 898 else: 899 return [ 900 ragged_tensor.RaggedTensor.from_row_splits(val, split, validate=False) 901 for (val, split) in zip(ragged_values, ragged_row_splits) 902 ] 903