• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Feature configuration for tf.io.parse_example."""
16
17import collections
18import re
19
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import check_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import sparse_ops
28from tensorflow.python.ops.ragged import ragged_math_ops
29from tensorflow.python.ops.ragged import ragged_tensor
30from tensorflow.python.platform import tf_logging
31from tensorflow.python.util.tf_export import tf_export
32
33
34# TODO(b/122887740) Refactor code:
35#   * Move input verification to feature configuration objects (e.g.,
36#     VarLenFeature should check that dtype is a valid dtype).
37#   * Add an _add_feature() method to each feature configuration object
38#     (rather than using a dispatch table in _ParseOpParams._add_feature).
39#   * Update _construct_tensors_for_composite_features() to call a method
40#     on the feature object (rather than using dispatch).
41
42
43@tf_export("io.VarLenFeature", v1=["VarLenFeature", "io.VarLenFeature"])
44class VarLenFeature(collections.namedtuple("VarLenFeature", ["dtype"])):
45  """Configuration for parsing a variable-length input feature.
46
47  Fields:
48    dtype: Data type of input.
49  """
50  pass
51
52
53@tf_export("io.RaggedFeature")
54class RaggedFeature(
55    collections.namedtuple(
56        "RaggedFeature",
57        ["dtype", "value_key", "partitions", "row_splits_dtype", "validate"])):
58  """Configuration for passing a RaggedTensor input feature.
59
60  `value_key` specifies the feature key for a variable-length list of values;
61  and `partitions` specifies zero or more feature keys for partitioning those
62  values into higher dimensions.  Each element of `partitions` must be one of
63  the following:
64
65    * `tf.io.RaggedFeature.RowSplits(key: string)`
66    * `tf.io.RaggedFeature.RowLengths(key: string)`
67    * `tf.io.RaggedFeature.RowStarts(key: string)`
68    * `tf.io.RaggedFeature.RowLimits(key: string)`
69    * `tf.io.RaggedFeature.ValueRowIds(key: string)`
70    * `tf.io.RaggedFeature.UniformRowLength(length: int)`.
71
72  Where `key` is a feature key whose values are used to partition the values.
73  Partitions are listed from outermost to innermost.
74
75  * If `len(partitions) == 0` (the default), then:
76
77    * A feature from a single `tf.Example` is parsed into a 1D `tf.Tensor`.
78    * A feature from a batch of `tf.Example`s is parsed into a 2D
79      `tf.RaggedTensor`, where the outer dimension is the batch dimension, and
80      the inner (ragged) dimension is the feature length in each example.
81
82  * If `len(partitions) == 1`, then:
83
84    * A feature from a single `tf.Example` is parsed into a 2D
85      `tf.RaggedTensor`, where the values taken from the `value_key` are
86      separated into rows using the partition key.
87    * A feature from a batch of `tf.Example`s is parsed into a 3D
88      `tf.RaggedTensor`, where the outer dimension is the batch dimension,
89      the two inner dimensions are formed by separating the `value_key` values
90      from each example into rows using that example's partition key.
91
92  * If `len(partitions) > 1`, then:
93
94    * A feature from a single `tf.Example` is parsed into a `tf.RaggedTensor`
95      whose rank is `len(partitions)+1`, and whose ragged_rank is
96      `len(partitions)`.
97
98    * A feature from a batch of `tf.Example`s is parsed into a `tf.RaggedTensor`
99      whose rank is `len(partitions)+2` and whose ragged_rank is
100      `len(partitions)+1`, where the outer dimension is the batch dimension.
101
102  There is one exception: if the final (i.e., innermost) element(s) of
103  `partitions` are `UniformRowLength`s, then the values are simply reshaped (as
104  a higher-dimensional `tf.Tensor`), rather than being wrapped in a
105  `tf.RaggedTensor`.
106
107  #### Examples
108
109  >>> import google.protobuf.text_format as pbtext
110  >>> example_batch = [
111  ...   pbtext.Merge(r'''
112  ...     features {
113  ...       feature {key: "v" value {int64_list {value: [3, 1, 4, 1, 5, 9]}}}
114  ...       feature {key: "s1" value {int64_list {value: [0, 2, 3, 3, 6]}}}
115  ...       feature {key: "s2" value {int64_list {value: [0, 2, 3, 4]}}}
116  ...     }''', tf.train.Example()).SerializeToString(),
117  ...   pbtext.Merge(r'''
118  ...     features {
119  ...       feature {key: "v" value {int64_list {value: [2, 7, 1, 8, 2, 8, 1]}}}
120  ...       feature {key: "s1" value {int64_list {value: [0, 3, 4, 5, 7]}}}
121  ...       feature {key: "s2" value {int64_list {value: [0, 1, 1, 4]}}}
122  ...     }''', tf.train.Example()).SerializeToString()]
123
124  >>> features = {
125  ...     # Zero partitions: returns 1D tf.Tensor for each Example.
126  ...     'f1': tf.io.RaggedFeature(value_key="v", dtype=tf.int64),
127  ...     # One partition: returns 2D tf.RaggedTensor for each Example.
128  ...     'f2': tf.io.RaggedFeature(value_key="v", dtype=tf.int64, partitions=[
129  ...         tf.io.RaggedFeature.RowSplits("s1")]),
130  ...     # Two partitions: returns 3D tf.RaggedTensor for each Example.
131  ...     'f3': tf.io.RaggedFeature(value_key="v", dtype=tf.int64, partitions=[
132  ...         tf.io.RaggedFeature.RowSplits("s2"),
133  ...         tf.io.RaggedFeature.RowSplits("s1")])
134  ... }
135
136  >>> feature_dict = tf.io.parse_single_example(example_batch[0], features)
137  >>> for (name, val) in sorted(feature_dict.items()):
138  ...   print('%s: %s' % (name, val))
139  f1: tf.Tensor([3 1 4 1 5 9], shape=(6,), dtype=int64)
140  f2: <tf.RaggedTensor [[3, 1], [4], [], [1, 5, 9]]>
141  f3: <tf.RaggedTensor [[[3, 1], [4]], [[]], [[1, 5, 9]]]>
142
143  >>> feature_dict = tf.io.parse_example(example_batch, features)
144  >>> for (name, val) in sorted(feature_dict.items()):
145  ...   print('%s: %s' % (name, val))
146  f1: <tf.RaggedTensor [[3, 1, 4, 1, 5, 9],
147                        [2, 7, 1, 8, 2, 8, 1]]>
148  f2: <tf.RaggedTensor [[[3, 1], [4], [], [1, 5, 9]],
149                        [[2, 7, 1], [8], [2], [8, 1]]]>
150  f3: <tf.RaggedTensor [[[[3, 1], [4]], [[]], [[1, 5, 9]]],
151                        [[[2, 7, 1]], [], [[8], [2], [8, 1]]]]>
152
153  Fields:
154    dtype: Data type of the `RaggedTensor`.  Must be one of:
155      `tf.dtypes.int64`, `tf.dtypes.float32`, `tf.dtypes.string`.
156    value_key: (Optional.) Key for a `Feature` in the input `Example`, whose
157      parsed `Tensor` will be the resulting `RaggedTensor.flat_values`.  If
158      not specified, then it defaults to the key for this `RaggedFeature`.
159    partitions: (Optional.) A list of objects specifying the row-partitioning
160      tensors (from outermost to innermost).  Each entry in this list must be
161      one of:
162        * `tf.io.RaggedFeature.RowSplits(key: string)`
163        * `tf.io.RaggedFeature.RowLengths(key: string)`
164        * `tf.io.RaggedFeature.RowStarts(key: string)`
165        * `tf.io.RaggedFeature.RowLimits(key: string)`
166        * `tf.io.RaggedFeature.ValueRowIds(key: string)`
167        * `tf.io.RaggedFeature.UniformRowLength(length: int)`.
168      Where `key` is a key for a `Feature` in the input `Example`, whose parsed
169      `Tensor` will be the resulting row-partitioning tensor.
170    row_splits_dtype: (Optional.) Data type for the row-partitioning tensor(s).
171      One of `int32` or `int64`.  Defaults to `int32`.
172    validate: (Optional.) Boolean indicating whether or not to validate that
173      the input values form a valid RaggedTensor.  Defaults to `False`.
174  """
175
176  # pylint: disable=invalid-name
177  RowSplits = collections.namedtuple("RowSplits", ["key"])
178  RowLengths = collections.namedtuple("RowLengths", ["key"])
179  RowStarts = collections.namedtuple("RowStarts", ["key"])
180  RowLimits = collections.namedtuple("RowLimits", ["key"])
181  ValueRowIds = collections.namedtuple("ValueRowIds", ["key"])
182  UniformRowLength = collections.namedtuple("UniformRowLength", ["length"])
183  # pylint: enable=invalid-name
184
185  _PARTITION_TYPES = (RowSplits, RowLengths, RowStarts, RowLimits, ValueRowIds,
186                      UniformRowLength)
187
188  def __new__(cls,
189              dtype,
190              value_key=None,
191              partitions=(),
192              row_splits_dtype=dtypes.int32,
193              validate=False):
194    if value_key is not None:
195      if not isinstance(value_key, str):
196        raise ValueError(
197            f"Argument `value_key` must be a string; got {value_key}")
198      if not value_key:
199        raise ValueError("Argument `value_key` must not be empty")
200    dtype = dtypes.as_dtype(dtype)
201    if dtype not in (dtypes.int64, dtypes.float32, dtypes.string):
202      raise ValueError("Argument `dtype` must be int64, float32, or bytes; got "
203                       f"{dtype!r}")
204    row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
205    if row_splits_dtype not in (dtypes.int32, dtypes.int64):
206      raise ValueError("Argument `row_splits_dtype` must be int32 or int64; got"
207                       f"{row_splits_dtype!r}")
208    if not isinstance(partitions, (list, tuple)):
209      raise TypeError("Argument `partitions` must be a list or tuple. Received"
210                      f"partitions={partitions} of type "
211                      f"{type(partitions).__name__}.")
212    for partition in partitions:
213      if not isinstance(partition, cls._PARTITION_TYPES):
214        raise TypeError("Argument `partitions` must be a list of partition "
215                        f"objects {cls._PARTITION_TYPES}; got: {partition!r}")
216    if not isinstance(validate, bool):
217      raise TypeError(f"Argument `validate` must be a bool; got {validate!r}")
218    return super(RaggedFeature, cls).__new__(cls, dtype, value_key, partitions,
219                                             row_splits_dtype, validate)
220
221
222@tf_export("io.SparseFeature", v1=["io.SparseFeature", "SparseFeature"])
223class SparseFeature(
224    collections.namedtuple(
225        "SparseFeature",
226        ["index_key", "value_key", "dtype", "size", "already_sorted"])):
227  """Configuration for parsing a sparse input feature from an `Example`.
228
229  Note, preferably use `VarLenFeature` (possibly in combination with a
230  `SequenceExample`) in order to parse out `SparseTensor`s instead of
231  `SparseFeature` due to its simplicity.
232
233  Closely mimicking the `SparseTensor` that will be obtained by parsing an
234  `Example` with a `SparseFeature` config, a `SparseFeature` contains a
235
236  * `value_key`: The name of key for a `Feature` in the `Example` whose parsed
237    `Tensor` will be the resulting `SparseTensor.values`.
238
239  * `index_key`: A list of names - one for each dimension in the resulting
240    `SparseTensor` whose `indices[i][dim]` indicating the position of
241    the `i`-th value in the `dim` dimension will be equal to the `i`-th value in
242    the Feature with key named `index_key[dim]` in the `Example`.
243
244  * `size`: A list of ints for the resulting `SparseTensor.dense_shape`.
245
246  For example, we can represent the following 2D `SparseTensor`
247
248  ```python
249  SparseTensor(indices=[[3, 1], [20, 0]],
250               values=[0.5, -1.0]
251               dense_shape=[100, 3])
252  ```
253
254  with an `Example` input proto
255
256  ```python
257  features {
258    feature { key: "val" value { float_list { value: [ 0.5, -1.0 ] } } }
259    feature { key: "ix0" value { int64_list { value: [ 3, 20 ] } } }
260    feature { key: "ix1" value { int64_list { value: [ 1, 0 ] } } }
261  }
262  ```
263
264  and `SparseFeature` config with 2 `index_key`s
265
266  ```python
267  SparseFeature(index_key=["ix0", "ix1"],
268                value_key="val",
269                dtype=tf.float32,
270                size=[100, 3])
271  ```
272
273  Fields:
274    index_key: A single string name or a list of string names of index features.
275      For each key the underlying feature's type must be `int64` and its length
276      must always match that of the `value_key` feature.
277      To represent `SparseTensor`s with a `dense_shape` of `rank` higher than 1
278      a list of length `rank` should be used.
279    value_key: Name of value feature.  The underlying feature's type must
280      be `dtype` and its length must always match that of all the `index_key`s'
281      features.
282    dtype: Data type of the `value_key` feature.
283    size: A Python int or list thereof specifying the dense shape. Should be a
284      list if and only if `index_key` is a list. In that case the list must be
285      equal to the length of `index_key`. Each for each entry `i` all values in
286      the `index_key`[i] feature must be in `[0, size[i])`.
287    already_sorted: A Python boolean to specify whether the values in
288      `value_key` are already sorted by their index position. If so skip
289      sorting. False by default (optional).
290  """
291
292  def __new__(cls, index_key, value_key, dtype, size, already_sorted=False):
293    return super(SparseFeature, cls).__new__(
294        cls, index_key, value_key, dtype, size, already_sorted)
295
296
297@tf_export("io.FixedLenFeature", v1=["io.FixedLenFeature", "FixedLenFeature"])
298class FixedLenFeature(collections.namedtuple(
299    "FixedLenFeature", ["shape", "dtype", "default_value"])):
300  """Configuration for parsing a fixed-length input feature.
301
302  To treat sparse input as dense, provide a `default_value`; otherwise,
303  the parse functions will fail on any examples missing this feature.
304
305  Fields:
306    shape: Shape of input data.
307    dtype: Data type of input.
308    default_value: Value to be used if an example is missing this feature. It
309        must be compatible with `dtype` and of the specified `shape`.
310  """
311
312  def __new__(cls, shape, dtype, default_value=None):
313    return super(FixedLenFeature, cls).__new__(
314        cls, shape, dtype, default_value)
315
316
317@tf_export("io.FixedLenSequenceFeature",
318           v1=["io.FixedLenSequenceFeature", "FixedLenSequenceFeature"])
319class FixedLenSequenceFeature(collections.namedtuple(
320    "FixedLenSequenceFeature",
321    ["shape", "dtype", "allow_missing", "default_value"])):
322  """Configuration for parsing a variable-length input feature into a `Tensor`.
323
324  The resulting `Tensor` of parsing a single `SequenceExample` or `Example` has
325  a static `shape` of `[None] + shape` and the specified `dtype`.
326  The resulting `Tensor` of parsing a `batch_size` many `Example`s has
327  a static `shape` of `[batch_size, None] + shape` and the specified `dtype`.
328  The entries in the `batch` from different `Examples` will be padded with
329  `default_value` to the maximum length present in the `batch`.
330
331  To treat a sparse input as dense, provide `allow_missing=True`; otherwise,
332  the parse functions will fail on any examples missing this feature.
333
334  Fields:
335    shape: Shape of input data for dimension 2 and higher. First dimension is
336      of variable length `None`.
337    dtype: Data type of input.
338    allow_missing: Whether to allow this feature to be missing from a feature
339      list item. Is available only for parsing `SequenceExample` not for
340      parsing `Examples`.
341    default_value: Scalar value to be used to pad multiple `Example`s to their
342      maximum length. Irrelevant for parsing a single `Example` or
343      `SequenceExample`. Defaults to "" for dtype string and 0 otherwise
344      (optional).
345  """
346
347  def __new__(cls, shape, dtype, allow_missing=False, default_value=None):
348    return super(FixedLenSequenceFeature, cls).__new__(
349        cls, shape, dtype, allow_missing, default_value)
350
351
352class _ParseOpParams:
353  """Raw parameters used by `gen_parsing_ops`.
354
355  Attributes:
356    sparse_keys: A list of string keys in the examples' features. The results
357      for these keys will be returned as `SparseTensor` objects.
358    sparse_types: A list of `DTypes` of the same length as `sparse_keys`. Only
359      `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string`
360      (`BytesList`) are supported.
361    dense_keys: A list of string keys in the examples' features. The results for
362      these keys will be returned as `Tensor`s
363    dense_types: A list of DTypes of the same length as `dense_keys`. Only
364      `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string`
365      (`BytesList`) are supported.
366    dense_defaults: A dict mapping string keys to `Tensor`s. The keys of the
367      dict must match the dense_keys of the feature.
368    dense_shapes: A list of tuples with the same length as `dense_keys`. The
369      shape of the data for each dense feature referenced by `dense_keys`.
370      Required for any input tensors identified by `dense_keys`.  Must be either
371      fully defined, or may contain an unknown first dimension. An unknown first
372      dimension means the feature is treated as having a variable number of
373      blocks, and the output shape along this dimension is considered unknown at
374      graph build time.  Padding is applied for minibatch elements smaller than
375      the maximum number of blocks for the given feature along this dimension.
376    ragged_keys: A list of string keys in the examples' features.  The
377      results for these keys will be returned as `RaggedTensor` objects.
378    ragged_value_types: A list of `DTypes` of the same length as `ragged_keys`,
379      specifying the value type for each ragged feature.  Must be one of:
380      `tf.float32`, `tf.int64`, `tf.string`.
381    ragged_split_types: A list of `DTypes` of the same length as `ragged_keys`,
382      specifying the row_splits type for each ragged feature.  Must be one of:
383      `tf.int32`, `tf.int64`.
384    dense_shapes_as_proto: dense_shapes converted to TensorShapeProto.
385    dense_defaults_vec: A vector of `Tensor`s containing the default values,
386      corresponding 1:1 with `dense_keys`.
387    num_features: The total number of feature keys.
388  """
389
390  def __init__(self,
391               sparse_keys=None,
392               sparse_types=None,
393               dense_keys=None,
394               dense_types=None,
395               dense_defaults=None,
396               dense_shapes=None,
397               ragged_keys=None,
398               ragged_value_types=None,
399               ragged_split_types=None):
400    # Note: we use an OrderedDict for dense_defaults, to ensure consistent
401    # graph construction order for _e2e_test.
402    dense_defaults = (
403        collections.OrderedDict() if dense_defaults is None else dense_defaults)
404    sparse_keys = [] if sparse_keys is None else sparse_keys
405    sparse_types = [] if sparse_types is None else sparse_types
406    dense_keys = [] if dense_keys is None else dense_keys
407    dense_types = [] if dense_types is None else dense_types
408    dense_shapes = ([[]] *
409                    len(dense_keys) if dense_shapes is None else dense_shapes)
410    ragged_keys = [] if ragged_keys is None else ragged_keys
411    ragged_value_types = ([]
412                          if ragged_value_types is None else ragged_value_types)
413    ragged_split_types = ([]
414                          if ragged_split_types is None else ragged_split_types)
415    self.sparse_keys = sparse_keys
416    self.sparse_types = [dtypes.as_dtype(t) for t in sparse_types]
417    self.dense_keys = dense_keys
418    self.dense_types = [dtypes.as_dtype(t) for t in dense_types]
419    self.dense_shapes = [tensor_shape.as_shape(s) for s in dense_shapes]
420    self.dense_defaults = dense_defaults
421    self.ragged_keys = ragged_keys
422    self.ragged_value_types = [dtypes.as_dtype(t) for t in ragged_value_types]
423    self.ragged_split_types = [dtypes.as_dtype(t) for t in ragged_split_types]
424    self._validate()
425
426  @classmethod
427  def from_features(cls, features, types):
428    """Builds _ParseOpParams for a given set of features and allowed types.
429
430    Args:
431      features: A `dict` mapping feature keys to objects of a type in `types`.
432      types: Type of features to allow, among `FixedLenFeature`,
433        `VarLenFeature`, `SparseFeature`, and `FixedLenSequenceFeature`.
434
435    Returns:
436      A `_ParseOpParams` containing the raw parameters for `gen_parsing_ops`.
437
438    Raises:
439      ValueError: if `features` contains an item not in `types`, or an invalid
440          feature.
441      ValueError: if sparse and dense key sets intersect.
442      ValueError: if input lengths do not match up.
443    """
444    params = cls()
445    if features:
446      # NOTE: We iterate over sorted keys to keep things deterministic.
447      for key in sorted(features.keys()):
448        feature = features[key]
449        if not isinstance(feature, tuple(types)):
450          raise ValueError(
451              f"Unsupported {type(feature).__name__} {feature} for key '{key}'")
452        params._add_feature(key, feature)  # pylint: disable=protected-access
453    params._validate()  # pylint: disable=protected-access
454    return params
455
456  @property
457  def dense_shapes_as_proto(self):
458    return [shape.as_proto() for shape in self.dense_shapes]
459
460  @property
461  def num_features(self):
462    return len(self.dense_keys) + len(self.sparse_keys) + len(self.ragged_keys)
463
464  @property
465  def dense_defaults_vec(self):
466    return [
467        self._make_dense_default(k, s, t)
468        for k, s, t in zip(self.dense_keys, self.dense_shapes, self.dense_types)
469    ]
470
471  def _make_dense_default(self, key, shape, dtype):
472    """Construct the default value tensor for a specified dense feature.
473
474    Args:
475      key: The key string identifying the dense feature.
476      shape: The dense feature's shape.
477      dtype: The dense feature's dtype.
478
479    Returns:
480      A Tensor.
481    """
482    default_value = self.dense_defaults.get(key)
483    if (shape.ndims is not None and shape.ndims > 0 and
484        shape.dims[0].value is None):
485      # Variable stride dense shape, the default value should be a
486      # scalar padding value.
487      if default_value is None:
488        default_value = ops.convert_to_tensor(
489            "" if dtype == dtypes.string else 0, dtype=dtype)
490      else:
491        # Reshape to a scalar to ensure user gets an error if they
492        # provide a tensor that's not intended to be a padding value
493        # (0 or 2+ elements).
494        key_name = "padding_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
495        default_value = ops.convert_to_tensor(
496            default_value, dtype=dtype, name=key_name)
497        default_value = array_ops.reshape(default_value, [])
498    else:
499      if default_value is None:
500        default_value = constant_op.constant([], dtype=dtype)
501      elif not isinstance(default_value, ops.Tensor):
502        key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
503        default_value = ops.convert_to_tensor(
504            default_value, dtype=dtype, name=key_name)
505        default_value = array_ops.reshape(default_value, shape)
506
507    return default_value
508
509  def _add_feature(self, key, feature):
510    """Adds the specified feature to this ParseOpParams."""
511    if isinstance(feature, VarLenFeature):
512      self._add_varlen_feature(key, feature)
513    elif isinstance(feature, SparseFeature):
514      self._add_sparse_feature(key, feature)
515    elif isinstance(feature, FixedLenFeature):
516      self._add_fixed_len_feature(key, feature)
517    elif isinstance(feature, FixedLenSequenceFeature):
518      self._add_fixed_len_sequence_feature(key, feature)
519    elif isinstance(feature, RaggedFeature):
520      self._add_ragged_feature(key, feature)
521    else:
522      raise ValueError(f"Invalid feature {key}:{feature}.")
523
524  def _add_varlen_feature(self, key, feature):
525    """Adds a VarLenFeature."""
526    if not feature.dtype:
527      raise ValueError(
528          f"Missing type for feature {key}. Received feature={feature}")
529    self._add_sparse_key(key, feature.dtype)
530
531  def _add_sparse_key(self, key, dtype):
532    """Adds a sparse key & dtype, checking for duplicates."""
533    if key in self.sparse_keys:
534      original_dtype = self.sparse_types[self.sparse_keys.index(key)]
535      if original_dtype != dtype:
536        raise ValueError(
537            f"Conflicting type {original_dtype} vs {dtype} for feature {key}.")
538    else:
539      self.sparse_keys.append(key)
540      self.sparse_types.append(dtype)
541
542  def _add_sparse_feature(self, key, feature):
543    """Adds a SparseFeature."""
544
545    if not feature.index_key:
546      raise ValueError(f"Missing index_key for SparseFeature {feature}.")
547    if not feature.value_key:
548      raise ValueError(f"Missing value_key for SparseFeature {feature}.")
549    if not feature.dtype:
550      raise ValueError(f"Missing type for feature {key}. Received feature="
551                       f"{feature}.")
552    index_keys = feature.index_key
553    if isinstance(index_keys, str):
554      index_keys = [index_keys]
555    elif len(index_keys) > 1:
556      tf_logging.warning("SparseFeature is a complicated feature config "
557                         "and should only be used after careful "
558                         "consideration of VarLenFeature.")
559    for index_key in sorted(index_keys):
560      self._add_sparse_key(index_key, dtypes.int64)
561    self._add_sparse_key(feature.value_key, feature.dtype)
562
563  def _add_fixed_len_feature(self, key, feature):
564    """Adds a FixedLenFeature."""
565    if not feature.dtype:
566      raise ValueError(f"Missing type for feature {key}. Received feature="
567                       f"{feature}.")
568    if feature.shape is None:
569      raise ValueError(f"Missing shape for feature {key}. Received feature="
570                       f"{feature}.")
571    feature_tensor_shape = tensor_shape.as_shape(feature.shape)
572    if (feature.shape and feature_tensor_shape.ndims and
573        feature_tensor_shape.dims[0].value is None):
574      raise ValueError(f"First dimension of shape for feature {key} unknown. "
575                       "Consider using FixedLenSequenceFeature. Received "
576                       f"feature={feature}.")
577    if (feature.shape is not None and
578        not feature_tensor_shape.is_fully_defined()):
579      raise ValueError(f"All dimensions of shape for feature {key} need to be "
580                       f"known but received {feature.shape!s}.")
581    self.dense_keys.append(key)
582    self.dense_shapes.append(tensor_shape.as_shape(feature.shape))
583    self.dense_types.append(feature.dtype)
584    if feature.default_value is not None:
585      self.dense_defaults[key] = feature.default_value
586
587  def _add_fixed_len_sequence_feature(self, key, feature):
588    """Adds a FixedLenSequenceFeature."""
589    if not feature.dtype:
590      raise ValueError(f"Missing type for feature {key}. Received feature="
591                       f"{feature}.")
592    if feature.shape is None:
593      raise ValueError(f"Missing shape for feature {key}. Received feature="
594                       f"{feature}.")
595    self.dense_keys.append(key)
596    self.dense_shapes.append(tensor_shape.as_shape(feature.shape))
597    self.dense_types.append(feature.dtype)
598    if feature.allow_missing:
599      self.dense_defaults[key] = None
600    if feature.default_value is not None:
601      self.dense_defaults[key] = feature.default_value
602
603  def _add_ragged_key(self, key, value_type, split_type):
604    """Adds a ragged key & dtype, checking for duplicates."""
605    if key in self.ragged_keys:
606      original_value_type = self.ragged_value_types[self.ragged_keys.index(key)]
607      original_split_type = self.ragged_split_types[self.ragged_keys.index(key)]
608      if original_value_type != value_type:
609        raise ValueError(f"Conflicting type {original_value_type} vs "
610                         f"{value_type} for feature {key}.")
611      if original_split_type != split_type:
612        raise ValueError(f"Conflicting partition type {original_split_type} vs "
613                         f"{split_type} for feature {key}.")
614    else:
615      self.ragged_keys.append(key)
616      self.ragged_value_types.append(value_type)
617      self.ragged_split_types.append(split_type)
618
619  def _add_ragged_feature(self, key, feature):
620    """Adds a RaggedFeature."""
621    value_key = key if feature.value_key is None else feature.value_key
622    self._add_ragged_key(value_key, feature.dtype, feature.row_splits_dtype)
623    for partition in feature.partitions:
624      if not isinstance(partition, RaggedFeature.UniformRowLength):
625        self._add_ragged_key(partition.key, dtypes.int64,
626                             feature.row_splits_dtype)
627
628  def _validate(self):
629    """Validates the features in this ParseOpParams."""
630    if len(self.dense_shapes) != len(self.dense_keys):
631      raise ValueError("len(self.dense_shapes) != len(self.dense_keys): "
632                       f"{len(self.dense_shapes)} vs {len(self.dense_keys)}.")
633    if len(self.dense_types) != len(self.dense_keys):
634      raise ValueError("len(self.dense_types) != len(self.dense_keys): "
635                       f"{len(self.dense_types)} vs {len(self.dense_keys)}.")
636    if len(self.sparse_types) != len(self.sparse_keys):
637      raise ValueError("len(self.sparse_types) != len(self.sparse_keys): "
638                       f"{len(self.sparse_types)} vs {len(self.sparse_keys)}.")
639    if len(self.ragged_value_types) != len(self.ragged_keys):
640      raise ValueError(
641          "len(self.ragged_value_types) != len(self.ragged_keys): "
642          f"{len(self.ragged_value_types)} vs {len(self.ragged_keys)}.")
643    if len(self.ragged_split_types) != len(self.ragged_keys):
644      raise ValueError(
645          "len(self.ragged_split_types) != len(self.ragged_keys): "
646          f"{len(self.ragged_split_types)} vs {len(self.ragged_keys)}.")
647
648    dense_key_set = set(self.dense_keys)
649    sparse_key_set = set(self.sparse_keys)
650    ragged_key_set = set(self.ragged_keys)
651    if not dense_key_set.isdisjoint(sparse_key_set):
652      raise ValueError(
653          "Dense and sparse keys must not intersect; dense_keys: "
654          f"{self.dense_keys}, sparse_keys: {self.sparse_keys}, intersection: "
655          f"{dense_key_set.intersection(sparse_key_set)}")
656    if not dense_key_set.isdisjoint(ragged_key_set):
657      raise ValueError(
658          "Dense and ragged keys must not intersect; dense_keys: ",
659          f"{self.dense_keys}, ragged_keys: {self.ragged_keys}, intersection: "
660          f"{dense_key_set.intersection(ragged_key_set)}")
661    if not ragged_key_set.isdisjoint(sparse_key_set):
662      raise ValueError(
663          "Ragged and sparse keys must not intersect; ragged_keys: "
664          f"{self.ragged_keys}, sparse_keys: {self.sparse_keys}, intersection: "
665          f"{ragged_key_set.intersection(sparse_key_set)}")
666
667
668def _construct_tensors_for_composite_features(features, tensor_dict):
669  """Creates tensors for SparseFeatures and RaggedFeatures.
670
671  Constructs new dict based on `tensor_dict`.
672
673  For each key in `features` whose value is a `SparseFeature`:
674
675    * Looks up that SparseFeature's value_key and index_keys in tensor_dict.
676    * Uses those tensors to construct a single SparseTensor.
677    * Stores that SparseTensor in the output dict under the same key.
678
679  For each key in `features` whose value is a `RaggedFeature`:
680
681    * Looks up that RaggedFeature's value_key and partition keys in tensor_dict.
682    * Uses those tensors to construct a single RaggedTensor.
683    * Stores that RaggedTensor in the output dict under the same key.
684
685  For any other key in `features`:
686
687    * Copies that key and its value from tensor_dict to the output dictionary.
688
689  Args:
690    features: A `dict` mapping feature keys to `SparseFeature` or
691      `RaggedFeature` values.  Values of other types will be ignored.
692    tensor_dict: A `dict` mapping feature keys to `Tensor`, `SparseTensor`, and
693      `RaggedTensor` values.  Expected to contain keys of the `SparseFeature`s'
694      `index_key`s and `value_key`s and mapping them to `SparseTensor`s.
695
696  Returns:
697    A `dict` mapping feature keys to `Tensor`, `SparseTensor`, and
698    `RaggedTensor` values. Similar to `tensor_dict` except each `SparseFeature`
699    in `features` results in a single `SparseTensor`; and each `RaggedFeature`
700    in `features` results in a single `RaggedTensor`.
701  """
702  tensor_dict = dict(tensor_dict)  # Do not modify argument passed in.
703  updates = {}
704  for key in sorted(features.keys()):
705    feature = features[key]
706    if isinstance(feature, SparseFeature):
707      # Construct SparseTensors for SparseFeatures
708      if isinstance(feature.index_key, str):
709        sp_ids = tensor_dict[feature.index_key]
710      else:
711        sp_ids = [tensor_dict[index_key] for index_key in feature.index_key]
712      sp_values = tensor_dict[feature.value_key]
713      updates[key] = sparse_ops.sparse_merge(
714          sp_ids,
715          sp_values,
716          vocab_size=feature.size,
717          already_sorted=feature.already_sorted)
718    elif isinstance(feature, RaggedFeature):
719      # Construct RaggedTensors for RaggedFeatures.
720      value_key = key if feature.value_key is None else feature.value_key
721      rt = tensor_dict[value_key]
722      if isinstance(rt, ragged_tensor.RaggedTensor):
723        # We processed a batch of tf.Example or tf.SequenceExample, or single
724        # tf.SequenceExample.
725        if rt.ragged_rank > 1:
726          # We're processing a batch of SequenceExample, and we effectively have
727          # two batch dimensions.  Cllapse those batch dimensions here, and
728          # restore them below (using outer_splits).
729          outer_splits = rt.row_splits
730          rt = rt.values
731        else:
732          outer_splits = None
733        for partition in reversed(feature.partitions):
734          rt = _add_batched_ragged_partition(rt, partition, tensor_dict,
735                                             key, feature.validate,
736                                             outer_splits)
737        if outer_splits is not None:
738          rt = ragged_tensor.RaggedTensor.from_row_splits(
739              rt, outer_splits, validate=feature.validate)
740      else:
741        # We processed a single tf.Example.
742        for partition in reversed(feature.partitions):
743          rt = _add_ragged_partition(rt, partition, tensor_dict,
744                                     feature.row_splits_dtype, feature.validate)
745      updates[key] = rt
746
747  # Process updates after all composite tensors have been constructed (in case
748  # multiple features use the same value_key, and one uses that key as its
749  # feature key).
750  tensor_dict.update(updates)
751
752  # Remove tensors from dictionary that were only used to construct
753  # tensors for SparseFeature or RaggedTensor.
754  for key in set(tensor_dict) - set(features):
755    del tensor_dict[key]
756  return tensor_dict
757
758
759def _add_ragged_partition(values, partition, tensor_dict, row_splits_dtype,
760                          validate):
761  """Creates a RaggedTensor from a values tensor and a partition tensor.
762
763  Args:
764    values: The values tensor for the new RaggedTensor.
765    partition: The partition configuration object.  Specifies the key that
766      should be used to look up the partition tensor (unless partition is a
767      RaggedFeature.UniformRowLength, in which case there is no partition
768      tensor).
769    tensor_dict: The dictionary mapping keys to tensors.
770    row_splits_dtype: The dtype for the partition tensor.
771    validate: Whether to validate that the values form a valid RaggedTensor.
772
773  Returns:
774    A new RaggedTensor formed from the values and partition tensors.
775  """
776  if isinstance(partition, RaggedFeature.UniformRowLength):
777    if isinstance(values, ragged_tensor.RaggedTensor):
778      length = ops.convert_to_tensor(partition.length, dtype=row_splits_dtype)
779      return ragged_tensor.RaggedTensor.from_uniform_row_length(
780          values, length, validate=validate)
781    else:
782      return array_ops.reshape(values, array_ops.concat(
783          [[-1, partition.length], array_ops.shape(values)[1:]], axis=0))
784  else:
785    partition_t = math_ops.cast(tensor_dict[partition.key], row_splits_dtype)
786    if isinstance(partition, RaggedFeature.RowSplits):
787      return ragged_tensor.RaggedTensor.from_row_splits(
788          values, partition_t, validate=validate)
789    elif isinstance(partition, RaggedFeature.RowLengths):
790      return ragged_tensor.RaggedTensor.from_row_lengths(
791          values, partition_t, validate=validate)
792    elif isinstance(partition, RaggedFeature.RowStarts):
793      return ragged_tensor.RaggedTensor.from_row_starts(
794          values, partition_t, validate=validate)
795    elif isinstance(partition, RaggedFeature.RowLimits):
796      return ragged_tensor.RaggedTensor.from_row_limits(
797          values, partition_t, validate=validate)
798    elif isinstance(partition, RaggedFeature.ValueRowIds):
799      return ragged_tensor.RaggedTensor.from_value_rowids(
800          values, partition_t, validate=validate)
801    raise ValueError(f"Unhandled partition type {partition!r}")
802
803
804def _add_batched_ragged_partition(rt, partition, tensor_dict, feature_key,
805                                  validate, outer_splits=None):
806  """Adds a batched ragged partition tensor to a batched ragged tensor.
807
808  Args:
809    rt: A RaggedTensor with shape [batch_size, ...].
810    partition: The partition configuration object.  Specifies the key that
811      should be used to look up the partition tensor (unless partition is a
812      RaggedFeature.UniformRowLength, in which case there is no partition
813      tensor).  The specified tensor must have shape [batch_size, ...].
814    tensor_dict: The dictionary mapping keys to tensors.
815    feature_key: The name of the feature being parsed (for error messages).
816    validate: Whether to validate that the values form a valid RaggedTensor.
817    outer_splits: If not None, then we have two batch dimensions, and this
818      is the row-splits for the collapsed batch dimension.  Every partition
819      tensor must have an outer row_splits that matches this value.
820
821  Returns:
822    A new RaggedTensor where each batch item `rt[i]` has been partitioned
823    using the `partition_t[i]`.
824  """
825  if isinstance(partition, RaggedFeature.UniformRowLength):
826    if rt.ragged_rank > 1:
827      length = ops.convert_to_tensor(partition.length, rt.row_splits.dtype)
828      return ragged_tensor.RaggedTensor.from_row_splits(
829          ragged_tensor.RaggedTensor.from_uniform_row_length(
830              rt.values, length, validate=validate),
831          rt.row_splits // length,
832          validate=validate)
833    else:
834      reshaped_vals = array_ops.reshape(rt.values, array_ops.concat(
835          [[-1, partition.length], array_ops.shape(rt.values)[1:]], axis=0))
836      return ragged_tensor.RaggedTensor.from_row_splits(
837          reshaped_vals, rt.row_splits // partition.length, validate=validate)
838
839  partition_t = tensor_dict[partition.key]
840  if partition_t.values.dtype != rt.row_splits.dtype:
841    partition_t = math_ops.cast(partition_t, rt.row_splits.dtype)
842
843  checks = []
844  if outer_splits is not None:
845    if validate:
846      checks.append(check_ops.assert_equal(
847          outer_splits, partition_t.row_splits,
848          message="Feature %s: values and partitions are not aligned"
849          % feature_key))
850    partition_t = partition_t.values
851
852  with ops.control_dependencies(checks):
853    if isinstance(partition, (RaggedFeature.RowSplits,
854                              RaggedFeature.RowLimits)):
855      if isinstance(partition, RaggedFeature.RowSplits):
856        partition_t = partition_t[:, 1:]
857      adjusted_limits = partition_t.values + array_ops.repeat(
858          rt.row_starts(), partition_t.row_lengths())
859      return partition_t.with_values(
860          ragged_tensor.RaggedTensor.from_row_limits(
861              rt.values, adjusted_limits, validate=validate))
862    elif isinstance(partition, RaggedFeature.RowStarts):
863      adjusted_starts = partition_t.values + array_ops.repeat(
864          rt.row_starts(), partition_t.row_lengths())
865      return partition_t.with_values(
866          ragged_tensor.RaggedTensor.from_row_starts(
867              rt.values, adjusted_starts, validate=validate))
868    elif isinstance(partition, RaggedFeature.RowLengths):
869      return partition_t.with_values(
870          ragged_tensor.RaggedTensor.from_row_lengths(
871              rt.values, partition_t.values, validate=validate))
872    elif isinstance(partition, RaggedFeature.ValueRowIds):
873      nrows = math_ops.maximum(  # number of rows in each batch item
874          ragged_math_ops.reduce_max(partition_t + 1, axis=1), 0)
875      adjusted_rowids = partition_t.values + array_ops.repeat(
876          math_ops.cumsum(nrows, exclusive=True), partition_t.row_lengths())
877      return ragged_tensor.RaggedTensor.from_row_lengths(
878          ragged_tensor.RaggedTensor.from_value_rowids(
879              rt.values, adjusted_rowids, validate=validate),
880          nrows,
881          validate=validate)
882
883    raise ValueError(f"Unhandled partition type {partition!r}")
884
885
886def _build_ragged_tensors(serialized_shape,
887                          ragged_values,
888                          ragged_row_splits,
889                          ragged_inner_splits=None):
890  """Builds RaggedTensors from the outputs of a parse op."""
891  if ragged_inner_splits is not None:
892    ragged_values = [
893        ragged_tensor.RaggedTensor.from_row_splits(val, split, validate=False)
894        for (val, split) in zip(ragged_values, ragged_inner_splits)
895    ]
896  if serialized_shape.ndims == 0:
897    return ragged_values
898  else:
899    return [
900        ragged_tensor.RaggedTensor.from_row_splits(val, split, validate=False)
901        for (val, split) in zip(ragged_values, ragged_row_splits)
902    ]
903