• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python-style indexing and slicing for RaggedTensors."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.eager import context
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops.ragged import ragged_gather_ops
28from tensorflow.python.ops.ragged import ragged_math_ops
29from tensorflow.python.ops.ragged import ragged_tensor
30
31
32def ragged_tensor_getitem(self, key):
33  """Returns the specified piece of this RaggedTensor.
34
35  Supports multidimensional indexing and slicing, with one restriction:
36  indexing into a ragged inner dimension is not allowed.  This case is
37  problematic because the indicated value may exist in some rows but not
38  others.  In such cases, it's not obvious whether we should (1) report an
39  IndexError; (2) use a default value; or (3) skip that value and return a
40  tensor with fewer rows than we started with.  Following the guiding
41  principles of Python ("In the face of ambiguity, refuse the temptation to
42  guess"), we simply disallow this operation.
43
44  Any dimensions added by `array_ops.newaxis` will be ragged if the following
45  dimension is ragged.
46
47  Args:
48    self: The RaggedTensor to slice.
49    key: Indicates which piece of the RaggedTensor to return, using standard
50      Python semantics (e.g., negative values index from the end).  `key`
51      may have any of the following types:
52
53      * `int` constant
54      * Scalar integer `Tensor`
55      * `slice` containing integer constants and/or scalar integer
56        `Tensor`s
57      * `Ellipsis`
58      * `tf.newaxis`
59      * `tuple` containing any of the above (for multidimentional indexing)
60
61  Returns:
62    A `Tensor` or `RaggedTensor` object.  Values that include at least one
63    ragged dimension are returned as `RaggedTensor`.  Values that include no
64    ragged dimensions are returned as `Tensor`.  See above for examples of
65    expressions that return `Tensor`s vs `RaggedTensor`s.
66
67  Raises:
68    ValueError: If `key` is out of bounds.
69    ValueError: If `key` is not supported.
70    TypeError: If the indices in `key` have an unsupported type.
71
72  Examples:
73
74    ```python
75    >>> # A 2-D ragged tensor with 1 ragged dimension.
76    >>> rt = ragged.constant([['a', 'b', 'c'], ['d', 'e'], ['f'], ['g']])
77    >>> rt[0].eval().tolist()       # First row (1-D `Tensor`)
78    ['a', 'b', 'c']
79    >>> rt[:3].eval().tolist()      # First three rows (2-D RaggedTensor)
80    [['a', 'b', 'c'], ['d', 'e'], '[f'], [g']]
81    >>> rt[3, 0].eval().tolist()    # 1st element of 4th row (scalar)
82    'g'
83
84    >>> # A 3-D ragged tensor with 2 ragged dimensions.
85    >>> rt = ragged.constant([[[1, 2, 3], [4]],
86    ...                    [[5], [], [6]],
87    ...                    [[7]],
88    ...                    [[8, 9], [10]]])
89    >>> rt[1].eval().tolist()       # Second row (2-D RaggedTensor)
90    [[5], [], [6]]
91    >>> rt[3, 0].eval().tolist()    # First element of fourth row (1-D Tensor)
92    [8, 9]
93    >>> rt[:, 1:3].eval().tolist()  # Items 1-3 of each row (3-D RaggedTensor)
94    [[[4]], [[], [6]], [], [[10]]]
95    >>> rt[:, -1:].eval().tolist()  # Last item of each row (3-D RaggedTensor)
96    [[[4]], [[6]], [[7]], [[10]]]
97    ```
98  """
99  scope_tensors = [self] + list(_tensors_in_key_list(key))
100  if isinstance(key, (list, tuple)):
101    key = list(key)
102  else:
103    key = [key]
104  with ops.name_scope(None, "RaggedGetItem", scope_tensors):
105    return _ragged_getitem(self, key)
106
107
108def _ragged_getitem(rt_input, key_list):
109  """Helper for indexing and slicing ragged tensors with __getitem__().
110
111  Extracts the specified piece of the `rt_input`.  See
112  `RaggedTensor.__getitem__` for examples and restrictions.
113
114  Args:
115    rt_input: The `RaggedTensor` from which a piece should be returned.
116    key_list: The list of keys specifying which piece to return. Each key
117      corresponds with a separate dimension.
118
119  Returns:
120    The indicated piece of rt_input.
121
122  Raises:
123    ValueError: If `key_list` is not supported.
124    TypeError: If any keys in `key_list` have an unsupported type.
125  """
126  if not key_list:
127    return rt_input
128  row_key = key_list[0]
129  inner_keys = key_list[1:]
130
131  if row_key is Ellipsis:
132    expanded_key_list = _expand_ellipsis(key_list, rt_input.shape.ndims)
133    return _ragged_getitem(rt_input, expanded_key_list)
134
135  # Adding a new axis: Get rt_input[inner_keys], and wrap it in a RaggedTensor
136  # that puts all values in a single row.
137  if row_key is array_ops.newaxis:
138    inner_rt = _ragged_getitem(rt_input, inner_keys)
139    nsplits = array_ops.shape(inner_rt.row_splits, out_type=dtypes.int64)[0]
140    return ragged_tensor.RaggedTensor.from_row_splits(
141        inner_rt, array_ops.stack([0, nsplits - 1]))
142
143  # Slicing a range of rows: first slice the outer dimension, and then
144  # call `_ragged_getitem_inner_dimensions` to handle the inner keys.
145  if isinstance(row_key, slice):
146    sliced_rt_input = _slice_ragged_row_dimension(rt_input, row_key)
147    return _ragged_getitem_inner_dimensions(sliced_rt_input, inner_keys)
148
149  # Indexing a single row: slice values to get the indicated row, and then
150  # use a recursive call to __getitem__ to handle the inner keys.
151  else:
152    starts = rt_input.row_splits[:-1]
153    limits = rt_input.row_splits[1:]
154    if context.executing_eagerly():
155      # In python, __getitem__ should throw IndexError for out of bound
156      # indices. This will allow iteration run correctly as python will
157      # translate IndexError into StopIteration for next()/__next__().
158      # Below is an example:
159      #    import tensorflow as tf
160      #    r = tf.ragged.constant([[1., 2.], [3., 4., 5.], [6.]])
161      #    for elem in r:
162      #      print(elem)
163      # In non eager mode, the exception is thrown when session runs
164      # so we don't know if out of bound happens before.
165      # In eager mode, however, it is possible to find out when to
166      # throw out of bound IndexError.
167      # In the following row_key >= len(starts) is checked. In case of
168      # TypeError which happens when row_key is not an integer, the exception
169      # will simply be ignored as it will be processed later anyway.
170      try:
171        if int(row_key) >= len(starts):
172          raise IndexError("Row key {} out of bounds".format(row_key))
173      except (TypeError, ValueError):
174        pass
175    row = rt_input.values[starts[row_key]:limits[row_key]]
176    return row.__getitem__(inner_keys)
177
178
179def _slice_ragged_row_dimension(rt_input, row_key):
180  """Slice the outer dimension of `rt_input` according to the given `slice`.
181
182  Args:
183    rt_input: The `RaggedTensor` to slice.
184    row_key: The `slice` object that should be used to slice `rt_input`.
185
186  Returns:
187    A `RaggedTensor` containing the indicated slice of `rt_input`.
188  """
189  if row_key.start is None and row_key.stop is None and row_key.step is None:
190    return rt_input
191
192  # Use row_key to slice the starts & limits.
193  new_starts = rt_input.row_splits[:-1][row_key]
194  new_limits = rt_input.row_splits[1:][row_key]
195  zero_pad = array_ops.zeros([1], dtypes.int64)
196
197  # If there's no slice step, then we can just select a single continuous
198  # span of `ragged.values(rt_input)`.
199  if row_key.step is None or row_key.step == 1:
200    # Construct the new splits.  If new_starts and new_limits are empty,
201    # then this reduces to [0].  Otherwise, this reduces to:
202    #   concat([[new_starts[0]], new_limits])
203    new_splits = array_ops.concat(
204        [zero_pad[array_ops.size(new_starts):], new_starts[:1], new_limits],
205        axis=0)
206    values_start = new_splits[0]
207    values_limit = new_splits[-1]
208    return ragged_tensor.RaggedTensor.from_row_splits(
209        rt_input.values[values_start:values_limit], new_splits - values_start)
210
211  # If there is a slice step (aka a strided slice), then use ragged_gather to
212  # collect the necessary elements of `ragged.values(rt_input)`.
213  else:
214    return _build_ragged_tensor_from_value_ranges(new_starts, new_limits, 1,
215                                                  rt_input.values)
216
217
218def _ragged_getitem_inner_dimensions(rt_input, key_list):
219  """Retrieve inner dimensions, keeping outermost dimension unchanged.
220
221  Args:
222    rt_input: The `RaggedTensor` or `Tensor` from which a piece should be
223      extracted.
224    key_list: The __getitem__ keys for slicing the inner dimensions.
225
226  Returns:
227    A `RaggedTensor`.
228
229  Raises:
230    ValueError: If key_list is not supported.
231  """
232  if not key_list:
233    return rt_input
234
235  if isinstance(rt_input, ops.Tensor):
236    return rt_input.__getitem__([slice(None, None, None)] + key_list)
237
238  column_key = key_list[0]
239  if column_key is Ellipsis:
240    expanded_key_list = _expand_ellipsis(key_list, rt_input.values.shape.ndims)
241    return _ragged_getitem_inner_dimensions(rt_input, expanded_key_list)
242
243  # Adding a new axis to a ragged inner dimension: recursively get the inner
244  # dimensions of rt_input with key_list[1:], and then wrap the result in a
245  # RaggedTensor that puts each value in its own row.
246  if column_key is array_ops.newaxis:
247    inner_rt = _ragged_getitem_inner_dimensions(rt_input, key_list[1:])
248    nsplits = array_ops.shape(inner_rt.row_splits, out_type=dtypes.int64)[0]
249    return ragged_tensor.RaggedTensor.from_row_splits(inner_rt,
250                                                      math_ops.range(nsplits))
251
252  # Slicing a range of columns in a ragged inner dimension.  We use a
253  # recursive call to process the values, and then assemble a RaggedTensor
254  # with those values.
255  if isinstance(column_key, slice):
256    if (column_key.start is None and column_key.stop is None and
257        column_key.step is None):
258      # Trivial slice: recursively process all values, & splits is unchanged.
259      return rt_input.with_values(
260          _ragged_getitem_inner_dimensions(rt_input.values, key_list[1:]))
261    else:
262      # Nontrivial slice: use ragged_gather to extract the indicated slice as
263      # a new RaggedTensor (inner_rt), and then recursively process its values.
264      # The splits can be taken from inner_rt.row_splits().
265      inner_rt_starts = rt_input.row_splits[:-1]
266      inner_rt_limits = rt_input.row_splits[1:]
267      if column_key.start is not None and column_key.start != 0:
268        inner_rt_starts = _add_offset_to_ranges(
269            column_key.start, rt_input.row_splits[:-1], rt_input.row_splits[1:])
270      if column_key.stop is not None and column_key.stop != 0:
271        inner_rt_limits = _add_offset_to_ranges(
272            column_key.stop, rt_input.row_splits[:-1], rt_input.row_splits[1:])
273      inner_rt = _build_ragged_tensor_from_value_ranges(
274          inner_rt_starts, inner_rt_limits, column_key.step, rt_input.values)
275      return inner_rt.with_values(
276          _ragged_getitem_inner_dimensions(inner_rt.values, key_list[1:]))
277
278  # Indexing a single column in a ragged inner dimension: raise an Exception.
279  # See RaggedTensor.__getitem__.__doc__ for an explanation of why indexing
280  # into a ragged inner dimension is problematic.
281  else:
282    raise ValueError("Cannot index into an inner ragged dimension.")
283
284
285def _expand_ellipsis(key_list, num_remaining_dims):
286  """Expands the ellipsis at the start of `key_list`.
287
288  Assumes that the first element of `key_list` is Ellipsis.  This will either
289  remove the Ellipsis (if it corresponds to zero indices) or prepend a new
290  `slice(None, None, None)` (if it corresponds to more than zero indices).
291
292  Args:
293    key_list: The arguments to `__getitem__()`.
294    num_remaining_dims: The number of dimensions remaining.
295
296  Returns:
297    A copy of `key_list` with he ellipsis expanded.
298  Raises:
299    ValueError: If ragged_rank.shape.ndims is None
300    IndexError: If there are too many elements in `key_list`.
301  """
302  if num_remaining_dims is None:
303    raise ValueError("Ellipsis not supported for unknown shape RaggedTensors")
304  num_indices = sum(1 for idx in key_list if idx is not array_ops.newaxis)
305  if num_indices > num_remaining_dims + 1:
306    raise IndexError("Too many indices for RaggedTensor")
307  elif num_indices == num_remaining_dims + 1:
308    return key_list[1:]
309  else:
310    return [slice(None, None, None)] + key_list
311
312
313def _tensors_in_key_list(key_list):
314  """Generates all Tensors in the given slice spec."""
315  if isinstance(key_list, ops.Tensor):
316    yield key_list
317  if isinstance(key_list, (list, tuple)):
318    for v in key_list:
319      for tensor in _tensors_in_key_list(v):
320        yield tensor
321  if isinstance(key_list, slice):
322    for tensor in _tensors_in_key_list(key_list.start):
323      yield tensor
324    for tensor in _tensors_in_key_list(key_list.stop):
325      yield tensor
326    for tensor in _tensors_in_key_list(key_list.step):
327      yield tensor
328
329
330def _build_ragged_tensor_from_value_ranges(starts, limits, step, values):
331  """Returns a `RaggedTensor` containing the specified sequences of values.
332
333  Returns a RaggedTensor `output` where:
334
335  ```python
336  output.shape[0] = starts.shape[0]
337  output[i] = values[starts[i]:limits[i]:step]
338  ```
339
340  Requires that `starts.shape == limits.shape` and
341  `0 <= starts[i] <= limits[i] <= values.shape[0]`.
342
343  Args:
344    starts: 1D integer Tensor specifying the start indices for the sequences of
345      values to include.
346    limits: 1D integer Tensor specifying the limit indices for the sequences of
347      values to include.
348    step: Integer value specifying the step size for strided slices.
349    values: The set of values to select from.
350
351  Returns:
352    A `RaggedTensor`.
353
354  Raises:
355    ValueError: Until the prerequisite ops are checked in.
356  """
357  # Use `ragged_range` to get the index of each value we should include.
358  if step is None:
359    step = 1
360  step = ops.convert_to_tensor(step, name="step")
361  if step.dtype.is_integer:
362    step = math_ops.cast(step, dtypes.int64)
363  else:
364    raise TypeError("slice strides must be integers or None")
365  value_indices = ragged_math_ops.range(starts, limits, step)
366
367  # Use `ragged_gather` or `array_ops.gather` to collect the values.
368  if isinstance(values, ragged_tensor.RaggedTensor):
369    gathered_values = ragged_gather_ops.gather(
370        params=values, indices=value_indices.values)
371  else:
372    gathered_values = array_ops.gather(
373        params=values, indices=value_indices.values)
374
375  # Assemble the RaggedTensor from splits & values.
376  return value_indices.with_values(gathered_values)
377
378
379def _add_offset_to_ranges(offset, starts, limits):
380  """Adds an indexing offset to each of the specified ranges.
381
382  If offset>=0, then return output[i]=min(starts[i]+offset, limits[i])
383  If offset<0, then return output[i]=max(limits[i]+offset, starts[i])
384
385  Args:
386    offset: The offset to add.  None, or an int, or a scalar Tensor.
387    starts: 1-D int64 tensor containing start indices.
388    limits: 1-D int64 tensor containing limit indices.
389
390  Returns:
391    A 1-D int64 tensor.
392  """
393
394  def map_positive_offset(offset):
395    return math_ops.minimum(starts + offset, limits)
396
397  def map_negative_offset(offset):
398    return math_ops.maximum(limits + offset, starts)
399
400  if isinstance(offset, ops.Tensor):
401    offset = math_ops.cast(offset, dtypes.int64)
402    return control_flow_ops.cond(offset >= 0,
403                                 lambda: map_positive_offset(offset),
404                                 lambda: map_negative_offset(offset))
405  elif isinstance(offset, int):
406    return (map_positive_offset(offset)
407            if offset > 0 else map_negative_offset(offset))
408
409  else:
410    raise TypeError("slice offsets must be integers or None")
411