• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Support for ragged tensors."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import ops
22from tensorflow.python.ops.ragged import ragged_tensor
23from tensorflow.python.ops.ragged import ragged_util
24from tensorflow.python.util.tf_export import tf_export
25
26
27@tf_export("ragged.map_flat_values")
28def map_flat_values(op, *args, **kwargs):
29  """Applies `op` to the values of one or more RaggedTensors.
30
31  Replaces any `RaggedTensor` in `args` or `kwargs` with its `flat_values`
32  tensor, and then calls `op`.  Returns a `RaggedTensor` that is constructed
33  from the input `RaggedTensor`s' `nested_row_splits` and the value returned by
34  the `op`.
35
36  If the input arguments contain multiple `RaggedTensor`s, then they must have
37  identical `nested_row_splits`.
38
39  Examples:
40
41  ```python
42  >>> rt = ragged.constant([[1, 2, 3], [], [4, 5], [6]])
43  >>> ragged.map_flat_values(tf.ones_like, rt).eval().tolist()
44  [[1, 1, 1], [], [1, 1], [1]]
45  >>> ragged.map_flat_values(tf.multiply, rt, rt).eval().tolist()
46  [[1, 4, 9], [], [16, 25], [36]]
47  >>> ragged.map_flat_values(tf.add, rt, 5).eval().tolist()
48  [[6, 7, 8], [], [9, 10], [11]]
49  ```
50
51  Args:
52    op: The operation that should be applied to the RaggedTensor `flat_values`.
53      `op` is typically an element-wise operation (such as math_ops.add), but
54      any operation that preserves the size of the outermost dimension can be
55      used.  I.e., `shape[0]` of the value returned by `op` must match
56      `shape[0]` of the `RaggedTensor`s' `flat_values` tensors.
57    *args: Arguments for `op`.
58    **kwargs: Keyword arguments for `op`.
59
60  Returns:
61    A `RaggedTensor` whose `ragged_rank` matches the `ragged_rank` of all
62    input `RaggedTensor`s.
63  Raises:
64    ValueError: If args contains no `RaggedTensors`, or if the `nested_splits`
65      of the input `RaggedTensor`s are not identical.
66  """
67  # Replace RaggedTensors with their values; and collect the splits tensors
68  # from each RaggedTensor.
69  nested_splits_lists = []
70  inner_args = _replace_ragged_with_flat_values(args, nested_splits_lists)
71  inner_kwargs = _replace_ragged_with_flat_values(kwargs, nested_splits_lists)
72  if not nested_splits_lists:
73    return op(*args, **kwargs)
74
75  with ops.control_dependencies(
76      ragged_util.assert_splits_match(nested_splits_lists)):
77    # Delegate to op, and then compose the result from the transformed values
78    # and the splits.
79    return ragged_tensor.RaggedTensor.from_nested_row_splits(
80        op(*inner_args, **inner_kwargs), nested_splits_lists[0])
81
82
83def _replace_ragged_with_flat_values(value, nested_splits_lists):
84  """Replace RaggedTensors with their flat_values, and record their splits.
85
86  Returns a copy of `value`, with any nested `RaggedTensor`s replaced by their
87  `flat_values` tensor.  Looks inside lists, tuples, and dicts.
88
89  Appends each `RaggedTensor`'s `nested_splits` to `nested_splits_lists`.
90
91  Args:
92    value: The value that should be transformed by replacing `RaggedTensors`.
93    nested_splits_lists: An output parameter used to record the `nested_splits`
94      for any `RaggedTensors` that were replaced.
95
96  Returns:
97    A copy of `value` with nested `RaggedTensors` replaced by their `values`.
98  """
99  # Base case
100  if ragged_tensor.is_ragged(value):
101    value = ragged_tensor.convert_to_tensor_or_ragged_tensor(value)
102    nested_splits_lists.append(value.nested_row_splits)
103    return value.flat_values
104
105  # Recursion cases
106  def recurse(v):
107    return _replace_ragged_with_flat_values(v, nested_splits_lists)
108
109  if isinstance(value, list):
110    return [recurse(v) for v in value]
111  elif isinstance(value, tuple):
112    return tuple(recurse(v) for v in value)
113  elif isinstance(value, dict):
114    return dict((k, recurse(v)) for (k, v) in value.items())
115  else:
116    return value
117