• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""where operation for RaggedTensors."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops.ragged import ragged_concat_ops
25from tensorflow.python.ops.ragged import ragged_functional_ops
26from tensorflow.python.ops.ragged import ragged_gather_ops
27from tensorflow.python.ops.ragged import ragged_tensor
28from tensorflow.python.ops.ragged import ragged_tensor_shape
29
30
31def where_v2(condition, x=None, y=None, name=None):
32  """Return the elements where `condition` is `True`.
33
34  : If both `x` and `y` are None: Retrieve indices of true elements.
35
36    Returns the coordinates of true elements of `condition`. The coordinates
37    are returned in a 2-D tensor with shape
38    `[num_true_values, dim_size(condition)]`, where `result[i]` is the
39    coordinates of the `i`th true value (in row-major order).
40
41  : If both `x` and `y` are non-`None`: Multiplex between `x` and `y`.
42
43    Choose an output shape  from the shapes of `condition`, `x`, and `y` that
44    all three shapes are broadcastable to; and then use the broadcasted
45    `condition` tensor as a mask that chooses whether the corredsponding element
46    in the output should be taken from `x` (if `condition` is true) or `y` (if
47    `condition` is false).
48
49  >>> # Example: retrieve indices of true elements
50  >>> tf.where(tf.ragged.constant([[True, False], [True]]))
51  <tf.Tensor: shape=(2, 2), dtype=int64, numpy= array([[0, 0], [1, 0]])>
52
53  >>> # Example: multiplex between `x` and `y`
54  >>> tf.where(tf.ragged.constant([[True, False], [True, False, True]]),
55  ...          tf.ragged.constant([['A', 'B'], ['C', 'D', 'E']]),
56  ...          tf.ragged.constant([['a', 'b'], ['c', 'd', 'e']]))
57  <tf.RaggedTensor [[b'A', b'b'], [b'C', b'd', b'E']]>
58
59  Args:
60    condition: A potentially ragged tensor of type `bool`
61    x: A potentially ragged tensor (optional).
62    y: A potentially ragged tensor (optional).  Must be specified if `x` is
63      specified.  Must have the same rank and type as `x`.
64    name: A name of the operation (optional).
65
66  Returns:
67    : If both `x` and `y` are `None`:
68      A `Tensor` with shape `(num_true, rank(condition))`.
69    : Otherwise:
70      A potentially ragged tensor with the same type as `x` and `y`, and whose
71      shape is broadcast-compatible with `x`, `y`, and `condition`.
72
73  Raises:
74    ValueError: When exactly one of `x` or `y` is non-`None`; or when
75      `condition`, `x`, and `y` have incompatible shapes.
76  """
77  if (x is None) != (y is None):
78    raise ValueError('x and y must be either both None or both non-None')
79
80  with ops.name_scope('RaggedWhere', name, [condition, x, y]):
81    condition = ragged_tensor.convert_to_tensor_or_ragged_tensor(
82        condition, name='condition')
83    if x is None:
84      return _coordinate_where(condition)
85    else:
86      x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x')
87      y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, name='y')
88      condition, x, y = ragged_tensor.match_row_splits_dtypes(condition, x, y)
89      return _elementwise_where_v2(condition, x, y)
90
91
92def where(condition, x=None, y=None, name=None):
93  """Return the elements, either from `x` or `y`, depending on the `condition`.
94
95  : If both `x` and `y` are `None`:
96    Returns the coordinates of true elements of `condition`. The coordinates
97    are returned in a 2-D tensor with shape
98    `[num_true_values, dim_size(condition)]`, where `result[i]` is the
99    coordinates of the `i`th true value (in row-major order).
100
101  : If both `x` and `y` are non-`None`:
102    Returns a tensor formed by selecting values from `x` where condition is
103    true, and from `y` when condition is false.  In particular:
104
105    : If `condition`, `x`, and `y` all have the same shape:
106
107      * `result[i1...iN] = x[i1...iN]` if `condition[i1...iN]` is true.
108      * `result[i1...iN] = y[i1...iN]` if `condition[i1...iN]` is false.
109
110    : Otherwise:
111
112      * `condition` must be a vector.
113      * `x` and `y` must have the same number of dimensions.
114      * The outermost dimensions of `condition`, `x`, and `y` must all have the
115        same size.
116      * `result[i] = x[i]` if `condition[i]` is true.
117      * `result[i] = y[i]` if `condition[i]` is false.
118
119  Args:
120    condition: A potentially ragged tensor of type `bool`
121    x: A potentially ragged tensor (optional).
122    y: A potentially ragged tensor (optional).  Must be specified if `x` is
123      specified.  Must have the same rank and type as `x`.
124    name: A name of the operation (optional)
125
126  Returns:
127    : If both `x` and `y` are `None`:
128      A `Tensor` with shape `(num_true, dim_size(condition))`.
129    : Otherwise:
130      A potentially ragged tensor with the same type, rank, and outermost
131      dimension size as `x` and `y`.
132      `result.ragged_rank = max(x.ragged_rank, y.ragged_rank)`.
133
134  Raises:
135    ValueError: When exactly one of `x` or `y` is non-`None`; or when
136      `condition`, `x`, and `y` have incompatible shapes.
137
138  #### Examples:
139
140  >>> # Coordinates where condition is true.
141  >>> condition = tf.ragged.constant([[True, False, True], [False, True]])
142  >>> print(where(condition))
143  tf.Tensor( [[0 0] [0 2] [1 1]], shape=(3, 2), dtype=int64)
144
145  >>> # Elementwise selection between x and y, based on condition.
146  >>> condition = tf.ragged.constant([[True, False, True], [False, True]])
147  >>> x = tf.ragged.constant([['A', 'B', 'C'], ['D', 'E']])
148  >>> y = tf.ragged.constant([['a', 'b', 'c'], ['d', 'e']])
149  >>> print(where(condition, x, y))
150  <tf.RaggedTensor [[b'A', b'b', b'C'], [b'd', b'E']]>
151
152  >>> # Row selection between x and y, based on condition.
153  >>> condition = [True, False]
154  >>> x = tf.ragged.constant([['A', 'B', 'C'], ['D', 'E']])
155  >>> y = tf.ragged.constant([['a', 'b', 'c'], ['d', 'e']])
156  >>> print(where(condition, x, y))
157  <tf.RaggedTensor [[b'A', b'B', b'C'], [b'd', b'e']]>
158  """
159  if (x is None) != (y is None):
160    raise ValueError('x and y must be either both None or both non-None')
161  with ops.name_scope('RaggedWhere', name, [condition, x, y]):
162    condition = ragged_tensor.convert_to_tensor_or_ragged_tensor(
163        condition, name='condition')
164    if x is None:
165      return _coordinate_where(condition)
166    else:
167      x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x')
168      y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, name='y')
169      condition, x, y = ragged_tensor.match_row_splits_dtypes(condition, x, y)
170      return _elementwise_where(condition, x, y)
171
172
173def _elementwise_where(condition, x, y):
174  """Ragged version of tf.where(condition, x, y)."""
175  condition_is_ragged = isinstance(condition, ragged_tensor.RaggedTensor)
176  x_is_ragged = isinstance(x, ragged_tensor.RaggedTensor)
177  y_is_ragged = isinstance(y, ragged_tensor.RaggedTensor)
178
179  if not (condition_is_ragged or x_is_ragged or y_is_ragged):
180    return array_ops.where(condition, x, y)
181
182  elif condition_is_ragged and x_is_ragged and y_is_ragged:
183    return ragged_functional_ops.map_flat_values(array_ops.where, condition, x,
184                                                 y)
185  elif not condition_is_ragged:
186    # Concatenate x and y, and then use `gather` to assemble the selected rows.
187    condition.shape.assert_has_rank(1)
188    x_and_y = ragged_concat_ops.concat([x, y], axis=0)
189    x_nrows = _nrows(x, out_type=x_and_y.row_splits.dtype)
190    y_nrows = _nrows(y, out_type=x_and_y.row_splits.dtype)
191    indices = array_ops.where(condition, math_ops.range(x_nrows),
192                              x_nrows + math_ops.range(y_nrows))
193    return ragged_gather_ops.gather(x_and_y, indices)
194
195  else:
196    raise ValueError('Input shapes do not match.')
197
198
199def _elementwise_where_v2(condition, x, y):
200  """Ragged version of tf.where_v2(condition, x, y)."""
201  # Broadcast x, y, and condition to have the same shape.
202  if not (condition.shape.is_fully_defined() and x.shape.is_fully_defined() and
203          y.shape.is_fully_defined() and x.shape == y.shape and
204          condition.shape == x.shape):
205    shape_c = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(
206        condition)
207    shape_x = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x)
208    shape_y = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y)
209    shape = ragged_tensor_shape.broadcast_dynamic_shape(
210        shape_c, ragged_tensor_shape.broadcast_dynamic_shape(shape_x, shape_y))
211    condition = ragged_tensor_shape.broadcast_to(condition, shape)
212    x = ragged_tensor_shape.broadcast_to(x, shape)
213    y = ragged_tensor_shape.broadcast_to(y, shape)
214
215  condition_is_ragged = isinstance(condition, ragged_tensor.RaggedTensor)
216  x_is_ragged = isinstance(x, ragged_tensor.RaggedTensor)
217  y_is_ragged = isinstance(y, ragged_tensor.RaggedTensor)
218  if not (condition_is_ragged or x_is_ragged or y_is_ragged):
219    return array_ops.where_v2(condition, x, y)
220
221  return ragged_functional_ops.map_flat_values(array_ops.where_v2, condition, x,
222                                               y)
223
224
225def _coordinate_where(condition):
226  """Ragged version of tf.where(condition)."""
227  if not isinstance(condition, ragged_tensor.RaggedTensor):
228    return array_ops.where(condition)
229
230  # The coordinate for each `true` value in condition.values.
231  selected_coords = _coordinate_where(condition.values)
232
233  # Convert the first index in each coordinate to a row index and column index.
234  condition = condition.with_row_splits_dtype(selected_coords.dtype)
235  first_index = selected_coords[:, 0]
236  selected_rows = array_ops.gather(condition.value_rowids(), first_index)
237  selected_row_starts = array_ops.gather(condition.row_splits, selected_rows)
238  selected_cols = first_index - selected_row_starts
239
240  # Assemble the row & column index with the indices for inner dimensions.
241  return array_ops.concat([
242      array_ops.expand_dims(selected_rows, 1),
243      array_ops.expand_dims(selected_cols, 1), selected_coords[:, 1:]
244  ],
245                          axis=1)
246
247
248def _nrows(rt_input, out_type):
249  if isinstance(rt_input, ragged_tensor.RaggedTensor):
250    return rt_input.nrows(out_type=out_type)
251  else:
252    return array_ops.shape(rt_input, out_type=out_type)[0]
253