• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Support for ragged tensors."""
16
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import tensor_shape
19from tensorflow.python.ops.ragged import ragged_config
20from tensorflow.python.ops.ragged import ragged_tensor
21from tensorflow.python.util import dispatch
22from tensorflow.python.util.tf_export import tf_export
23
24
25@tf_export("ragged.map_flat_values")
26@dispatch.add_dispatch_support
27def map_flat_values(op, *args, **kwargs):
28  """Applies `op` to the `flat_values` of one or more RaggedTensors.
29
30  Replaces any `RaggedTensor` in `args` or `kwargs` with its `flat_values`
31  tensor (which collapses all ragged dimensions), and then calls `op`.  Returns
32  a `RaggedTensor` that is constructed from the input `RaggedTensor`s'
33  `nested_row_splits` and the value returned by the `op`.
34
35  If the input arguments contain multiple `RaggedTensor`s, then they must have
36  identical `nested_row_splits`.
37
38  This operation is generally used to apply elementwise operations to each value
39  in a `RaggedTensor`.
40
41  Warning: `tf.ragged.map_flat_values` does *not* apply `op` to each row of a
42  ragged tensor.  This difference is important for non-elementwise operations,
43  such as `tf.reduce_sum`.  If you wish to apply a non-elementwise operation to
44  each row of a ragged tensor, use `tf.map_fn` instead.  (You may need to
45  specify an `output_signature` when using `tf.map_fn` with ragged tensors.)
46
47  Examples:
48
49  >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
50  >>> tf.ragged.map_flat_values(tf.ones_like, rt)
51  <tf.RaggedTensor [[1, 1, 1], [], [1, 1], [1]]>
52  >>> tf.ragged.map_flat_values(tf.multiply, rt, rt)
53  <tf.RaggedTensor [[1, 4, 9], [], [16, 25], [36]]>
54  >>> tf.ragged.map_flat_values(tf.add, rt, 5)
55  <tf.RaggedTensor [[6, 7, 8], [], [9, 10], [11]]>
56
57  Example with a non-elementwise operation (note that `map_flat_values` and
58  `map_fn` return different results):
59
60  >>> rt = tf.ragged.constant([[1.0, 3.0], [], [3.0, 6.0, 3.0]])
61  >>> def normalized(x):
62  ...   return x / tf.reduce_sum(x)
63  >>> tf.ragged.map_flat_values(normalized, rt)
64  <tf.RaggedTensor [[0.0625, 0.1875], [], [0.1875, 0.375, 0.1875]]>
65  >>> tf.map_fn(normalized, rt)
66  <tf.RaggedTensor [[0.25, 0.75], [], [0.25, 0.5, 0.25]]>
67
68  Args:
69    op: The operation that should be applied to the RaggedTensor `flat_values`.
70      `op` is typically an element-wise operation (such as math_ops.add), but
71      any operation that preserves the size of the outermost dimension can be
72      used.  I.e., `shape[0]` of the value returned by `op` must match
73      `shape[0]` of the `RaggedTensor`s' `flat_values` tensors.
74    *args: Arguments for `op`.
75    **kwargs: Keyword arguments for `op`.
76
77  Returns:
78    A `RaggedTensor` whose `ragged_rank` matches the `ragged_rank` of all
79    input `RaggedTensor`s.
80  Raises:
81    ValueError: If args contains no `RaggedTensors`, or if the `nested_splits`
82      of the input `RaggedTensor`s are not identical.
83  """
84  # Replace RaggedTensors with their values; and collect the partitions tensors
85  # from each RaggedTensor.
86  partition_lists = []
87  flat_values_nrows = []
88  inner_args = _replace_ragged_with_flat_values(args, partition_lists,
89                                                flat_values_nrows)
90  inner_kwargs = _replace_ragged_with_flat_values(kwargs, partition_lists,
91                                                  flat_values_nrows)
92  if not partition_lists:
93    return op(*args, **kwargs)
94
95  # If we can statically determine that the inputs are incompatible, then raise
96  # an error.  (We can't guarantee full compatibility statically, so we need to
97  # perform some runtime checks too; but this allows us to fail sooner in some
98  # cases.)
99  if flat_values_nrows:
100    flat_values_nrows = set(flat_values_nrows)
101    if len(flat_values_nrows) != 1:
102      raise ValueError("Input RaggedTensors' flat_values must all have the "
103                       "same outer-dimension size.  Got sizes: %s" %
104                       flat_values_nrows)
105    flat_values_nrows = flat_values_nrows.pop()  # Get the single element
106  else:
107    flat_values_nrows = None
108
109  partition_dtypes = set(p[0].dtype for p in partition_lists)
110  if len(partition_dtypes) > 1:
111    if not ragged_config.auto_cast_partition_dtype():
112      raise ValueError("Input RaggedTensors have mismatched row partition "
113                       "dtypes; use RaggedTensor.with_row_splits_dtype() to "
114                       "convert them to compatible dtypes.")
115
116    partition_lists = [
117        [p.with_dtype(dtypes.int64)
118         for p in partition_list]  # pylint: disable=g-complex-comprehension
119        for partition_list in partition_lists
120    ]
121
122  # Delegate to `op`
123  op_output = op(*inner_args, **inner_kwargs)
124  # Check that the result has the expected shape (if known).
125  if flat_values_nrows is not None:
126    if not op_output.shape[:1].is_compatible_with([flat_values_nrows]):
127      raise ValueError(
128          "tf.ragged.map_flat_values requires that the output of `op` have "
129          "the same outer-dimension size as flat_values of any ragged "
130          "inputs. (output shape: %s; expected outer dimension size: %s)" %
131          (op_output.shape, flat_values_nrows))
132  # Compose the result from the transformed values and the partitions.
133  return ragged_tensor.RaggedTensor._from_nested_row_partitions(  # pylint: disable=protected-access
134      op_output,
135      _merge_partition_lists(partition_lists),
136      validate=False)
137
138
139def _replace_ragged_with_flat_values(value, partition_lists, flat_values_nrows):
140  """Replace RaggedTensors with their flat_values, and record their partitions.
141
142  Returns a copy of `value`, with any nested `RaggedTensor`s replaced by their
143  `flat_values` tensor.  Looks inside lists, tuples, and dicts.
144
145  Appends each `RaggedTensor`'s `RowPartition`s to `partition_lists`.
146
147  Args:
148    value: The value that should be transformed by replacing `RaggedTensors`.
149    partition_lists: An output parameter used to record the row partitions
150      for any `RaggedTensors` that were replaced.
151    flat_values_nrows: An output parameter used to record the outer dimension
152      size for each replacement `flat_values` (when known).  Contains a list of
153      int.
154
155  Returns:
156    A copy of `value` with nested `RaggedTensors` replaced by their `values`.
157  """
158  # Base case
159  if ragged_tensor.is_ragged(value):
160    value = ragged_tensor.convert_to_tensor_or_ragged_tensor(value)
161    partition_lists.append(value._nested_row_partitions)  # pylint: disable=protected-access
162    nrows = tensor_shape.dimension_at_index(value.flat_values.shape, 0).value
163    if nrows is not None:
164      flat_values_nrows.append(nrows)
165    return value.flat_values
166
167  # Recursion cases
168  def recurse(v):
169    return _replace_ragged_with_flat_values(v, partition_lists,
170                                            flat_values_nrows)
171
172  if isinstance(value, list):
173    return [recurse(v) for v in value]
174  elif isinstance(value, tuple):
175    return tuple(recurse(v) for v in value)
176  elif isinstance(value, dict):
177    return dict((k, recurse(v)) for (k, v) in value.items())
178  else:
179    return value
180
181
182def _merge_partition_lists(partition_lists):
183  """Merges the given list of lists of RowPartitions.
184
185  Args:
186    partition_lists: A list of lists of RowPartition.
187
188  Returns:
189    A list of RowPartitions, where `result[i]` is formed by merging
190    `partition_lists[j][i]` for all `j`, using
191    `RowPartition._merge_precomputed_encodings`.
192  """
193  dst = list(partition_lists[0])
194  for src in partition_lists[1:]:
195    if len(src) != len(dst):
196      raise ValueError("All ragged inputs must have the same ragged_rank.")
197    for i in range(len(dst)):
198      # pylint: disable=protected-access
199      dst[i] = dst[i]._merge_precomputed_encodings(src[i])
200  return dst
201