1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python-style indexing and slicing for RaggedTensors.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.eager import context 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import control_flow_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops.ragged import ragged_gather_ops 28from tensorflow.python.ops.ragged import ragged_math_ops 29from tensorflow.python.ops.ragged import ragged_tensor 30 31 32def ragged_tensor_getitem(self, key): 33 """Returns the specified piece of this RaggedTensor. 34 35 Supports multidimensional indexing and slicing, with one restriction: 36 indexing into a ragged inner dimension is not allowed. This case is 37 problematic because the indicated value may exist in some rows but not 38 others. In such cases, it's not obvious whether we should (1) report an 39 IndexError; (2) use a default value; or (3) skip that value and return a 40 tensor with fewer rows than we started with. Following the guiding 41 principles of Python ("In the face of ambiguity, refuse the temptation to 42 guess"), we simply disallow this operation. 43 44 Any dimensions added by `array_ops.newaxis` will be ragged if the following 45 dimension is ragged. 46 47 Args: 48 self: The RaggedTensor to slice. 49 key: Indicates which piece of the RaggedTensor to return, using standard 50 Python semantics (e.g., negative values index from the end). `key` 51 may have any of the following types: 52 53 * `int` constant 54 * Scalar integer `Tensor` 55 * `slice` containing integer constants and/or scalar integer 56 `Tensor`s 57 * `Ellipsis` 58 * `tf.newaxis` 59 * `tuple` containing any of the above (for multidimentional indexing) 60 61 Returns: 62 A `Tensor` or `RaggedTensor` object. Values that include at least one 63 ragged dimension are returned as `RaggedTensor`. Values that include no 64 ragged dimensions are returned as `Tensor`. See above for examples of 65 expressions that return `Tensor`s vs `RaggedTensor`s. 66 67 Raises: 68 ValueError: If `key` is out of bounds. 69 ValueError: If `key` is not supported. 70 TypeError: If the indices in `key` have an unsupported type. 71 72 Examples: 73 74 ```python 75 >>> # A 2-D ragged tensor with 1 ragged dimension. 76 >>> rt = ragged.constant([['a', 'b', 'c'], ['d', 'e'], ['f'], ['g']]) 77 >>> rt[0].eval().tolist() # First row (1-D `Tensor`) 78 ['a', 'b', 'c'] 79 >>> rt[:3].eval().tolist() # First three rows (2-D RaggedTensor) 80 [['a', 'b', 'c'], ['d', 'e'], '[f'], [g']] 81 >>> rt[3, 0].eval().tolist() # 1st element of 4th row (scalar) 82 'g' 83 84 >>> # A 3-D ragged tensor with 2 ragged dimensions. 85 >>> rt = ragged.constant([[[1, 2, 3], [4]], 86 ... [[5], [], [6]], 87 ... [[7]], 88 ... [[8, 9], [10]]]) 89 >>> rt[1].eval().tolist() # Second row (2-D RaggedTensor) 90 [[5], [], [6]] 91 >>> rt[3, 0].eval().tolist() # First element of fourth row (1-D Tensor) 92 [8, 9] 93 >>> rt[:, 1:3].eval().tolist() # Items 1-3 of each row (3-D RaggedTensor) 94 [[[4]], [[], [6]], [], [[10]]] 95 >>> rt[:, -1:].eval().tolist() # Last item of each row (3-D RaggedTensor) 96 [[[4]], [[6]], [[7]], [[10]]] 97 ``` 98 """ 99 scope_tensors = [self] + list(_tensors_in_key_list(key)) 100 if isinstance(key, (list, tuple)): 101 key = list(key) 102 else: 103 key = [key] 104 with ops.name_scope(None, "RaggedGetItem", scope_tensors): 105 return _ragged_getitem(self, key) 106 107 108def _ragged_getitem(rt_input, key_list): 109 """Helper for indexing and slicing ragged tensors with __getitem__(). 110 111 Extracts the specified piece of the `rt_input`. See 112 `RaggedTensor.__getitem__` for examples and restrictions. 113 114 Args: 115 rt_input: The `RaggedTensor` from which a piece should be returned. 116 key_list: The list of keys specifying which piece to return. Each key 117 corresponds with a separate dimension. 118 119 Returns: 120 The indicated piece of rt_input. 121 122 Raises: 123 ValueError: If `key_list` is not supported. 124 TypeError: If any keys in `key_list` have an unsupported type. 125 """ 126 if not key_list: 127 return rt_input 128 row_key = key_list[0] 129 inner_keys = key_list[1:] 130 131 if row_key is Ellipsis: 132 expanded_key_list = _expand_ellipsis(key_list, rt_input.shape.ndims) 133 return _ragged_getitem(rt_input, expanded_key_list) 134 135 # Adding a new axis: Get rt_input[inner_keys], and wrap it in a RaggedTensor 136 # that puts all values in a single row. 137 if row_key is array_ops.newaxis: 138 inner_rt = _ragged_getitem(rt_input, inner_keys) 139 nsplits = array_ops.shape(inner_rt.row_splits, out_type=dtypes.int64)[0] 140 return ragged_tensor.RaggedTensor.from_row_splits( 141 inner_rt, array_ops.stack([0, nsplits - 1])) 142 143 # Slicing a range of rows: first slice the outer dimension, and then 144 # call `_ragged_getitem_inner_dimensions` to handle the inner keys. 145 if isinstance(row_key, slice): 146 sliced_rt_input = _slice_ragged_row_dimension(rt_input, row_key) 147 return _ragged_getitem_inner_dimensions(sliced_rt_input, inner_keys) 148 149 # Indexing a single row: slice values to get the indicated row, and then 150 # use a recursive call to __getitem__ to handle the inner keys. 151 else: 152 starts = rt_input.row_splits[:-1] 153 limits = rt_input.row_splits[1:] 154 if context.executing_eagerly(): 155 # In python, __getitem__ should throw IndexError for out of bound 156 # indices. This will allow iteration run correctly as python will 157 # translate IndexError into StopIteration for next()/__next__(). 158 # Below is an example: 159 # import tensorflow as tf 160 # r = tf.ragged.constant([[1., 2.], [3., 4., 5.], [6.]]) 161 # for elem in r: 162 # print(elem) 163 # In non eager mode, the exception is thrown when session runs 164 # so we don't know if out of bound happens before. 165 # In eager mode, however, it is possible to find out when to 166 # throw out of bound IndexError. 167 # In the following row_key >= len(starts) is checked. In case of 168 # TypeError which happens when row_key is not an integer, the exception 169 # will simply be ignored as it will be processed later anyway. 170 try: 171 if int(row_key) >= len(starts): 172 raise IndexError("Row key {} out of bounds".format(row_key)) 173 except (TypeError, ValueError): 174 pass 175 row = rt_input.values[starts[row_key]:limits[row_key]] 176 return row.__getitem__(inner_keys) 177 178 179def _slice_ragged_row_dimension(rt_input, row_key): 180 """Slice the outer dimension of `rt_input` according to the given `slice`. 181 182 Args: 183 rt_input: The `RaggedTensor` to slice. 184 row_key: The `slice` object that should be used to slice `rt_input`. 185 186 Returns: 187 A `RaggedTensor` containing the indicated slice of `rt_input`. 188 """ 189 if row_key.start is None and row_key.stop is None and row_key.step is None: 190 return rt_input 191 192 # Use row_key to slice the starts & limits. 193 new_starts = rt_input.row_splits[:-1][row_key] 194 new_limits = rt_input.row_splits[1:][row_key] 195 zero_pad = array_ops.zeros([1], dtypes.int64) 196 197 # If there's no slice step, then we can just select a single continuous 198 # span of `ragged.values(rt_input)`. 199 if row_key.step is None or row_key.step == 1: 200 # Construct the new splits. If new_starts and new_limits are empty, 201 # then this reduces to [0]. Otherwise, this reduces to: 202 # concat([[new_starts[0]], new_limits]) 203 new_splits = array_ops.concat( 204 [zero_pad[array_ops.size(new_starts):], new_starts[:1], new_limits], 205 axis=0) 206 values_start = new_splits[0] 207 values_limit = new_splits[-1] 208 return ragged_tensor.RaggedTensor.from_row_splits( 209 rt_input.values[values_start:values_limit], new_splits - values_start) 210 211 # If there is a slice step (aka a strided slice), then use ragged_gather to 212 # collect the necessary elements of `ragged.values(rt_input)`. 213 else: 214 return _build_ragged_tensor_from_value_ranges(new_starts, new_limits, 1, 215 rt_input.values) 216 217 218def _ragged_getitem_inner_dimensions(rt_input, key_list): 219 """Retrieve inner dimensions, keeping outermost dimension unchanged. 220 221 Args: 222 rt_input: The `RaggedTensor` or `Tensor` from which a piece should be 223 extracted. 224 key_list: The __getitem__ keys for slicing the inner dimensions. 225 226 Returns: 227 A `RaggedTensor`. 228 229 Raises: 230 ValueError: If key_list is not supported. 231 """ 232 if not key_list: 233 return rt_input 234 235 if isinstance(rt_input, ops.Tensor): 236 return rt_input.__getitem__([slice(None, None, None)] + key_list) 237 238 column_key = key_list[0] 239 if column_key is Ellipsis: 240 expanded_key_list = _expand_ellipsis(key_list, rt_input.values.shape.ndims) 241 return _ragged_getitem_inner_dimensions(rt_input, expanded_key_list) 242 243 # Adding a new axis to a ragged inner dimension: recursively get the inner 244 # dimensions of rt_input with key_list[1:], and then wrap the result in a 245 # RaggedTensor that puts each value in its own row. 246 if column_key is array_ops.newaxis: 247 inner_rt = _ragged_getitem_inner_dimensions(rt_input, key_list[1:]) 248 nsplits = array_ops.shape(inner_rt.row_splits, out_type=dtypes.int64)[0] 249 return ragged_tensor.RaggedTensor.from_row_splits(inner_rt, 250 math_ops.range(nsplits)) 251 252 # Slicing a range of columns in a ragged inner dimension. We use a 253 # recursive call to process the values, and then assemble a RaggedTensor 254 # with those values. 255 if isinstance(column_key, slice): 256 if (column_key.start is None and column_key.stop is None and 257 column_key.step is None): 258 # Trivial slice: recursively process all values, & splits is unchanged. 259 return rt_input.with_values( 260 _ragged_getitem_inner_dimensions(rt_input.values, key_list[1:])) 261 else: 262 # Nontrivial slice: use ragged_gather to extract the indicated slice as 263 # a new RaggedTensor (inner_rt), and then recursively process its values. 264 # The splits can be taken from inner_rt.row_splits(). 265 inner_rt_starts = rt_input.row_splits[:-1] 266 inner_rt_limits = rt_input.row_splits[1:] 267 if column_key.start is not None and column_key.start != 0: 268 inner_rt_starts = _add_offset_to_ranges( 269 column_key.start, rt_input.row_splits[:-1], rt_input.row_splits[1:]) 270 if column_key.stop is not None and column_key.stop != 0: 271 inner_rt_limits = _add_offset_to_ranges( 272 column_key.stop, rt_input.row_splits[:-1], rt_input.row_splits[1:]) 273 inner_rt = _build_ragged_tensor_from_value_ranges( 274 inner_rt_starts, inner_rt_limits, column_key.step, rt_input.values) 275 return inner_rt.with_values( 276 _ragged_getitem_inner_dimensions(inner_rt.values, key_list[1:])) 277 278 # Indexing a single column in a ragged inner dimension: raise an Exception. 279 # See RaggedTensor.__getitem__.__doc__ for an explanation of why indexing 280 # into a ragged inner dimension is problematic. 281 else: 282 raise ValueError("Cannot index into an inner ragged dimension.") 283 284 285def _expand_ellipsis(key_list, num_remaining_dims): 286 """Expands the ellipsis at the start of `key_list`. 287 288 Assumes that the first element of `key_list` is Ellipsis. This will either 289 remove the Ellipsis (if it corresponds to zero indices) or prepend a new 290 `slice(None, None, None)` (if it corresponds to more than zero indices). 291 292 Args: 293 key_list: The arguments to `__getitem__()`. 294 num_remaining_dims: The number of dimensions remaining. 295 296 Returns: 297 A copy of `key_list` with he ellipsis expanded. 298 Raises: 299 ValueError: If ragged_rank.shape.ndims is None 300 IndexError: If there are too many elements in `key_list`. 301 """ 302 if num_remaining_dims is None: 303 raise ValueError("Ellipsis not supported for unknown shape RaggedTensors") 304 num_indices = sum(1 for idx in key_list if idx is not array_ops.newaxis) 305 if num_indices > num_remaining_dims + 1: 306 raise IndexError("Too many indices for RaggedTensor") 307 elif num_indices == num_remaining_dims + 1: 308 return key_list[1:] 309 else: 310 return [slice(None, None, None)] + key_list 311 312 313def _tensors_in_key_list(key_list): 314 """Generates all Tensors in the given slice spec.""" 315 if isinstance(key_list, ops.Tensor): 316 yield key_list 317 if isinstance(key_list, (list, tuple)): 318 for v in key_list: 319 for tensor in _tensors_in_key_list(v): 320 yield tensor 321 if isinstance(key_list, slice): 322 for tensor in _tensors_in_key_list(key_list.start): 323 yield tensor 324 for tensor in _tensors_in_key_list(key_list.stop): 325 yield tensor 326 for tensor in _tensors_in_key_list(key_list.step): 327 yield tensor 328 329 330def _build_ragged_tensor_from_value_ranges(starts, limits, step, values): 331 """Returns a `RaggedTensor` containing the specified sequences of values. 332 333 Returns a RaggedTensor `output` where: 334 335 ```python 336 output.shape[0] = starts.shape[0] 337 output[i] = values[starts[i]:limits[i]:step] 338 ``` 339 340 Requires that `starts.shape == limits.shape` and 341 `0 <= starts[i] <= limits[i] <= values.shape[0]`. 342 343 Args: 344 starts: 1D integer Tensor specifying the start indices for the sequences of 345 values to include. 346 limits: 1D integer Tensor specifying the limit indices for the sequences of 347 values to include. 348 step: Integer value specifying the step size for strided slices. 349 values: The set of values to select from. 350 351 Returns: 352 A `RaggedTensor`. 353 354 Raises: 355 ValueError: Until the prerequisite ops are checked in. 356 """ 357 # Use `ragged_range` to get the index of each value we should include. 358 if step is None: 359 step = 1 360 step = ops.convert_to_tensor(step, name="step") 361 if step.dtype.is_integer: 362 step = math_ops.cast(step, dtypes.int64) 363 else: 364 raise TypeError("slice strides must be integers or None") 365 value_indices = ragged_math_ops.range(starts, limits, step) 366 367 # Use `ragged_gather` or `array_ops.gather` to collect the values. 368 if isinstance(values, ragged_tensor.RaggedTensor): 369 gathered_values = ragged_gather_ops.gather( 370 params=values, indices=value_indices.values) 371 else: 372 gathered_values = array_ops.gather( 373 params=values, indices=value_indices.values) 374 375 # Assemble the RaggedTensor from splits & values. 376 return value_indices.with_values(gathered_values) 377 378 379def _add_offset_to_ranges(offset, starts, limits): 380 """Adds an indexing offset to each of the specified ranges. 381 382 If offset>=0, then return output[i]=min(starts[i]+offset, limits[i]) 383 If offset<0, then return output[i]=max(limits[i]+offset, starts[i]) 384 385 Args: 386 offset: The offset to add. None, or an int, or a scalar Tensor. 387 starts: 1-D int64 tensor containing start indices. 388 limits: 1-D int64 tensor containing limit indices. 389 390 Returns: 391 A 1-D int64 tensor. 392 """ 393 394 def map_positive_offset(offset): 395 return math_ops.minimum(starts + offset, limits) 396 397 def map_negative_offset(offset): 398 return math_ops.maximum(limits + offset, starts) 399 400 if isinstance(offset, ops.Tensor): 401 offset = math_ops.cast(offset, dtypes.int64) 402 return control_flow_ops.cond(offset >= 0, 403 lambda: map_positive_offset(offset), 404 lambda: map_negative_offset(offset)) 405 elif isinstance(offset, int): 406 return (map_positive_offset(offset) 407 if offset > 0 else map_negative_offset(offset)) 408 409 else: 410 raise TypeError("slice offsets must be integers or None") 411